mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
init update
This commit is contained in:
0
tools/__init__.py
Normal file
0
tools/__init__.py
Normal file
142
tools/general_tools.py
Normal file
142
tools/general_tools.py
Normal file
@@ -0,0 +1,142 @@
|
||||
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
def _load_runtime_env() -> dict:
|
||||
path = os.environ.get("RUNTIME_ENV_PATH")
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if isinstance(data, dict):
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
return {}
|
||||
|
||||
|
||||
def get_config_value(key: str, default=None):
|
||||
_RUNTIME_ENV = _load_runtime_env()
|
||||
|
||||
if key in _RUNTIME_ENV:
|
||||
return _RUNTIME_ENV[key]
|
||||
return os.getenv(key, default)
|
||||
|
||||
def write_config_value(key: str, value: any):
|
||||
_RUNTIME_ENV = _load_runtime_env()
|
||||
_RUNTIME_ENV[key] = value
|
||||
path = os.environ.get("RUNTIME_ENV_PATH")
|
||||
with open(path, "w", encoding="utf-8") as f:
|
||||
json.dump(_RUNTIME_ENV, f, ensure_ascii=False, indent=4)
|
||||
|
||||
def extract_conversation(conversation: dict, output_type: str):
|
||||
"""Extract information from a conversation payload.
|
||||
|
||||
Args:
|
||||
conversation: A mapping that includes 'messages' (list of dicts or objects with attributes).
|
||||
output_type: 'final' to return the model's final answer content; 'all' to return the full messages list.
|
||||
|
||||
Returns:
|
||||
For 'final': the final assistant content string if found, otherwise None.
|
||||
For 'all': the original messages list (or empty list if missing).
|
||||
"""
|
||||
|
||||
def get_field(obj, key, default=None):
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(key, default)
|
||||
return getattr(obj, key, default)
|
||||
|
||||
def get_nested(obj, path, default=None):
|
||||
current = obj
|
||||
for key in path:
|
||||
current = get_field(current, key, None)
|
||||
if current is None:
|
||||
return default
|
||||
return current
|
||||
|
||||
messages = get_field(conversation, "messages", []) or []
|
||||
|
||||
if output_type == "all":
|
||||
return messages
|
||||
|
||||
if output_type == "final":
|
||||
# Prefer the last message with finish_reason == 'stop' and non-empty content.
|
||||
for msg in reversed(messages):
|
||||
finish_reason = get_nested(msg, ["response_metadata", "finish_reason"])
|
||||
content = get_field(msg, "content")
|
||||
if finish_reason == "stop" and isinstance(content, str) and content.strip():
|
||||
return content
|
||||
|
||||
# Fallback: last AI-like message with non-empty content and not a tool call.
|
||||
for msg in reversed(messages):
|
||||
content = get_field(msg, "content")
|
||||
additional_kwargs = get_field(msg, "additional_kwargs", {}) or {}
|
||||
tool_calls = None
|
||||
if isinstance(additional_kwargs, dict):
|
||||
tool_calls = additional_kwargs.get("tool_calls")
|
||||
else:
|
||||
tool_calls = getattr(additional_kwargs, "tool_calls", None)
|
||||
|
||||
is_tool_invoke = isinstance(tool_calls, list)
|
||||
# Tool messages often have 'tool_call_id' or 'name' (tool name)
|
||||
has_tool_call_id = get_field(msg, "tool_call_id") is not None
|
||||
tool_name = get_field(msg, "name")
|
||||
is_tool_message = has_tool_call_id or isinstance(tool_name, str)
|
||||
|
||||
if not is_tool_invoke and not is_tool_message and isinstance(content, str) and content.strip():
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
raise ValueError("output_type must be 'final' or 'all'")
|
||||
|
||||
|
||||
def extract_tool_messages(conversation: dict):
|
||||
"""Return all ToolMessage-like entries from the conversation.
|
||||
|
||||
A ToolMessage is identified heuristically by having either:
|
||||
- a non-empty 'tool_call_id', or
|
||||
- a string 'name' (tool name) and no 'finish_reason' like normal AI messages
|
||||
|
||||
Supports both dict-based and object-based messages.
|
||||
"""
|
||||
|
||||
def get_field(obj, key, default=None):
|
||||
if isinstance(obj, dict):
|
||||
return obj.get(key, default)
|
||||
return getattr(obj, key, default)
|
||||
|
||||
def get_nested(obj, path, default=None):
|
||||
current = obj
|
||||
for key in path:
|
||||
current = get_field(current, key, None)
|
||||
if current is None:
|
||||
return default
|
||||
return current
|
||||
|
||||
messages = get_field(conversation, "messages", []) or []
|
||||
tool_messages = []
|
||||
for msg in messages:
|
||||
tool_call_id = get_field(msg, "tool_call_id")
|
||||
name = get_field(msg, "name")
|
||||
finish_reason = get_nested(msg, ["response_metadata", "finish_reason"]) # present for AIMessage
|
||||
# Treat as ToolMessage if it carries a tool_call_id, or looks like a tool response
|
||||
if tool_call_id or (isinstance(name, str) and not finish_reason):
|
||||
tool_messages.append(msg)
|
||||
return tool_messages
|
||||
|
||||
|
||||
def extract_first_tool_message_content(conversation: dict):
|
||||
"""Return the content of the first ToolMessage if available, else None."""
|
||||
msgs = extract_tool_messages(conversation)
|
||||
if not msgs:
|
||||
return None
|
||||
|
||||
first = msgs[0]
|
||||
if isinstance(first, dict):
|
||||
return first.get("content")
|
||||
return getattr(first, "content", None)
|
||||
|
||||
370
tools/price_tools.py
Normal file
370
tools/price_tools.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
import sys
|
||||
|
||||
# 将项目根目录加入 Python 路径,便于从子目录直接运行本文件
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
from tools.general_tools import get_config_value
|
||||
|
||||
all_nasdaq_100_symbols = [
|
||||
"NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA",
|
||||
"NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN",
|
||||
"PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC",
|
||||
"BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON",
|
||||
"CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX",
|
||||
"CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI",
|
||||
"MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP",
|
||||
"NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO",
|
||||
"XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR",
|
||||
"KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD",
|
||||
"ON", "BIIB", "LULU", "CDW", "GFS"
|
||||
]
|
||||
|
||||
def get_yesterday_date(today_date: str) -> str:
|
||||
"""
|
||||
获取昨日日期,考虑休市日。
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||||
|
||||
Returns:
|
||||
yesterday_date: 昨日日期字符串,格式 YYYY-MM-DD。
|
||||
"""
|
||||
# 计算昨日日期,考虑休市日
|
||||
today_dt = datetime.strptime(today_date, "%Y-%m-%d")
|
||||
yesterday_dt = today_dt - timedelta(days=1)
|
||||
|
||||
# 如果昨日是周末,向前找到最近的交易日
|
||||
while yesterday_dt.weekday() >= 5: # 5=Saturday, 6=Sunday
|
||||
yesterday_dt -= timedelta(days=1)
|
||||
|
||||
yesterday_date = yesterday_dt.strftime("%Y-%m-%d")
|
||||
return yesterday_date
|
||||
|
||||
def get_open_prices(today_date: str, symbols: List[str], merged_path: Optional[str] = None) -> Dict[str, Optional[float]]:
|
||||
"""从 data/merged.jsonl 中读取指定日期与标的的开盘价。
|
||||
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD。
|
||||
symbols: 需要查询的股票代码列表。
|
||||
merged_path: 可选,自定义 merged.jsonl 路径;默认读取项目根目录下 data/merged.jsonl。
|
||||
|
||||
Returns:
|
||||
{symbol_price: open_price 或 None} 的字典;若未找到对应日期或标的,则值为 None。
|
||||
"""
|
||||
wanted = set(symbols)
|
||||
results: Dict[str, Optional[float]] = {}
|
||||
|
||||
if merged_path is None:
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
merged_file = base_dir / "data" / "merged.jsonl"
|
||||
else:
|
||||
merged_file = Path(merged_path)
|
||||
|
||||
if not merged_file.exists():
|
||||
return results
|
||||
|
||||
with merged_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
except Exception:
|
||||
continue
|
||||
meta = doc.get("Meta Data", {}) if isinstance(doc, dict) else {}
|
||||
sym = meta.get("2. Symbol")
|
||||
if sym not in wanted:
|
||||
continue
|
||||
series = doc.get("Time Series (Daily)", {})
|
||||
if not isinstance(series, dict):
|
||||
continue
|
||||
bar = series.get(today_date)
|
||||
if isinstance(bar, dict):
|
||||
open_val = bar.get("1. buy price")
|
||||
try:
|
||||
results[f'{sym}_price'] = float(open_val) if open_val is not None else None
|
||||
except Exception:
|
||||
results[f'{sym}_price'] = None
|
||||
|
||||
return results
|
||||
|
||||
def get_yesterday_open_and_close_price(today_date: str, symbols: List[str], merged_path: Optional[str] = None) -> tuple[Dict[str, Optional[float]], Dict[str, Optional[float]]]:
|
||||
"""从 data/merged.jsonl 中读取指定日期与股票的昨日买入价和卖出价。
|
||||
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||||
symbols: 需要查询的股票代码列表。
|
||||
merged_path: 可选,自定义 merged.jsonl 路径;默认读取项目根目录下 data/merged.jsonl。
|
||||
|
||||
Returns:
|
||||
(买入价字典, 卖出价字典) 的元组;若未找到对应日期或标的,则值为 None。
|
||||
"""
|
||||
wanted = set(symbols)
|
||||
buy_results: Dict[str, Optional[float]] = {}
|
||||
sell_results: Dict[str, Optional[float]] = {}
|
||||
|
||||
if merged_path is None:
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
merged_file = base_dir / "data" / "merged.jsonl"
|
||||
else:
|
||||
merged_file = Path(merged_path)
|
||||
|
||||
if not merged_file.exists():
|
||||
return buy_results, sell_results
|
||||
|
||||
yesterday_date = get_yesterday_date(today_date)
|
||||
|
||||
with merged_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
except Exception:
|
||||
continue
|
||||
meta = doc.get("Meta Data", {}) if isinstance(doc, dict) else {}
|
||||
sym = meta.get("2. Symbol")
|
||||
if sym not in wanted:
|
||||
continue
|
||||
series = doc.get("Time Series (Daily)", {})
|
||||
if not isinstance(series, dict):
|
||||
continue
|
||||
|
||||
# 尝试获取昨日买入价和卖出价
|
||||
bar = series.get(yesterday_date)
|
||||
if isinstance(bar, dict):
|
||||
buy_val = bar.get("1. buy price") # 买入价字段
|
||||
sell_val = bar.get("4. sell price") # 卖出价字段
|
||||
|
||||
try:
|
||||
buy_price = float(buy_val) if buy_val is not None else None
|
||||
sell_price = float(sell_val) if sell_val is not None else None
|
||||
buy_results[f'{sym}_price'] = buy_price
|
||||
sell_results[f'{sym}_price'] = sell_price
|
||||
except Exception:
|
||||
buy_results[f'{sym}_price'] = None
|
||||
sell_results[f'{sym}_price'] = None
|
||||
else:
|
||||
# 如果昨日没有数据,尝试向前查找最近的交易日
|
||||
today_dt = datetime.strptime(today_date, "%Y-%m-%d")
|
||||
yesterday_dt = today_dt - timedelta(days=1)
|
||||
current_date = yesterday_dt
|
||||
found_data = False
|
||||
|
||||
# 最多向前查找5个交易日
|
||||
for _ in range(5):
|
||||
current_date -= timedelta(days=1)
|
||||
# 跳过周末
|
||||
while current_date.weekday() >= 5:
|
||||
current_date -= timedelta(days=1)
|
||||
|
||||
check_date = current_date.strftime("%Y-%m-%d")
|
||||
bar = series.get(check_date)
|
||||
if isinstance(bar, dict):
|
||||
buy_val = bar.get("1. buy price")
|
||||
sell_val = bar.get("4. sell price")
|
||||
|
||||
try:
|
||||
buy_price = float(buy_val) if buy_val is not None else None
|
||||
sell_price = float(sell_val) if sell_val is not None else None
|
||||
buy_results[f'{sym}_price'] = buy_price
|
||||
sell_results[f'{sym}_price'] = sell_price
|
||||
found_data = True
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not found_data:
|
||||
buy_results[f'{sym}_price'] = None
|
||||
sell_results[f'{sym}_price'] = None
|
||||
|
||||
return buy_results, sell_results
|
||||
|
||||
def get_yesterday_profit(today_date: str, yesterday_buy_prices: Dict[str, Optional[float]], yesterday_sell_prices: Dict[str, Optional[float]], yesterday_init_position: Dict[str, float]) -> Dict[str, float]:
|
||||
"""
|
||||
获取今日开盘时持仓的收益,收益计算方式为:(昨日收盘价格 - 昨日开盘价格)*当前持仓。
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||||
yesterday_buy_prices: 昨日开盘价格字典,格式为 {symbol_price: price}
|
||||
yesterday_sell_prices: 昨日收盘价格字典,格式为 {symbol_price: price}
|
||||
yesterday_init_position: 昨日初始持仓字典,格式为 {symbol: weight}
|
||||
|
||||
Returns:
|
||||
{symbol: profit} 的字典;若未找到对应日期或标的,则值为 0.0。
|
||||
"""
|
||||
profit_dict = {}
|
||||
|
||||
# 遍历所有股票代码
|
||||
for symbol in all_nasdaq_100_symbols:
|
||||
symbol_price_key = f'{symbol}_price'
|
||||
|
||||
# 获取昨日开盘价和收盘价
|
||||
buy_price = yesterday_buy_prices.get(symbol_price_key)
|
||||
sell_price = yesterday_sell_prices.get(symbol_price_key)
|
||||
|
||||
# 获取昨日持仓权重
|
||||
position_weight = yesterday_init_position.get(symbol, 0.0)
|
||||
|
||||
# 计算收益:(收盘价 - 开盘价) * 持仓权重
|
||||
if buy_price is not None and sell_price is not None and position_weight > 0:
|
||||
profit = (sell_price - buy_price) * position_weight
|
||||
profit_dict[symbol] = round(profit, 4) # 保留4位小数
|
||||
else:
|
||||
profit_dict[symbol] = 0.0
|
||||
|
||||
return profit_dict
|
||||
|
||||
def get_today_init_position(today_date: str, modelname: str) -> Dict[str, float]:
|
||||
"""
|
||||
获取今日开盘时的初始持仓(即文件中上一个交易日代表的持仓)。从../data/agent_data/{modelname}/position/position.jsonl中读取。
|
||||
如果同一日期有多条记录,选择id最大的记录作为初始持仓。
|
||||
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||||
modelname: 模型名称,用于构建文件路径。
|
||||
|
||||
Returns:
|
||||
{symbol: weight} 的字典;若未找到对应日期,则返回空字典。
|
||||
"""
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||||
|
||||
if not position_file.exists():
|
||||
print(f"Position file {position_file} does not exist")
|
||||
return {}
|
||||
|
||||
yesterday_date = get_yesterday_date(today_date)
|
||||
max_id = -1
|
||||
latest_positions = {}
|
||||
|
||||
with position_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
if doc.get("date") == yesterday_date:
|
||||
current_id = doc.get("id", 0)
|
||||
if current_id > max_id:
|
||||
max_id = current_id
|
||||
latest_positions = doc.get("positions", {})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return latest_positions
|
||||
|
||||
def get_latest_position(today_date: str, modelname: str) -> Dict[str, float]:
|
||||
"""
|
||||
获取最新持仓。从 ../data/agent_data/{modelname}/position/position.jsonl 中读取。
|
||||
优先选择当天 (today_date) 中 id 最大的记录;
|
||||
若当天无记录,则回退到上一个交易日,选择该日中 id 最大的记录。
|
||||
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||||
modelname: 模型名称,用于构建文件路径。
|
||||
|
||||
Returns:
|
||||
(positions, max_id):
|
||||
- positions: {symbol: weight} 的字典;若未找到任何记录,则为空字典。
|
||||
- max_id: 选中记录的最大 id;若未找到任何记录,则为 -1。
|
||||
"""
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||||
|
||||
if not position_file.exists():
|
||||
return {}, -1
|
||||
|
||||
# 先尝试读取当天记录
|
||||
max_id_today = -1
|
||||
latest_positions_today: Dict[str, float] = {}
|
||||
|
||||
with position_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
if doc.get("date") == today_date:
|
||||
current_id = doc.get("id", -1)
|
||||
if current_id > max_id_today:
|
||||
max_id_today = current_id
|
||||
latest_positions_today = doc.get("positions", {})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if max_id_today >= 0:
|
||||
return latest_positions_today, max_id_today
|
||||
|
||||
# 当天没有记录,则回退到上一个交易日
|
||||
prev_date = get_yesterday_date(today_date)
|
||||
max_id_prev = -1
|
||||
latest_positions_prev: Dict[str, float] = {}
|
||||
|
||||
with position_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
if doc.get("date") == prev_date:
|
||||
current_id = doc.get("id", -1)
|
||||
if current_id > max_id_prev:
|
||||
max_id_prev = current_id
|
||||
latest_positions_prev = doc.get("positions", {})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return latest_positions_prev, max_id_prev
|
||||
|
||||
def add_no_trade_record(today_date: str, modelname: str):
|
||||
"""
|
||||
添加不交易记录。从 ../data/agent_data/{modelname}/position/position.jsonl 中前一日最后一条持仓,并更新在今日的position.jsonl文件中。
|
||||
Args:
|
||||
today_date: 日期字符串,格式 YYYY-MM-DD,代表今天日期。
|
||||
modelname: 模型名称,用于构建文件路径。
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
save_item = {}
|
||||
current_position, current_action_id = get_latest_position(today_date, modelname)
|
||||
print(current_position, current_action_id)
|
||||
save_item["date"] = today_date
|
||||
save_item["id"] = current_action_id+1
|
||||
save_item["this_action"] = {"action":"no_trade","symbol":"","amount":0}
|
||||
|
||||
save_item["positions"] = current_position
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||||
|
||||
with position_file.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(save_item) + "\n")
|
||||
return
|
||||
|
||||
if __name__ == "__main__":
|
||||
today_date = get_config_value("TODAY_DATE")
|
||||
signature = get_config_value("SIGNATURE")
|
||||
if signature is None:
|
||||
raise ValueError("SIGNATURE environment variable is not set")
|
||||
print(today_date, signature)
|
||||
yesterday_date = get_yesterday_date(today_date)
|
||||
# print(yesterday_date)
|
||||
today_buy_price = get_open_prices(today_date, all_nasdaq_100_symbols)
|
||||
# print(today_buy_price)
|
||||
yesterday_buy_prices, yesterday_sell_prices = get_yesterday_open_and_close_price(today_date, all_nasdaq_100_symbols)
|
||||
# print(yesterday_buy_prices)
|
||||
# print(yesterday_sell_prices)
|
||||
today_init_position = get_today_init_position(today_date, signature)
|
||||
# print(today_init_position)
|
||||
latest_position, latest_action_id = get_latest_position(today_date, signature)
|
||||
print(latest_position, latest_action_id)
|
||||
yesterday_profit = get_yesterday_profit(today_date, yesterday_buy_prices, yesterday_sell_prices, today_init_position)
|
||||
# print(yesterday_profit)
|
||||
add_no_trade_record(today_date, signature)
|
||||
872
tools/result_tools.py
Normal file
872
tools/result_tools.py
Normal file
@@ -0,0 +1,872 @@
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import sys
|
||||
|
||||
# Add project root directory to Python path to allow running this file from subdirectories
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from tools.price_tools import (
|
||||
get_yesterday_date,
|
||||
get_open_prices,
|
||||
get_yesterday_open_and_close_price,
|
||||
get_today_init_position,
|
||||
get_latest_position,
|
||||
all_nasdaq_100_symbols
|
||||
)
|
||||
from tools.general_tools import get_config_value
|
||||
|
||||
|
||||
def calculate_portfolio_value(positions: Dict[str, float], prices: Dict[str, Optional[float]], cash: float = 0.0) -> float:
|
||||
"""
|
||||
Calculate total portfolio value
|
||||
|
||||
Args:
|
||||
positions: Position dictionary in format {symbol: shares}
|
||||
prices: Price dictionary in format {symbol_price: price}
|
||||
cash: Cash balance
|
||||
|
||||
Returns:
|
||||
Total portfolio value
|
||||
"""
|
||||
total_value = cash
|
||||
|
||||
for symbol, shares in positions.items():
|
||||
if symbol == "CASH":
|
||||
continue
|
||||
price_key = f'{symbol}_price'
|
||||
price = prices.get(price_key)
|
||||
if price is not None and shares > 0:
|
||||
total_value += shares * price
|
||||
|
||||
return total_value
|
||||
|
||||
|
||||
def get_available_date_range(modelname: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Get available data date range
|
||||
|
||||
Args:
|
||||
modelname: Model name
|
||||
|
||||
Returns:
|
||||
Tuple of (earliest date, latest date) in YYYY-MM-DD format
|
||||
"""
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||||
|
||||
if not position_file.exists():
|
||||
return "", ""
|
||||
|
||||
dates = []
|
||||
|
||||
with position_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
date = doc.get("date")
|
||||
if date:
|
||||
dates.append(date)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not dates:
|
||||
return "", ""
|
||||
|
||||
dates.sort()
|
||||
return dates[0], dates[-1]
|
||||
|
||||
|
||||
def get_daily_portfolio_values(modelname: str, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, float]:
|
||||
"""
|
||||
Get daily portfolio values
|
||||
|
||||
Args:
|
||||
modelname: Model name
|
||||
start_date: Start date in YYYY-MM-DD format, uses earliest date if None
|
||||
end_date: End date in YYYY-MM-DD format, uses latest date if None
|
||||
|
||||
Returns:
|
||||
Dictionary of daily portfolio values in format {date: portfolio_value}
|
||||
"""
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
position_file = base_dir / "data" / "agent_data" / modelname / "position" / "position.jsonl"
|
||||
merged_file = base_dir / "data" / "merged.jsonl"
|
||||
|
||||
if not position_file.exists() or not merged_file.exists():
|
||||
return {}
|
||||
|
||||
# Get available date range if not specified
|
||||
if start_date is None or end_date is None:
|
||||
earliest_date, latest_date = get_available_date_range(modelname)
|
||||
if not earliest_date or not latest_date:
|
||||
return {}
|
||||
|
||||
if start_date is None:
|
||||
start_date = earliest_date
|
||||
if end_date is None:
|
||||
end_date = latest_date
|
||||
|
||||
# Read position data
|
||||
position_data = []
|
||||
with position_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
position_data.append(doc)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Read price data
|
||||
price_data = {}
|
||||
with merged_file.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
doc = json.loads(line)
|
||||
meta = doc.get("Meta Data", {})
|
||||
symbol = meta.get("2. Symbol")
|
||||
if symbol:
|
||||
price_data[symbol] = doc.get("Time Series (Daily)", {})
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Calculate daily portfolio values
|
||||
daily_values = {}
|
||||
|
||||
# Group position data by date
|
||||
positions_by_date = {}
|
||||
for record in position_data:
|
||||
date = record.get("date")
|
||||
if date:
|
||||
if date not in positions_by_date:
|
||||
positions_by_date[date] = []
|
||||
positions_by_date[date].append(record)
|
||||
|
||||
# For each date, sort records by id and take latest position
|
||||
for date, records in positions_by_date.items():
|
||||
if start_date and date < start_date:
|
||||
continue
|
||||
if end_date and date > end_date:
|
||||
continue
|
||||
|
||||
# Sort by id and take latest position
|
||||
latest_record = max(records, key=lambda x: x.get("id", 0))
|
||||
positions = latest_record.get("positions", {})
|
||||
|
||||
# Get daily prices
|
||||
daily_prices = {}
|
||||
for symbol in all_nasdaq_100_symbols:
|
||||
if symbol in price_data:
|
||||
symbol_prices = price_data[symbol]
|
||||
if date in symbol_prices:
|
||||
price_info = symbol_prices[date]
|
||||
buy_price = price_info.get("1. buy price")
|
||||
sell_price = price_info.get("4. sell price")
|
||||
# Use closing (sell) price to calculate value
|
||||
if sell_price is not None:
|
||||
daily_prices[f'{symbol}_price'] = float(sell_price)
|
||||
|
||||
# Calculate portfolio value
|
||||
cash = positions.get("CASH", 0.0)
|
||||
portfolio_value = calculate_portfolio_value(positions, daily_prices, cash)
|
||||
daily_values[date] = portfolio_value
|
||||
|
||||
return daily_values
|
||||
|
||||
|
||||
def calculate_daily_returns(portfolio_values: Dict[str, float]) -> List[float]:
|
||||
"""
|
||||
Calculate daily returns
|
||||
|
||||
Args:
|
||||
portfolio_values: Daily portfolio value dictionary
|
||||
|
||||
Returns:
|
||||
List of daily returns
|
||||
"""
|
||||
if len(portfolio_values) < 2:
|
||||
return []
|
||||
|
||||
# Sort by date
|
||||
sorted_dates = sorted(portfolio_values.keys())
|
||||
returns = []
|
||||
|
||||
for i in range(1, len(sorted_dates)):
|
||||
prev_date = sorted_dates[i-1]
|
||||
curr_date = sorted_dates[i]
|
||||
|
||||
prev_value = portfolio_values[prev_date]
|
||||
curr_value = portfolio_values[curr_date]
|
||||
|
||||
if prev_value > 0:
|
||||
daily_return = (curr_value - prev_value) / prev_value
|
||||
returns.append(daily_return)
|
||||
|
||||
return returns
|
||||
|
||||
|
||||
def calculate_sharpe_ratio(returns: List[float], risk_free_rate: float = 0.02) -> float:
|
||||
"""
|
||||
Calculate Sharpe ratio
|
||||
|
||||
Args:
|
||||
returns: List of returns
|
||||
risk_free_rate: Risk-free rate (annualized)
|
||||
|
||||
Returns:
|
||||
Sharpe ratio
|
||||
"""
|
||||
if not returns or len(returns) < 2:
|
||||
return 0.0
|
||||
|
||||
returns_array = np.array(returns)
|
||||
|
||||
# Calculate annualized return and volatility
|
||||
mean_return = np.mean(returns_array)
|
||||
std_return = np.std(returns_array, ddof=1)
|
||||
|
||||
# Assume 252 trading days per year
|
||||
annualized_return = mean_return * 252
|
||||
annualized_volatility = std_return * np.sqrt(252)
|
||||
|
||||
if annualized_volatility == 0:
|
||||
return 0.0
|
||||
|
||||
# Calculate Sharpe ratio
|
||||
sharpe_ratio = (annualized_return - risk_free_rate) / annualized_volatility
|
||||
|
||||
return sharpe_ratio
|
||||
|
||||
|
||||
def calculate_max_drawdown(portfolio_values: Dict[str, float]) -> Tuple[float, str, str]:
|
||||
"""
|
||||
Calculate maximum drawdown
|
||||
|
||||
Args:
|
||||
portfolio_values: Daily portfolio value dictionary
|
||||
|
||||
Returns:
|
||||
Tuple of (maximum drawdown percentage, drawdown start date, drawdown end date)
|
||||
"""
|
||||
if not portfolio_values:
|
||||
return 0.0, "", ""
|
||||
|
||||
# Sort by date
|
||||
sorted_dates = sorted(portfolio_values.keys())
|
||||
values = [portfolio_values[date] for date in sorted_dates]
|
||||
|
||||
max_drawdown = 0.0
|
||||
peak_value = values[0]
|
||||
peak_date = sorted_dates[0]
|
||||
drawdown_start_date = ""
|
||||
drawdown_end_date = ""
|
||||
|
||||
for i, (date, value) in enumerate(zip(sorted_dates, values)):
|
||||
if value > peak_value:
|
||||
peak_value = value
|
||||
peak_date = date
|
||||
|
||||
drawdown = (peak_value - value) / peak_value
|
||||
if drawdown > max_drawdown:
|
||||
max_drawdown = drawdown
|
||||
drawdown_start_date = peak_date
|
||||
drawdown_end_date = date
|
||||
|
||||
return max_drawdown, drawdown_start_date, drawdown_end_date
|
||||
|
||||
|
||||
def calculate_cumulative_return(portfolio_values: Dict[str, float]) -> float:
|
||||
"""
|
||||
Calculate cumulative return
|
||||
|
||||
Args:
|
||||
portfolio_values: Daily portfolio value dictionary
|
||||
|
||||
Returns:
|
||||
Cumulative return
|
||||
"""
|
||||
if not portfolio_values:
|
||||
return 0.0
|
||||
|
||||
# Sort by date
|
||||
sorted_dates = sorted(portfolio_values.keys())
|
||||
initial_value = portfolio_values[sorted_dates[0]]
|
||||
final_value = portfolio_values[sorted_dates[-1]]
|
||||
|
||||
if initial_value == 0:
|
||||
return 0.0
|
||||
|
||||
cumulative_return = (final_value - initial_value) / initial_value
|
||||
return cumulative_return
|
||||
|
||||
|
||||
def calculate_annualized_return(portfolio_values: Dict[str, float]) -> float:
|
||||
"""
|
||||
Calculate annualized return
|
||||
|
||||
Args:
|
||||
portfolio_values: Daily portfolio value dictionary
|
||||
|
||||
Returns:
|
||||
Annualized return
|
||||
"""
|
||||
if not portfolio_values:
|
||||
return 0.0
|
||||
|
||||
# Sort by date
|
||||
sorted_dates = sorted(portfolio_values.keys())
|
||||
initial_value = portfolio_values[sorted_dates[0]]
|
||||
final_value = portfolio_values[sorted_dates[-1]]
|
||||
|
||||
if initial_value == 0:
|
||||
return 0.0
|
||||
|
||||
# Calculate investment days
|
||||
start_date = datetime.strptime(sorted_dates[0], "%Y-%m-%d")
|
||||
end_date = datetime.strptime(sorted_dates[-1], "%Y-%m-%d")
|
||||
days = (end_date - start_date).days
|
||||
|
||||
if days == 0:
|
||||
return 0.0
|
||||
|
||||
# Calculate annualized return
|
||||
total_return = (final_value - initial_value) / initial_value
|
||||
annualized_return = (1 + total_return) ** (365 / days) - 1
|
||||
|
||||
return annualized_return
|
||||
|
||||
|
||||
def calculate_volatility(returns: List[float]) -> float:
|
||||
"""
|
||||
Calculate annualized volatility
|
||||
|
||||
Args:
|
||||
returns: List of returns
|
||||
|
||||
Returns:
|
||||
Annualized volatility
|
||||
"""
|
||||
if not returns or len(returns) < 2:
|
||||
return 0.0
|
||||
|
||||
returns_array = np.array(returns)
|
||||
daily_volatility = np.std(returns_array, ddof=1)
|
||||
|
||||
# Annualize volatility (assuming 252 trading days)
|
||||
annualized_volatility = daily_volatility * np.sqrt(252)
|
||||
|
||||
return annualized_volatility
|
||||
|
||||
|
||||
def calculate_win_rate(returns: List[float]) -> float:
|
||||
"""
|
||||
Calculate win rate
|
||||
|
||||
Args:
|
||||
returns: List of returns
|
||||
|
||||
Returns:
|
||||
Win rate (percentage of positive return days)
|
||||
"""
|
||||
if not returns:
|
||||
return 0.0
|
||||
|
||||
positive_days = sum(1 for r in returns if r > 0)
|
||||
total_days = len(returns)
|
||||
|
||||
return positive_days / total_days
|
||||
|
||||
|
||||
def calculate_profit_loss_ratio(returns: List[float]) -> float:
|
||||
"""
|
||||
Calculate profit/loss ratio
|
||||
|
||||
Args:
|
||||
returns: List of returns
|
||||
|
||||
Returns:
|
||||
Profit/loss ratio (average profit / average loss)
|
||||
"""
|
||||
if not returns:
|
||||
return 0.0
|
||||
|
||||
positive_returns = [r for r in returns if r > 0]
|
||||
negative_returns = [r for r in returns if r < 0]
|
||||
|
||||
if not positive_returns or not negative_returns:
|
||||
return 0.0
|
||||
|
||||
avg_profit = np.mean(positive_returns)
|
||||
avg_loss = abs(np.mean(negative_returns))
|
||||
|
||||
if avg_loss == 0:
|
||||
return 0.0
|
||||
|
||||
return avg_profit / avg_loss
|
||||
|
||||
|
||||
def calculate_all_metrics(modelname: str, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, any]:
|
||||
"""
|
||||
Calculate all performance metrics
|
||||
|
||||
Args:
|
||||
modelname: Model name
|
||||
start_date: Start date in YYYY-MM-DD format, uses earliest date if None
|
||||
end_date: End date in YYYY-MM-DD format, uses latest date if None
|
||||
|
||||
Returns:
|
||||
Dictionary containing all metrics
|
||||
"""
|
||||
# Get available date range if not specified
|
||||
if start_date is None or end_date is None:
|
||||
earliest_date, latest_date = get_available_date_range(modelname)
|
||||
if not earliest_date or not latest_date:
|
||||
return {
|
||||
"error": "Unable to get available data date range",
|
||||
"portfolio_values": {},
|
||||
"daily_returns": [],
|
||||
"sharpe_ratio": 0.0,
|
||||
"max_drawdown": 0.0,
|
||||
"max_drawdown_start": "",
|
||||
"max_drawdown_end": "",
|
||||
"cumulative_return": 0.0,
|
||||
"annualized_return": 0.0,
|
||||
"volatility": 0.0,
|
||||
"win_rate": 0.0,
|
||||
"profit_loss_ratio": 0.0,
|
||||
"total_trading_days": 0,
|
||||
"start_date": "",
|
||||
"end_date": ""
|
||||
}
|
||||
|
||||
if start_date is None:
|
||||
start_date = earliest_date
|
||||
if end_date is None:
|
||||
end_date = latest_date
|
||||
|
||||
# 获取每日投资组合价值
|
||||
portfolio_values = get_daily_portfolio_values(modelname, start_date, end_date)
|
||||
|
||||
if not portfolio_values:
|
||||
return {
|
||||
"error": "Unable to get portfolio data",
|
||||
"portfolio_values": {},
|
||||
"daily_returns": [],
|
||||
"sharpe_ratio": 0.0,
|
||||
"max_drawdown": 0.0,
|
||||
"max_drawdown_start": "",
|
||||
"max_drawdown_end": "",
|
||||
"cumulative_return": 0.0,
|
||||
"annualized_return": 0.0,
|
||||
"volatility": 0.0,
|
||||
"win_rate": 0.0,
|
||||
"profit_loss_ratio": 0.0,
|
||||
"total_trading_days": 0,
|
||||
"start_date": "",
|
||||
"end_date": ""
|
||||
}
|
||||
|
||||
# Calculate daily returns
|
||||
daily_returns = calculate_daily_returns(portfolio_values)
|
||||
|
||||
# Calculate various metrics
|
||||
sharpe_ratio = calculate_sharpe_ratio(daily_returns)
|
||||
max_drawdown, drawdown_start, drawdown_end = calculate_max_drawdown(portfolio_values)
|
||||
cumulative_return = calculate_cumulative_return(portfolio_values)
|
||||
annualized_return = calculate_annualized_return(portfolio_values)
|
||||
volatility = calculate_volatility(daily_returns)
|
||||
win_rate = calculate_win_rate(daily_returns)
|
||||
profit_loss_ratio = calculate_profit_loss_ratio(daily_returns)
|
||||
|
||||
# Get date range
|
||||
sorted_dates = sorted(portfolio_values.keys())
|
||||
start_date_actual = sorted_dates[0] if sorted_dates else ""
|
||||
end_date_actual = sorted_dates[-1] if sorted_dates else ""
|
||||
|
||||
return {
|
||||
"portfolio_values": portfolio_values,
|
||||
"daily_returns": daily_returns,
|
||||
"sharpe_ratio": round(sharpe_ratio, 4),
|
||||
"max_drawdown": round(max_drawdown, 4),
|
||||
"max_drawdown_start": drawdown_start,
|
||||
"max_drawdown_end": drawdown_end,
|
||||
"cumulative_return": round(cumulative_return, 4),
|
||||
"annualized_return": round(annualized_return, 4),
|
||||
"volatility": round(volatility, 4),
|
||||
"win_rate": round(win_rate, 4),
|
||||
"profit_loss_ratio": round(profit_loss_ratio, 4),
|
||||
"total_trading_days": len(portfolio_values),
|
||||
"start_date": start_date_actual,
|
||||
"end_date": end_date_actual
|
||||
}
|
||||
|
||||
|
||||
def print_performance_report(metrics: Dict[str, any]) -> None:
|
||||
"""
|
||||
Print performance report
|
||||
|
||||
Args:
|
||||
metrics: Dictionary containing all metrics
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("Portfolio Performance Report")
|
||||
print("=" * 60)
|
||||
|
||||
if "error" in metrics:
|
||||
print(f"Error: {metrics['error']}")
|
||||
return
|
||||
|
||||
print(f"Analysis Period: {metrics['start_date']} to {metrics['end_date']}")
|
||||
print(f"Trading Days: {metrics['total_trading_days']}")
|
||||
print()
|
||||
|
||||
print("Return Metrics:")
|
||||
print(f" Cumulative Return: {metrics['cumulative_return']:.2%}")
|
||||
print(f" Annualized Return: {metrics['annualized_return']:.2%}")
|
||||
print(f" Annualized Volatility: {metrics['volatility']:.2%}")
|
||||
print()
|
||||
|
||||
print("Risk Metrics:")
|
||||
print(f" Sharpe Ratio: {metrics['sharpe_ratio']:.4f}")
|
||||
print(f" Maximum Drawdown: {metrics['max_drawdown']:.2%}")
|
||||
if metrics['max_drawdown_start'] and metrics['max_drawdown_end']:
|
||||
print(f" Drawdown Period: {metrics['max_drawdown_start']} to {metrics['max_drawdown_end']}")
|
||||
print()
|
||||
|
||||
print("Trading Statistics:")
|
||||
print(f" Win Rate: {metrics['win_rate']:.2%}")
|
||||
print(f" Profit/Loss Ratio: {metrics['profit_loss_ratio']:.4f}")
|
||||
print()
|
||||
|
||||
# Show portfolio value changes
|
||||
portfolio_values = metrics['portfolio_values']
|
||||
if portfolio_values:
|
||||
sorted_dates = sorted(portfolio_values.keys())
|
||||
initial_value = portfolio_values[sorted_dates[0]]
|
||||
final_value = portfolio_values[sorted_dates[-1]]
|
||||
|
||||
print("Portfolio Value:")
|
||||
print(f" Initial Value: ${initial_value:,.2f}")
|
||||
print(f" Final Value: ${final_value:,.2f}")
|
||||
print(f" Value Change: ${final_value - initial_value:,.2f}")
|
||||
|
||||
|
||||
def get_next_id(filepath: Path) -> int:
|
||||
"""
|
||||
Get next ID number
|
||||
|
||||
Args:
|
||||
filepath: JSONL file path
|
||||
|
||||
Returns:
|
||||
Next ID number
|
||||
"""
|
||||
if not filepath.exists():
|
||||
return 0
|
||||
|
||||
max_id = -1
|
||||
with filepath.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
current_id = data.get("id", -1)
|
||||
if current_id > max_id:
|
||||
max_id = current_id
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return max_id + 1
|
||||
|
||||
|
||||
def save_metrics_to_jsonl(metrics: Dict[str, any], modelname: str, output_dir: Optional[str] = None) -> str:
|
||||
"""
|
||||
Incrementally save metrics to JSONL format
|
||||
|
||||
Args:
|
||||
metrics: Dictionary containing all metrics
|
||||
modelname: Model name
|
||||
output_dir: Output directory, defaults to data/agent_data/{modelname}/metrics/
|
||||
|
||||
Returns:
|
||||
Path to saved file
|
||||
"""
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = base_dir / "data" / "agent_data" / modelname / "metrics"
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Use fixed filename
|
||||
filename = "performance_metrics.jsonl"
|
||||
filepath = output_dir / filename
|
||||
|
||||
# Get next ID number
|
||||
next_id = get_next_id(filepath)
|
||||
|
||||
# Prepare data to save
|
||||
save_data = {
|
||||
"id": next_id,
|
||||
"model_name": modelname,
|
||||
"analysis_period": {
|
||||
"start_date": metrics.get("start_date", ""),
|
||||
"end_date": metrics.get("end_date", ""),
|
||||
"total_trading_days": metrics.get("total_trading_days", 0)
|
||||
},
|
||||
"performance_metrics": {
|
||||
"sharpe_ratio": metrics.get("sharpe_ratio", 0.0),
|
||||
"max_drawdown": metrics.get("max_drawdown", 0.0),
|
||||
"max_drawdown_period": {
|
||||
"start_date": metrics.get("max_drawdown_start", ""),
|
||||
"end_date": metrics.get("max_drawdown_end", "")
|
||||
},
|
||||
"cumulative_return": metrics.get("cumulative_return", 0.0),
|
||||
"annualized_return": metrics.get("annualized_return", 0.0),
|
||||
"volatility": metrics.get("volatility", 0.0),
|
||||
"win_rate": metrics.get("win_rate", 0.0),
|
||||
"profit_loss_ratio": metrics.get("profit_loss_ratio", 0.0)
|
||||
},
|
||||
"portfolio_summary": {}
|
||||
}
|
||||
|
||||
# Add portfolio value summary
|
||||
portfolio_values = metrics.get("portfolio_values", {})
|
||||
if portfolio_values:
|
||||
sorted_dates = sorted(portfolio_values.keys())
|
||||
initial_value = portfolio_values[sorted_dates[0]]
|
||||
final_value = portfolio_values[sorted_dates[-1]]
|
||||
|
||||
save_data["portfolio_summary"] = {
|
||||
"initial_value": initial_value,
|
||||
"final_value": final_value,
|
||||
"value_change": final_value - initial_value,
|
||||
"value_change_percent": ((final_value - initial_value) / initial_value) if initial_value > 0 else 0.0
|
||||
}
|
||||
|
||||
# Incrementally save to JSONL file (append mode)
|
||||
with filepath.open("a", encoding="utf-8") as f:
|
||||
f.write(json.dumps(save_data, ensure_ascii=False) + "\n")
|
||||
|
||||
return str(filepath)
|
||||
|
||||
|
||||
def get_latest_metrics(modelname: str, output_dir: Optional[str] = None) -> Optional[Dict[str, any]]:
|
||||
"""
|
||||
Get latest performance metrics record
|
||||
|
||||
Args:
|
||||
modelname: Model name
|
||||
output_dir: Output directory, defaults to data/agent_data/{modelname}/metrics/
|
||||
|
||||
Returns:
|
||||
Latest metrics record, or None if no records exist
|
||||
"""
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = base_dir / "data" / "agent_data" / modelname / "metrics"
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
filepath = output_dir / "performance_metrics.jsonl"
|
||||
|
||||
if not filepath.exists():
|
||||
return None
|
||||
|
||||
latest_record = None
|
||||
max_id = -1
|
||||
|
||||
with filepath.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
current_id = data.get("id", -1)
|
||||
if current_id > max_id:
|
||||
max_id = current_id
|
||||
latest_record = data
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return latest_record
|
||||
|
||||
|
||||
def get_metrics_history(modelname: str, output_dir: Optional[str] = None, limit: Optional[int] = None) -> List[Dict[str, any]]:
|
||||
"""
|
||||
Get performance metrics history
|
||||
|
||||
Args:
|
||||
modelname: Model name
|
||||
output_dir: Output directory, defaults to data/agent_data/{modelname}/metrics/
|
||||
limit: Limit number of records returned, None returns all records
|
||||
|
||||
Returns:
|
||||
List of metrics records, sorted by ID
|
||||
"""
|
||||
base_dir = Path(__file__).resolve().parents[1]
|
||||
|
||||
if output_dir is None:
|
||||
output_dir = base_dir / "data" / "agent_data" / modelname / "metrics"
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
filepath = output_dir / "performance_metrics.jsonl"
|
||||
|
||||
if not filepath.exists():
|
||||
return []
|
||||
|
||||
records = []
|
||||
|
||||
with filepath.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if not line.strip():
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
records.append(data)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# Sort by ID
|
||||
records.sort(key=lambda x: x.get("id", 0))
|
||||
|
||||
# Return latest records if limit specified
|
||||
if limit is not None and limit > 0:
|
||||
records = records[-limit:]
|
||||
|
||||
return records
|
||||
|
||||
|
||||
def print_metrics_summary(modelname: str, output_dir: Optional[str] = None) -> None:
|
||||
"""
|
||||
Print performance metrics summary
|
||||
|
||||
Args:
|
||||
modelname: Model name
|
||||
output_dir: Output directory
|
||||
"""
|
||||
print(f"📊 Model '{modelname}' Performance Metrics Summary")
|
||||
print("=" * 60)
|
||||
|
||||
# Get history records
|
||||
history = get_metrics_history(modelname, output_dir)
|
||||
|
||||
if not history:
|
||||
print("❌ No history records found")
|
||||
return
|
||||
|
||||
print(f"📈 Total Records: {len(history)}")
|
||||
|
||||
# Show latest record
|
||||
latest = history[-1]
|
||||
print(f"🕒 Latest Record (ID: {latest['id']}):")
|
||||
print(f" Analysis Period: {latest['analysis_period']['start_date']} to {latest['analysis_period']['end_date']}")
|
||||
print(f" Trading Days: {latest['analysis_period']['total_trading_days']}")
|
||||
|
||||
metrics = latest['performance_metrics']
|
||||
print(f" Sharpe Ratio: {metrics['sharpe_ratio']}")
|
||||
print(f" Maximum Drawdown: {metrics['max_drawdown']:.2%}")
|
||||
print(f" Cumulative Return: {metrics['cumulative_return']:.2%}")
|
||||
print(f" Annualized Return: {metrics['annualized_return']:.2%}")
|
||||
|
||||
# Show trends (if multiple records exist)
|
||||
if len(history) > 1:
|
||||
print(f"\n📊 Trend Analysis (Last {min(5, len(history))} Records):")
|
||||
|
||||
recent_records = history[-5:] if len(history) >= 5 else history
|
||||
|
||||
print("ID | Time | Cum Ret | Ann Ret | Sharpe")
|
||||
print("-" * 70)
|
||||
|
||||
for record in recent_records:
|
||||
metrics = record['performance_metrics']
|
||||
print(f"{record['id']:2d} | {metrics['cumulative_return']:8.2%} | {metrics['annualized_return']:8.2%} | {metrics['sharpe_ratio']:8.4f}")
|
||||
|
||||
|
||||
def calculate_and_save_metrics(modelname: str, start_date: Optional[str] = None, end_date: Optional[str] = None, output_dir: Optional[str] = None, print_report: bool = True) -> Dict[str, any]:
|
||||
"""
|
||||
Entry function to calculate all metrics and save in JSONL format
|
||||
|
||||
Args:
|
||||
modelname: Model name (SIGNATURE)
|
||||
start_date: Start date in YYYY-MM-DD format, uses earliest date if None
|
||||
end_date: End date in YYYY-MM-DD format, uses latest date if None
|
||||
output_dir: Output directory, defaults to data/agent_data/{modelname}/metrics/
|
||||
print_report: Whether to print report
|
||||
|
||||
Returns:
|
||||
Dictionary containing all metrics and saved file path
|
||||
"""
|
||||
print(f"Analyzing model: {modelname}")
|
||||
|
||||
# Show date range to be used if not specified
|
||||
if start_date is None or end_date is None:
|
||||
earliest_date, latest_date = get_available_date_range(modelname)
|
||||
if earliest_date and latest_date:
|
||||
if start_date is None:
|
||||
start_date = earliest_date
|
||||
print(f"Using default start date: {start_date}")
|
||||
if end_date is None:
|
||||
end_date = latest_date
|
||||
print(f"Using default end date: {end_date}")
|
||||
else:
|
||||
print("❌ Unable to get available data date range")
|
||||
|
||||
# Calculate all metrics
|
||||
metrics = calculate_all_metrics(modelname, start_date, end_date)
|
||||
|
||||
if "error" in metrics:
|
||||
print(f"Error: {metrics['error']}")
|
||||
return metrics
|
||||
|
||||
# Save in JSONL format
|
||||
try:
|
||||
saved_file = save_metrics_to_jsonl(metrics, modelname, output_dir)
|
||||
print(f"Metrics saved to: {saved_file}")
|
||||
metrics["saved_file"] = saved_file
|
||||
|
||||
# Get ID of just saved record
|
||||
latest_record = get_latest_metrics(modelname, output_dir)
|
||||
if latest_record:
|
||||
metrics["record_id"] = latest_record["id"]
|
||||
print(f"Record ID: {latest_record['id']}")
|
||||
except Exception as e:
|
||||
print(f"Error saving file: {e}")
|
||||
metrics["save_error"] = str(e)
|
||||
|
||||
# Print report
|
||||
if print_report:
|
||||
print_performance_report(metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test code
|
||||
# 测试代码
|
||||
modelname = get_config_value("SIGNATURE")
|
||||
if modelname is None:
|
||||
print("错误: 未设置 SIGNATURE 环境变量")
|
||||
print("请设置环境变量 SIGNATURE,例如: export SIGNATURE=claude-3.7-sonnet")
|
||||
sys.exit(1)
|
||||
|
||||
# 使用入口函数计算和保存指标
|
||||
result = calculate_and_save_metrics(modelname)
|
||||
Reference in New Issue
Block a user