mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
Major architecture transformation from batch-only to API service with
database persistence for Windmill integration.
## REST API Implementation
- POST /simulate/trigger - Start simulation jobs
- GET /simulate/status/{job_id} - Monitor job progress
- GET /results - Query results with filters (job_id, date, model)
- GET /health - Service health checks
## Database Layer
- SQLite persistence with 6 tables (jobs, job_details, positions,
holdings, reasoning_logs, tool_usage)
- Foreign key constraints with cascade deletes
- Replaces JSONL file storage
## Backend Components
- JobManager: Job lifecycle management with concurrency control
- RuntimeConfigManager: Thread-safe isolated runtime configs
- ModelDayExecutor: Single model-day execution engine
- SimulationWorker: Date-sequential, model-parallel orchestration
## Testing
- 102 unit and integration tests (85% coverage)
- Database: 98% coverage
- Job manager: 98% coverage
- API endpoints: 81% coverage
- Pydantic models: 100% coverage
- TDD approach throughout
## Docker Deployment
- Dual-mode: API server (persistent) + batch (one-time)
- Health checks with 30s interval
- Volume persistence for database and logs
- Separate entrypoints for each mode
## Validation Tools
- scripts/validate_docker_build.sh - Build validation
- scripts/test_api_endpoints.sh - Complete API testing
- scripts/test_batch_mode.sh - Batch mode validation
- DOCKER_API.md - Deployment guide
- TESTING_GUIDE.md - Testing procedures
## Configuration
- API_PORT environment variable (default: 8080)
- Backwards compatible with existing configs
- FastAPI, uvicorn, pydantic>=2.0 dependencies
Co-Authored-By: AI Assistant <noreply@example.com>
502 lines
16 KiB
Python
502 lines
16 KiB
Python
"""
|
|
Unit tests for api/database.py module.
|
|
|
|
Coverage target: 95%+
|
|
|
|
Tests verify:
|
|
- Database connection management
|
|
- Schema initialization
|
|
- Table creation and indexes
|
|
- Foreign key constraints
|
|
- Utility functions
|
|
"""
|
|
|
|
import pytest
|
|
import sqlite3
|
|
import os
|
|
import tempfile
|
|
from pathlib import Path
|
|
from api.database import (
|
|
get_db_connection,
|
|
initialize_database,
|
|
drop_all_tables,
|
|
vacuum_database,
|
|
get_database_stats
|
|
)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestDatabaseConnection:
|
|
"""Test database connection functionality."""
|
|
|
|
def test_get_db_connection_creates_directory(self):
|
|
"""Should create data directory if it doesn't exist."""
|
|
temp_dir = tempfile.mkdtemp()
|
|
db_path = os.path.join(temp_dir, "subdir", "test.db")
|
|
|
|
conn = get_db_connection(db_path)
|
|
assert conn is not None
|
|
assert os.path.exists(os.path.dirname(db_path))
|
|
|
|
conn.close()
|
|
os.unlink(db_path)
|
|
os.rmdir(os.path.dirname(db_path))
|
|
os.rmdir(temp_dir)
|
|
|
|
def test_get_db_connection_enables_foreign_keys(self):
|
|
"""Should enable foreign key constraints."""
|
|
temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
temp_db.close()
|
|
|
|
conn = get_db_connection(temp_db.name)
|
|
|
|
# Check if foreign keys are enabled
|
|
cursor = conn.cursor()
|
|
cursor.execute("PRAGMA foreign_keys")
|
|
result = cursor.fetchone()[0]
|
|
|
|
assert result == 1 # 1 = enabled
|
|
|
|
conn.close()
|
|
os.unlink(temp_db.name)
|
|
|
|
def test_get_db_connection_row_factory(self):
|
|
"""Should set row factory for dict-like access."""
|
|
temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
temp_db.close()
|
|
|
|
conn = get_db_connection(temp_db.name)
|
|
|
|
assert conn.row_factory == sqlite3.Row
|
|
|
|
conn.close()
|
|
os.unlink(temp_db.name)
|
|
|
|
def test_get_db_connection_thread_safety(self):
|
|
"""Should allow check_same_thread=False for async compatibility."""
|
|
temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
|
temp_db.close()
|
|
|
|
# This should not raise an error
|
|
conn = get_db_connection(temp_db.name)
|
|
assert conn is not None
|
|
|
|
conn.close()
|
|
os.unlink(temp_db.name)
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestSchemaInitialization:
|
|
"""Test database schema initialization."""
|
|
|
|
def test_initialize_database_creates_all_tables(self, clean_db):
|
|
"""Should create all 6 tables."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
# Query sqlite_master for table names
|
|
cursor.execute("""
|
|
SELECT name FROM sqlite_master
|
|
WHERE type='table' AND name NOT LIKE 'sqlite_%'
|
|
ORDER BY name
|
|
""")
|
|
|
|
tables = [row[0] for row in cursor.fetchall()]
|
|
|
|
expected_tables = [
|
|
'holdings',
|
|
'job_details',
|
|
'jobs',
|
|
'positions',
|
|
'reasoning_logs',
|
|
'tool_usage'
|
|
]
|
|
|
|
assert sorted(tables) == sorted(expected_tables)
|
|
|
|
conn.close()
|
|
|
|
def test_initialize_database_creates_jobs_table(self, clean_db):
|
|
"""Should create jobs table with correct schema."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("PRAGMA table_info(jobs)")
|
|
columns = {row[1]: row[2] for row in cursor.fetchall()}
|
|
|
|
expected_columns = {
|
|
'job_id': 'TEXT',
|
|
'config_path': 'TEXT',
|
|
'status': 'TEXT',
|
|
'date_range': 'TEXT',
|
|
'models': 'TEXT',
|
|
'created_at': 'TEXT',
|
|
'started_at': 'TEXT',
|
|
'updated_at': 'TEXT',
|
|
'completed_at': 'TEXT',
|
|
'total_duration_seconds': 'REAL',
|
|
'error': 'TEXT'
|
|
}
|
|
|
|
for col_name, col_type in expected_columns.items():
|
|
assert col_name in columns
|
|
assert columns[col_name] == col_type
|
|
|
|
conn.close()
|
|
|
|
def test_initialize_database_creates_positions_table(self, clean_db):
|
|
"""Should create positions table with correct schema."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("PRAGMA table_info(positions)")
|
|
columns = {row[1]: row[2] for row in cursor.fetchall()}
|
|
|
|
required_columns = [
|
|
'id', 'job_id', 'date', 'model', 'action_id', 'action_type',
|
|
'symbol', 'amount', 'price', 'cash', 'portfolio_value',
|
|
'daily_profit', 'daily_return_pct', 'cumulative_profit',
|
|
'cumulative_return_pct', 'created_at'
|
|
]
|
|
|
|
for col_name in required_columns:
|
|
assert col_name in columns
|
|
|
|
conn.close()
|
|
|
|
def test_initialize_database_creates_indexes(self, clean_db):
|
|
"""Should create all performance indexes."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT name FROM sqlite_master
|
|
WHERE type='index' AND name LIKE 'idx_%'
|
|
ORDER BY name
|
|
""")
|
|
|
|
indexes = [row[0] for row in cursor.fetchall()]
|
|
|
|
required_indexes = [
|
|
'idx_jobs_status',
|
|
'idx_jobs_created_at',
|
|
'idx_job_details_job_id',
|
|
'idx_job_details_status',
|
|
'idx_job_details_unique',
|
|
'idx_positions_job_id',
|
|
'idx_positions_date',
|
|
'idx_positions_model',
|
|
'idx_positions_date_model',
|
|
'idx_positions_unique',
|
|
'idx_holdings_position_id',
|
|
'idx_holdings_symbol',
|
|
'idx_reasoning_logs_job_date_model',
|
|
'idx_tool_usage_job_date_model'
|
|
]
|
|
|
|
for index in required_indexes:
|
|
assert index in indexes, f"Missing index: {index}"
|
|
|
|
conn.close()
|
|
|
|
def test_initialize_database_idempotent(self, clean_db):
|
|
"""Should be safe to call multiple times."""
|
|
# Initialize once (already done by clean_db fixture)
|
|
# Initialize again
|
|
initialize_database(clean_db)
|
|
|
|
# Should still have correct tables
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
cursor.execute("""
|
|
SELECT COUNT(*) FROM sqlite_master
|
|
WHERE type='table' AND name='jobs'
|
|
""")
|
|
|
|
assert cursor.fetchone()[0] == 1 # Only one jobs table
|
|
|
|
conn.close()
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestForeignKeyConstraints:
|
|
"""Test foreign key constraint enforcement."""
|
|
|
|
def test_cascade_delete_job_details(self, clean_db, sample_job_data):
|
|
"""Should cascade delete job_details when job is deleted."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
# Insert job
|
|
cursor.execute("""
|
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
sample_job_data["job_id"],
|
|
sample_job_data["config_path"],
|
|
sample_job_data["status"],
|
|
sample_job_data["date_range"],
|
|
sample_job_data["models"],
|
|
sample_job_data["created_at"]
|
|
))
|
|
|
|
# Insert job_detail
|
|
cursor.execute("""
|
|
INSERT INTO job_details (job_id, date, model, status)
|
|
VALUES (?, ?, ?, ?)
|
|
""", (sample_job_data["job_id"], "2025-01-16", "gpt-5", "pending"))
|
|
|
|
conn.commit()
|
|
|
|
# Verify job_detail exists
|
|
cursor.execute("SELECT COUNT(*) FROM job_details WHERE job_id = ?", (sample_job_data["job_id"],))
|
|
assert cursor.fetchone()[0] == 1
|
|
|
|
# Delete job
|
|
cursor.execute("DELETE FROM jobs WHERE job_id = ?", (sample_job_data["job_id"],))
|
|
conn.commit()
|
|
|
|
# Verify job_detail was cascade deleted
|
|
cursor.execute("SELECT COUNT(*) FROM job_details WHERE job_id = ?", (sample_job_data["job_id"],))
|
|
assert cursor.fetchone()[0] == 0
|
|
|
|
conn.close()
|
|
|
|
def test_cascade_delete_positions(self, clean_db, sample_job_data, sample_position_data):
|
|
"""Should cascade delete positions when job is deleted."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
# Insert job
|
|
cursor.execute("""
|
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
sample_job_data["job_id"],
|
|
sample_job_data["config_path"],
|
|
sample_job_data["status"],
|
|
sample_job_data["date_range"],
|
|
sample_job_data["models"],
|
|
sample_job_data["created_at"]
|
|
))
|
|
|
|
# Insert position
|
|
cursor.execute("""
|
|
INSERT INTO positions (
|
|
job_id, date, model, action_id, action_type, symbol, amount, price,
|
|
cash, portfolio_value, daily_profit, daily_return_pct,
|
|
cumulative_profit, cumulative_return_pct, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", tuple(sample_position_data.values()))
|
|
|
|
conn.commit()
|
|
|
|
# Delete job
|
|
cursor.execute("DELETE FROM jobs WHERE job_id = ?", (sample_job_data["job_id"],))
|
|
conn.commit()
|
|
|
|
# Verify position was cascade deleted
|
|
cursor.execute("SELECT COUNT(*) FROM positions WHERE job_id = ?", (sample_job_data["job_id"],))
|
|
assert cursor.fetchone()[0] == 0
|
|
|
|
conn.close()
|
|
|
|
def test_cascade_delete_holdings(self, clean_db, sample_job_data, sample_position_data):
|
|
"""Should cascade delete holdings when position is deleted."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
# Insert job
|
|
cursor.execute("""
|
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
sample_job_data["job_id"],
|
|
sample_job_data["config_path"],
|
|
sample_job_data["status"],
|
|
sample_job_data["date_range"],
|
|
sample_job_data["models"],
|
|
sample_job_data["created_at"]
|
|
))
|
|
|
|
# Insert position
|
|
cursor.execute("""
|
|
INSERT INTO positions (
|
|
job_id, date, model, action_id, action_type, symbol, amount, price,
|
|
cash, portfolio_value, daily_profit, daily_return_pct,
|
|
cumulative_profit, cumulative_return_pct, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", tuple(sample_position_data.values()))
|
|
|
|
position_id = cursor.lastrowid
|
|
|
|
# Insert holding
|
|
cursor.execute("""
|
|
INSERT INTO holdings (position_id, symbol, quantity)
|
|
VALUES (?, ?, ?)
|
|
""", (position_id, "AAPL", 10))
|
|
|
|
conn.commit()
|
|
|
|
# Verify holding exists
|
|
cursor.execute("SELECT COUNT(*) FROM holdings WHERE position_id = ?", (position_id,))
|
|
assert cursor.fetchone()[0] == 1
|
|
|
|
# Delete position
|
|
cursor.execute("DELETE FROM positions WHERE id = ?", (position_id,))
|
|
conn.commit()
|
|
|
|
# Verify holding was cascade deleted
|
|
cursor.execute("SELECT COUNT(*) FROM holdings WHERE position_id = ?", (position_id,))
|
|
assert cursor.fetchone()[0] == 0
|
|
|
|
conn.close()
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestUtilityFunctions:
|
|
"""Test database utility functions."""
|
|
|
|
def test_drop_all_tables(self, test_db_path):
|
|
"""Should drop all tables when called."""
|
|
# Initialize database
|
|
initialize_database(test_db_path)
|
|
|
|
# Verify tables exist
|
|
conn = get_db_connection(test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
|
assert cursor.fetchone()[0] == 6
|
|
conn.close()
|
|
|
|
# Drop all tables
|
|
drop_all_tables(test_db_path)
|
|
|
|
# Verify tables are gone
|
|
conn = get_db_connection(test_db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
|
|
assert cursor.fetchone()[0] == 0
|
|
conn.close()
|
|
|
|
def test_vacuum_database(self, clean_db):
|
|
"""Should execute VACUUM command without errors."""
|
|
# This should not raise an error
|
|
vacuum_database(clean_db)
|
|
|
|
# Verify database still accessible
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT COUNT(*) FROM jobs")
|
|
assert cursor.fetchone()[0] == 0
|
|
conn.close()
|
|
|
|
def test_get_database_stats_empty(self, clean_db):
|
|
"""Should return correct stats for empty database."""
|
|
stats = get_database_stats(clean_db)
|
|
|
|
assert "database_size_mb" in stats
|
|
assert stats["jobs"] == 0
|
|
assert stats["job_details"] == 0
|
|
assert stats["positions"] == 0
|
|
assert stats["holdings"] == 0
|
|
assert stats["reasoning_logs"] == 0
|
|
assert stats["tool_usage"] == 0
|
|
|
|
def test_get_database_stats_with_data(self, clean_db, sample_job_data):
|
|
"""Should return correct row counts with data."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
# Insert job
|
|
cursor.execute("""
|
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", (
|
|
sample_job_data["job_id"],
|
|
sample_job_data["config_path"],
|
|
sample_job_data["status"],
|
|
sample_job_data["date_range"],
|
|
sample_job_data["models"],
|
|
sample_job_data["created_at"]
|
|
))
|
|
|
|
# Insert job_detail
|
|
cursor.execute("""
|
|
INSERT INTO job_details (job_id, date, model, status)
|
|
VALUES (?, ?, ?, ?)
|
|
""", (sample_job_data["job_id"], "2025-01-16", "gpt-5", "pending"))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
stats = get_database_stats(clean_db)
|
|
|
|
assert stats["jobs"] == 1
|
|
assert stats["job_details"] == 1
|
|
assert stats["database_size_mb"] > 0
|
|
|
|
|
|
@pytest.mark.unit
|
|
class TestCheckConstraints:
|
|
"""Test CHECK constraints on table columns."""
|
|
|
|
def test_jobs_status_constraint(self, clean_db):
|
|
"""Should reject invalid job status values."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
# Try to insert job with invalid status
|
|
with pytest.raises(sqlite3.IntegrityError, match="CHECK constraint failed"):
|
|
cursor.execute("""
|
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", ("test-job", "configs/test.json", "invalid_status", "[]", "[]", "2025-01-20T00:00:00Z"))
|
|
|
|
conn.close()
|
|
|
|
def test_job_details_status_constraint(self, clean_db, sample_job_data):
|
|
"""Should reject invalid job_detail status values."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
# Insert valid job first
|
|
cursor.execute("""
|
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", tuple(sample_job_data.values()))
|
|
|
|
# Try to insert job_detail with invalid status
|
|
with pytest.raises(sqlite3.IntegrityError, match="CHECK constraint failed"):
|
|
cursor.execute("""
|
|
INSERT INTO job_details (job_id, date, model, status)
|
|
VALUES (?, ?, ?, ?)
|
|
""", (sample_job_data["job_id"], "2025-01-16", "gpt-5", "invalid_status"))
|
|
|
|
conn.close()
|
|
|
|
def test_positions_action_type_constraint(self, clean_db, sample_job_data):
|
|
"""Should reject invalid action_type values."""
|
|
conn = get_db_connection(clean_db)
|
|
cursor = conn.cursor()
|
|
|
|
# Insert valid job first
|
|
cursor.execute("""
|
|
INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?)
|
|
""", tuple(sample_job_data.values()))
|
|
|
|
# Try to insert position with invalid action_type
|
|
with pytest.raises(sqlite3.IntegrityError, match="CHECK constraint failed"):
|
|
cursor.execute("""
|
|
INSERT INTO positions (
|
|
job_id, date, model, action_id, action_type, cash, portfolio_value, created_at
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
|
""", (sample_job_data["job_id"], "2025-01-16", "gpt-5", 1, "invalid_action", 10000, 10000, "2025-01-16T00:00:00Z"))
|
|
|
|
conn.close()
|
|
|
|
|
|
# Coverage target: 95%+ for api/database.py
|