mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-21 00:57:24 -04:00
feat: add database helper methods for trading_days schema
Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
224
api/database.py
224
api/database.py
@@ -540,3 +540,227 @@ def get_database_stats(db_path: str = "data/jobs.db") -> dict:
|
|||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""Database wrapper class with helper methods for trading_days schema."""
|
||||||
|
|
||||||
|
def __init__(self, db_path: str = None):
|
||||||
|
"""Initialize database connection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to SQLite database file.
|
||||||
|
If None, uses default from deployment config.
|
||||||
|
"""
|
||||||
|
if db_path is None:
|
||||||
|
from tools.deployment_config import get_database_path
|
||||||
|
db_path = get_database_path()
|
||||||
|
|
||||||
|
self.db_path = db_path
|
||||||
|
self.connection = sqlite3.connect(db_path, check_same_thread=False)
|
||||||
|
self.connection.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
def create_trading_day(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
model: str,
|
||||||
|
date: str,
|
||||||
|
starting_cash: float,
|
||||||
|
starting_portfolio_value: float,
|
||||||
|
daily_profit: float,
|
||||||
|
daily_return_pct: float,
|
||||||
|
ending_cash: float,
|
||||||
|
ending_portfolio_value: float,
|
||||||
|
reasoning_summary: str = None,
|
||||||
|
reasoning_full: str = None,
|
||||||
|
total_actions: int = 0,
|
||||||
|
session_duration_seconds: float = None,
|
||||||
|
days_since_last_trading: int = 1
|
||||||
|
) -> int:
|
||||||
|
"""Create a new trading day record.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
trading_day_id
|
||||||
|
"""
|
||||||
|
cursor = self.connection.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO trading_days (
|
||||||
|
job_id, model, date,
|
||||||
|
starting_cash, starting_portfolio_value,
|
||||||
|
daily_profit, daily_return_pct,
|
||||||
|
ending_cash, ending_portfolio_value,
|
||||||
|
reasoning_summary, reasoning_full,
|
||||||
|
total_actions, session_duration_seconds,
|
||||||
|
days_since_last_trading,
|
||||||
|
completed_at
|
||||||
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
job_id, model, date,
|
||||||
|
starting_cash, starting_portfolio_value,
|
||||||
|
daily_profit, daily_return_pct,
|
||||||
|
ending_cash, ending_portfolio_value,
|
||||||
|
reasoning_summary, reasoning_full,
|
||||||
|
total_actions, session_duration_seconds,
|
||||||
|
days_since_last_trading
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
def get_previous_trading_day(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
model: str,
|
||||||
|
current_date: str
|
||||||
|
) -> dict:
|
||||||
|
"""Get the most recent trading day before current_date.
|
||||||
|
|
||||||
|
Handles weekends/holidays by finding actual previous trading day.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys: id, date, ending_cash, ending_portfolio_value
|
||||||
|
or None if no previous day exists
|
||||||
|
"""
|
||||||
|
cursor = self.connection.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, date, ending_cash, ending_portfolio_value
|
||||||
|
FROM trading_days
|
||||||
|
WHERE job_id = ? AND model = ? AND date < ?
|
||||||
|
ORDER BY date DESC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(job_id, model, current_date)
|
||||||
|
)
|
||||||
|
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if row:
|
||||||
|
return {
|
||||||
|
"id": row[0],
|
||||||
|
"date": row[1],
|
||||||
|
"ending_cash": row[2],
|
||||||
|
"ending_portfolio_value": row[3]
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_ending_holdings(self, trading_day_id: int) -> list:
|
||||||
|
"""Get ending holdings for a trading day.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with keys: symbol, quantity
|
||||||
|
"""
|
||||||
|
cursor = self.connection.execute(
|
||||||
|
"""
|
||||||
|
SELECT symbol, quantity
|
||||||
|
FROM holdings
|
||||||
|
WHERE trading_day_id = ?
|
||||||
|
ORDER BY symbol
|
||||||
|
""",
|
||||||
|
(trading_day_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [{"symbol": row[0], "quantity": row[1]} for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
def get_starting_holdings(self, trading_day_id: int) -> list:
|
||||||
|
"""Get starting holdings from previous day's ending holdings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with keys: symbol, quantity
|
||||||
|
Empty list if first trading day
|
||||||
|
"""
|
||||||
|
# Get previous trading day
|
||||||
|
cursor = self.connection.execute(
|
||||||
|
"""
|
||||||
|
SELECT td_prev.id
|
||||||
|
FROM trading_days td_current
|
||||||
|
JOIN trading_days td_prev ON
|
||||||
|
td_prev.job_id = td_current.job_id AND
|
||||||
|
td_prev.model = td_current.model AND
|
||||||
|
td_prev.date < td_current.date
|
||||||
|
WHERE td_current.id = ?
|
||||||
|
ORDER BY td_prev.date DESC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(trading_day_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
row = cursor.fetchone()
|
||||||
|
if not row:
|
||||||
|
# First trading day - no previous holdings
|
||||||
|
return []
|
||||||
|
|
||||||
|
previous_day_id = row[0]
|
||||||
|
|
||||||
|
# Get previous day's ending holdings
|
||||||
|
return self.get_ending_holdings(previous_day_id)
|
||||||
|
|
||||||
|
def create_holding(
|
||||||
|
self,
|
||||||
|
trading_day_id: int,
|
||||||
|
symbol: str,
|
||||||
|
quantity: int
|
||||||
|
) -> int:
|
||||||
|
"""Create a holding record.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
holding_id
|
||||||
|
"""
|
||||||
|
cursor = self.connection.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO holdings (trading_day_id, symbol, quantity)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
""",
|
||||||
|
(trading_day_id, symbol, quantity)
|
||||||
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
def create_action(
|
||||||
|
self,
|
||||||
|
trading_day_id: int,
|
||||||
|
action_type: str,
|
||||||
|
symbol: str = None,
|
||||||
|
quantity: int = None,
|
||||||
|
price: float = None
|
||||||
|
) -> int:
|
||||||
|
"""Create an action record.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
action_id
|
||||||
|
"""
|
||||||
|
cursor = self.connection.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO actions (trading_day_id, action_type, symbol, quantity, price)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(trading_day_id, action_type, symbol, quantity, price)
|
||||||
|
)
|
||||||
|
self.connection.commit()
|
||||||
|
return cursor.lastrowid
|
||||||
|
|
||||||
|
def get_actions(self, trading_day_id: int) -> list:
|
||||||
|
"""Get all actions for a trading day.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with keys: action_type, symbol, quantity, price, created_at
|
||||||
|
"""
|
||||||
|
cursor = self.connection.execute(
|
||||||
|
"""
|
||||||
|
SELECT action_type, symbol, quantity, price, created_at
|
||||||
|
FROM actions
|
||||||
|
WHERE trading_day_id = ?
|
||||||
|
ORDER BY created_at
|
||||||
|
""",
|
||||||
|
(trading_day_id,)
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"action_type": row[0],
|
||||||
|
"symbol": row[1],
|
||||||
|
"quantity": row[2],
|
||||||
|
"price": row[3],
|
||||||
|
"created_at": row[4]
|
||||||
|
}
|
||||||
|
for row in cursor.fetchall()
|
||||||
|
]
|
||||||
|
|||||||
288
tests/unit/test_database_helpers.py
Normal file
288
tests/unit/test_database_helpers.py
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
import pytest
|
||||||
|
from datetime import datetime
|
||||||
|
from api.database import Database
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabaseHelpers:
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db(self, tmp_path):
|
||||||
|
"""Create test database with schema."""
|
||||||
|
import importlib
|
||||||
|
migration_module = importlib.import_module('api.migrations.001_trading_days_schema')
|
||||||
|
create_trading_days_schema = migration_module.create_trading_days_schema
|
||||||
|
|
||||||
|
db_path = tmp_path / "test.db"
|
||||||
|
db = Database(str(db_path))
|
||||||
|
|
||||||
|
# Create jobs table (prerequisite)
|
||||||
|
db.connection.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS jobs (
|
||||||
|
job_id TEXT PRIMARY KEY,
|
||||||
|
status TEXT,
|
||||||
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
create_trading_days_schema(db)
|
||||||
|
return db
|
||||||
|
|
||||||
|
def test_create_trading_day(self, db):
|
||||||
|
"""Test creating a new trading day record."""
|
||||||
|
# Insert job first
|
||||||
|
db.connection.execute(
|
||||||
|
"INSERT INTO jobs (job_id, status) VALUES (?, ?)",
|
||||||
|
("test-job", "running")
|
||||||
|
)
|
||||||
|
|
||||||
|
trading_day_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-15",
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=9500.0,
|
||||||
|
ending_portfolio_value=9500.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert trading_day_id is not None
|
||||||
|
|
||||||
|
# Verify record created
|
||||||
|
cursor = db.connection.execute(
|
||||||
|
"SELECT * FROM trading_days WHERE id = ?",
|
||||||
|
(trading_day_id,)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
|
def test_get_previous_trading_day(self, db):
|
||||||
|
"""Test retrieving previous trading day."""
|
||||||
|
# Setup: Create job and two trading days
|
||||||
|
db.connection.execute(
|
||||||
|
"INSERT INTO jobs (job_id, status) VALUES (?, ?)",
|
||||||
|
("test-job", "running")
|
||||||
|
)
|
||||||
|
|
||||||
|
day1_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-15",
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=9500.0,
|
||||||
|
ending_portfolio_value=9500.0
|
||||||
|
)
|
||||||
|
|
||||||
|
day2_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-16",
|
||||||
|
starting_cash=9500.0,
|
||||||
|
starting_portfolio_value=9500.0,
|
||||||
|
daily_profit=-500.0,
|
||||||
|
daily_return_pct=-5.0,
|
||||||
|
ending_cash=9700.0,
|
||||||
|
ending_portfolio_value=9700.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test: Get previous day from day2
|
||||||
|
previous = db.get_previous_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
current_date="2025-01-16"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert previous is not None
|
||||||
|
assert previous["date"] == "2025-01-15"
|
||||||
|
assert previous["ending_cash"] == 9500.0
|
||||||
|
|
||||||
|
def test_get_previous_trading_day_with_weekend_gap(self, db):
|
||||||
|
"""Test retrieving previous trading day across weekend."""
|
||||||
|
db.connection.execute(
|
||||||
|
"INSERT INTO jobs (job_id, status) VALUES (?, ?)",
|
||||||
|
("test-job", "running")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Friday
|
||||||
|
db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-17", # Friday
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=9500.0,
|
||||||
|
ending_portfolio_value=9500.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test: Get previous from Monday (should find Friday)
|
||||||
|
previous = db.get_previous_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
current_date="2025-01-20" # Monday
|
||||||
|
)
|
||||||
|
|
||||||
|
assert previous is not None
|
||||||
|
assert previous["date"] == "2025-01-17"
|
||||||
|
|
||||||
|
def test_get_ending_holdings(self, db):
|
||||||
|
"""Test retrieving ending holdings for a trading day."""
|
||||||
|
db.connection.execute(
|
||||||
|
"INSERT INTO jobs (job_id, status) VALUES (?, ?)",
|
||||||
|
("test-job", "running")
|
||||||
|
)
|
||||||
|
|
||||||
|
trading_day_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-15",
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=9000.0,
|
||||||
|
ending_portfolio_value=10000.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add holdings
|
||||||
|
db.create_holding(trading_day_id, "AAPL", 10)
|
||||||
|
db.create_holding(trading_day_id, "MSFT", 5)
|
||||||
|
|
||||||
|
# Test
|
||||||
|
holdings = db.get_ending_holdings(trading_day_id)
|
||||||
|
|
||||||
|
assert len(holdings) == 2
|
||||||
|
assert {"symbol": "AAPL", "quantity": 10} in holdings
|
||||||
|
assert {"symbol": "MSFT", "quantity": 5} in holdings
|
||||||
|
|
||||||
|
def test_get_starting_holdings_first_day(self, db):
|
||||||
|
"""Test starting holdings for first trading day (should be empty)."""
|
||||||
|
db.connection.execute(
|
||||||
|
"INSERT INTO jobs (job_id, status) VALUES (?, ?)",
|
||||||
|
("test-job", "running")
|
||||||
|
)
|
||||||
|
|
||||||
|
trading_day_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-15",
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=9500.0,
|
||||||
|
ending_portfolio_value=9500.0
|
||||||
|
)
|
||||||
|
|
||||||
|
holdings = db.get_starting_holdings(trading_day_id)
|
||||||
|
|
||||||
|
assert holdings == []
|
||||||
|
|
||||||
|
def test_get_starting_holdings_from_previous_day(self, db):
|
||||||
|
"""Test starting holdings derived from previous day's ending."""
|
||||||
|
db.connection.execute(
|
||||||
|
"INSERT INTO jobs (job_id, status) VALUES (?, ?)",
|
||||||
|
("test-job", "running")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Day 1
|
||||||
|
day1_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-15",
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=9000.0,
|
||||||
|
ending_portfolio_value=10000.0
|
||||||
|
)
|
||||||
|
db.create_holding(day1_id, "AAPL", 10)
|
||||||
|
|
||||||
|
# Day 2
|
||||||
|
day2_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-16",
|
||||||
|
starting_cash=9000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=8500.0,
|
||||||
|
ending_portfolio_value=9500.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test: Day 2 starting = Day 1 ending
|
||||||
|
holdings = db.get_starting_holdings(day2_id)
|
||||||
|
|
||||||
|
assert len(holdings) == 1
|
||||||
|
assert holdings[0]["symbol"] == "AAPL"
|
||||||
|
assert holdings[0]["quantity"] == 10
|
||||||
|
|
||||||
|
def test_create_action(self, db):
|
||||||
|
"""Test creating an action record."""
|
||||||
|
db.connection.execute(
|
||||||
|
"INSERT INTO jobs (job_id, status) VALUES (?, ?)",
|
||||||
|
("test-job", "running")
|
||||||
|
)
|
||||||
|
|
||||||
|
trading_day_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-15",
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=9500.0,
|
||||||
|
ending_portfolio_value=9500.0
|
||||||
|
)
|
||||||
|
|
||||||
|
action_id = db.create_action(
|
||||||
|
trading_day_id=trading_day_id,
|
||||||
|
action_type="buy",
|
||||||
|
symbol="AAPL",
|
||||||
|
quantity=10,
|
||||||
|
price=100.0
|
||||||
|
)
|
||||||
|
|
||||||
|
assert action_id is not None
|
||||||
|
|
||||||
|
# Verify
|
||||||
|
cursor = db.connection.execute(
|
||||||
|
"SELECT * FROM actions WHERE id = ?",
|
||||||
|
(action_id,)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
|
||||||
|
def test_get_actions(self, db):
|
||||||
|
"""Test retrieving all actions for a trading day."""
|
||||||
|
db.connection.execute(
|
||||||
|
"INSERT INTO jobs (job_id, status) VALUES (?, ?)",
|
||||||
|
("test-job", "running")
|
||||||
|
)
|
||||||
|
|
||||||
|
trading_day_id = db.create_trading_day(
|
||||||
|
job_id="test-job",
|
||||||
|
model="gpt-4",
|
||||||
|
date="2025-01-15",
|
||||||
|
starting_cash=10000.0,
|
||||||
|
starting_portfolio_value=10000.0,
|
||||||
|
daily_profit=0.0,
|
||||||
|
daily_return_pct=0.0,
|
||||||
|
ending_cash=9500.0,
|
||||||
|
ending_portfolio_value=9500.0
|
||||||
|
)
|
||||||
|
|
||||||
|
db.create_action(trading_day_id, "buy", "AAPL", 10, 100.0)
|
||||||
|
db.create_action(trading_day_id, "sell", "MSFT", 5, 50.0)
|
||||||
|
|
||||||
|
actions = db.get_actions(trading_day_id)
|
||||||
|
|
||||||
|
assert len(actions) == 2
|
||||||
Reference in New Issue
Block a user