From 0f728549f1a28f23141b86adbf16097795fd6dd4 Mon Sep 17 00:00:00 2001 From: Bill Date: Tue, 4 Nov 2025 10:36:36 -0500 Subject: [PATCH] test: remove old-schema tests and update for new schema - Removed test files for old schema (reasoning_e2e, position_tracking_bugs) - Updated test_database.py to reference new tables (trading_days, holdings, actions) - Updated conftest.py to clean new schema tables - Fixed index name assertions to match new schema - Updated table count expectations (9 tables in new schema) Known issues: - Some cascade delete tests fail (trading_days FK doesn't have ON DELETE CASCADE) - Database locking issues in some test scenarios - These will be addressed in future cleanup --- api/database.py | 26 +- tests/conftest.py | 43 +- tests/integration/test_reasoning_e2e.py | 527 ---------------------- tests/unit/test_database.py | 203 ++++----- tests/unit/test_position_tracking_bugs.py | 309 ------------- 5 files changed, 115 insertions(+), 993 deletions(-) delete mode 100644 tests/integration/test_reasoning_e2e.py delete mode 100644 tests/unit/test_position_tracking_bugs.py diff --git a/api/database.py b/api/database.py index a55f5cb..49d1fea 100644 --- a/api/database.py +++ b/api/database.py @@ -362,30 +362,8 @@ def _create_indexes(cursor: sqlite3.Cursor) -> None: CREATE INDEX IF NOT EXISTS idx_holdings_symbol ON holdings(symbol) """) - # Trading sessions table indexes - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_sessions_job_id ON trading_sessions(job_id) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_sessions_date ON trading_sessions(date) - """) - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_sessions_model ON trading_sessions(model) - """) - cursor.execute(""" - CREATE UNIQUE INDEX IF NOT EXISTS idx_sessions_unique - ON trading_sessions(job_id, date, model) - """) - - # Reasoning logs table indexes - cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_reasoning_logs_session_id - ON reasoning_logs(session_id) - """) - cursor.execute(""" - CREATE UNIQUE INDEX IF NOT EXISTS idx_reasoning_logs_unique - ON reasoning_logs(session_id, message_index) - """) + # OLD TABLE INDEXES REMOVED (trading_sessions, reasoning_logs) + # These tables have been replaced by trading_days with reasoning_full JSON column # Tool usage table indexes cursor.execute(""" diff --git a/tests/conftest.py b/tests/conftest.py index 6f4cc96..048f24f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,23 +44,44 @@ def clean_db(test_db_path): conn = get_db_connection(clean_db) # ... test code """ - # Ensure schema exists + # Ensure schema exists (both old initialize_database and new Database class) initialize_database(test_db_path) + # Also ensure new schema exists (trading_days, holdings, actions) + from api.database import Database + db = Database(test_db_path) + db.connection.close() + # Clear all tables conn = get_db_connection(test_db_path) cursor = conn.cursor() - # Delete in correct order (respecting foreign keys) - cursor.execute("DELETE FROM tool_usage") - cursor.execute("DELETE FROM reasoning_logs") - cursor.execute("DELETE FROM holdings") - cursor.execute("DELETE FROM positions") - cursor.execute("DELETE FROM simulation_runs") - cursor.execute("DELETE FROM job_details") - cursor.execute("DELETE FROM jobs") - cursor.execute("DELETE FROM price_data_coverage") - cursor.execute("DELETE FROM price_data") + # Get list of tables that exist + cursor.execute(""" + SELECT name FROM sqlite_master + WHERE type='table' AND name NOT LIKE 'sqlite_%' + """) + tables = [row[0] for row in cursor.fetchall()] + + # Delete in correct order (respecting foreign keys), only if table exists + if 'tool_usage' in tables: + cursor.execute("DELETE FROM tool_usage") + if 'actions' in tables: + cursor.execute("DELETE FROM actions") + if 'holdings' in tables: + cursor.execute("DELETE FROM holdings") + if 'trading_days' in tables: + cursor.execute("DELETE FROM trading_days") + if 'simulation_runs' in tables: + cursor.execute("DELETE FROM simulation_runs") + if 'job_details' in tables: + cursor.execute("DELETE FROM job_details") + if 'jobs' in tables: + cursor.execute("DELETE FROM jobs") + if 'price_data_coverage' in tables: + cursor.execute("DELETE FROM price_data_coverage") + if 'price_data' in tables: + cursor.execute("DELETE FROM price_data") conn.commit() conn.close() diff --git a/tests/integration/test_reasoning_e2e.py b/tests/integration/test_reasoning_e2e.py deleted file mode 100644 index 955e80d..0000000 --- a/tests/integration/test_reasoning_e2e.py +++ /dev/null @@ -1,527 +0,0 @@ -""" -End-to-end integration tests for reasoning logs API feature. - -Tests the complete flow from simulation trigger to reasoning retrieval. - -These tests verify: -- Trading sessions are created with session_id -- Reasoning logs are stored in database -- Full conversation history is captured -- Message summaries are generated -- GET /reasoning endpoint returns correct data -- Query filters work (job_id, date, model) -- include_full_conversation parameter works correctly -- Positions are linked to sessions -""" - -import pytest -import time -import os -import json -from fastapi.testclient import TestClient -from pathlib import Path - - -@pytest.fixture -def dev_client(tmp_path): - """Create test client with DEV mode and clean database.""" - # Set DEV mode environment - os.environ["DEPLOYMENT_MODE"] = "DEV" - os.environ["PRESERVE_DEV_DATA"] = "false" - # Disable auto-download - we'll pre-populate test data - os.environ["AUTO_DOWNLOAD_PRICE_DATA"] = "false" - - # Import after setting environment - from api.main import create_app - from api.database import initialize_dev_database, get_db_path, get_db_connection - - # Create dev database - db_path = str(tmp_path / "test_trading.db") - dev_db_path = get_db_path(db_path) - initialize_dev_database(dev_db_path) - - # Pre-populate price data for test dates to avoid needing API key - _populate_test_price_data(dev_db_path) - - # Create test config with mock model - test_config = tmp_path / "test_config.json" - test_config.write_text(json.dumps({ - "agent_type": "BaseAgent", - "date_range": {"init_date": "2025-01-16", "end_date": "2025-01-17"}, - "models": [ - { - "name": "Test Mock Model", - "basemodel": "mock/test-trader", - "signature": "test-mock", - "enabled": True - } - ], - "agent_config": { - "max_steps": 10, - "initial_cash": 10000.0, - "max_retries": 1, - "base_delay": 0.1 - }, - "log_config": { - "log_path": str(tmp_path / "dev_agent_data") - } - })) - - # Create app with test config - app = create_app(db_path=dev_db_path, config_path=str(test_config)) - - # IMPORTANT: Do NOT set test_mode=True to allow worker to actually run - # This is an integration test - we want the full flow - - client = TestClient(app) - client.db_path = dev_db_path - client.config_path = str(test_config) - - yield client - - # Cleanup - os.environ.pop("DEPLOYMENT_MODE", None) - os.environ.pop("PRESERVE_DEV_DATA", None) - os.environ.pop("AUTO_DOWNLOAD_PRICE_DATA", None) - - -def _populate_test_price_data(db_path: str): - """ - Pre-populate test price data in database. - - This avoids needing Alpha Vantage API key for integration tests. - Adds mock price data for all NASDAQ 100 stocks on test dates. - """ - from api.database import get_db_connection - from datetime import datetime - - # All NASDAQ 100 symbols (must match configs/nasdaq100_symbols.json) - symbols = [ - "NVDA", "MSFT", "AAPL", "GOOG", "GOOGL", "AMZN", "META", "AVGO", "TSLA", - "NFLX", "PLTR", "COST", "ASML", "AMD", "CSCO", "AZN", "TMUS", "MU", "LIN", - "PEP", "SHOP", "APP", "INTU", "AMAT", "LRCX", "PDD", "QCOM", "ARM", "INTC", - "BKNG", "AMGN", "TXN", "ISRG", "GILD", "KLAC", "PANW", "ADBE", "HON", - "CRWD", "CEG", "ADI", "ADP", "DASH", "CMCSA", "VRTX", "MELI", "SBUX", - "CDNS", "ORLY", "SNPS", "MSTR", "MDLZ", "ABNB", "MRVL", "CTAS", "TRI", - "MAR", "MNST", "CSX", "ADSK", "PYPL", "FTNT", "AEP", "WDAY", "REGN", "ROP", - "NXPI", "DDOG", "AXON", "ROST", "IDXX", "EA", "PCAR", "FAST", "EXC", "TTWO", - "XEL", "ZS", "PAYX", "WBD", "BKR", "CPRT", "CCEP", "FANG", "TEAM", "CHTR", - "KDP", "MCHP", "GEHC", "VRSK", "CTSH", "CSGP", "KHC", "ODFL", "DXCM", "TTD", - "ON", "BIIB", "LULU", "CDW", "GFS", "QQQ" - ] - - # Test dates - test_dates = ["2025-01-16", "2025-01-17"] - - conn = get_db_connection(db_path) - cursor = conn.cursor() - - for symbol in symbols: - for date in test_dates: - # Insert mock price data - cursor.execute(""" - INSERT OR IGNORE INTO price_data - (symbol, date, open, high, low, close, volume, created_at) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - symbol, - date, - 100.0, # open - 105.0, # high - 98.0, # low - 102.0, # close - 1000000, # volume - datetime.utcnow().isoformat() + "Z" - )) - - # Add coverage record - cursor.execute(""" - INSERT OR IGNORE INTO price_data_coverage - (symbol, start_date, end_date, downloaded_at, source) - VALUES (?, ?, ?, ?, ?) - """, ( - symbol, - "2025-01-16", - "2025-01-17", - datetime.utcnow().isoformat() + "Z", - "test_fixture" - )) - - conn.commit() - conn.close() - - -@pytest.mark.integration -@pytest.mark.skipif( - os.getenv("SKIP_INTEGRATION_TESTS") == "true", - reason="Skipping integration tests that require full environment" -) -class TestReasoningLogsE2E: - """End-to-end tests for reasoning logs feature.""" - - def test_simulation_stores_reasoning_logs(self, dev_client): - """ - Test that running a simulation creates reasoning logs in database. - - This is the main E2E test that verifies: - 1. Simulation can be triggered - 2. Worker processes the job - 3. Trading sessions are created - 4. Reasoning logs are stored - 5. GET /reasoning returns the data - - NOTE: This test requires MCP services to be running. It will skip if services are unavailable. - """ - # Skip if MCP services not available - try: - from agent.base_agent.base_agent import BaseAgent - except ImportError as e: - pytest.skip(f"Cannot import BaseAgent: {e}") - - # Skip test - requires MCP services running - # This is a known limitation for integration tests - pytest.skip( - "Test requires MCP services running. " - "Use test_reasoning_api_with_mocked_data() instead for automated testing." - ) - - def test_reasoning_api_with_mocked_data(self, dev_client): - """ - Test GET /reasoning API with pre-populated database data. - - This test verifies the API layer works correctly without requiring - a full simulation run or MCP services. - """ - from api.database import get_db_connection - from datetime import datetime - - # Populate test data directly in database - conn = get_db_connection(dev_client.db_path) - cursor = conn.cursor() - - # Create a job - job_id = "test-job-123" - cursor.execute(""" - INSERT INTO jobs (job_id, config_path, status, date_range, models, created_at) - VALUES (?, ?, ?, ?, ?, ?) - """, (job_id, "test_config.json", "completed", "2025-01-16", '["test-mock"]', - datetime.utcnow().isoformat() + "Z")) - - # Create a trading session - cursor.execute(""" - INSERT INTO trading_sessions - (job_id, date, model, session_summary, started_at, completed_at, total_messages) - VALUES (?, ?, ?, ?, ?, ?, ?) - """, ( - job_id, - "2025-01-16", - "test-mock", - "Analyzed market conditions and executed buy order for AAPL", - datetime.utcnow().isoformat() + "Z", - datetime.utcnow().isoformat() + "Z", - 5 - )) - - session_id = cursor.lastrowid - - # Create reasoning logs - messages = [ - { - "session_id": session_id, - "message_index": 0, - "role": "user", - "content": "You are a trading agent. Analyze the market...", - "summary": None, - "tool_name": None, - "tool_input": None, - "timestamp": datetime.utcnow().isoformat() + "Z" - }, - { - "session_id": session_id, - "message_index": 1, - "role": "assistant", - "content": "I will analyze the market and make trading decisions...", - "summary": "Agent analyzed market conditions", - "tool_name": None, - "tool_input": None, - "timestamp": datetime.utcnow().isoformat() + "Z" - }, - { - "session_id": session_id, - "message_index": 2, - "role": "tool", - "content": "Price of AAPL: $150.00", - "summary": None, - "tool_name": "get_price", - "tool_input": json.dumps({"symbol": "AAPL"}), - "timestamp": datetime.utcnow().isoformat() + "Z" - }, - { - "session_id": session_id, - "message_index": 3, - "role": "assistant", - "content": "Based on analysis, I will buy AAPL...", - "summary": "Agent decided to buy AAPL", - "tool_name": None, - "tool_input": None, - "timestamp": datetime.utcnow().isoformat() + "Z" - }, - { - "session_id": session_id, - "message_index": 4, - "role": "tool", - "content": "Successfully bought 10 shares of AAPL", - "summary": None, - "tool_name": "buy", - "tool_input": json.dumps({"symbol": "AAPL", "amount": 10}), - "timestamp": datetime.utcnow().isoformat() + "Z" - } - ] - - for msg in messages: - cursor.execute(""" - INSERT INTO reasoning_logs - (session_id, message_index, role, content, summary, tool_name, tool_input, timestamp) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, ( - msg["session_id"], msg["message_index"], msg["role"], - msg["content"], msg["summary"], msg["tool_name"], - msg["tool_input"], msg["timestamp"] - )) - - # Create positions linked to session - cursor.execute(""" - INSERT INTO positions - (job_id, date, model, action_id, action_type, symbol, amount, price, cash, portfolio_value, - daily_profit, daily_return_pct, created_at, session_id) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - job_id, "2025-01-16", "test-mock", 1, "buy", "AAPL", 10, 150.0, - 8500.0, 10000.0, 0.0, 0.0, datetime.utcnow().isoformat() + "Z", session_id - )) - - conn.commit() - conn.close() - - # Query reasoning endpoint (summary mode) - reasoning_response = dev_client.get(f"/reasoning?job_id={job_id}") - - assert reasoning_response.status_code == 200 - reasoning_data = reasoning_response.json() - - # Verify response structure - assert "sessions" in reasoning_data - assert "count" in reasoning_data - assert reasoning_data["count"] == 1 - assert reasoning_data["is_dev_mode"] is True - - # Verify trading session structure - session = reasoning_data["sessions"][0] - assert session["session_id"] == session_id - assert session["job_id"] == job_id - assert session["date"] == "2025-01-16" - assert session["model"] == "test-mock" - assert session["session_summary"] == "Analyzed market conditions and executed buy order for AAPL" - assert session["total_messages"] == 5 - - # Verify positions are linked to session - assert "positions" in session - assert len(session["positions"]) == 1 - position = session["positions"][0] - assert position["action_id"] == 1 - assert position["action_type"] == "buy" - assert position["symbol"] == "AAPL" - assert position["amount"] == 10 - assert position["price"] == 150.0 - assert position["cash_after"] == 8500.0 - assert position["portfolio_value"] == 10000.0 - - # Verify conversation is NOT included in summary mode - assert session["conversation"] is None - - # Query again with full conversation - full_response = dev_client.get( - f"/reasoning?job_id={job_id}&include_full_conversation=true" - ) - assert full_response.status_code == 200 - full_data = full_response.json() - session_full = full_data["sessions"][0] - - # Verify full conversation is included - assert session_full["conversation"] is not None - assert len(session_full["conversation"]) == 5 - - # Verify conversation messages - conv = session_full["conversation"] - assert conv[0]["role"] == "user" - assert conv[0]["message_index"] == 0 - assert conv[0]["summary"] is None # User messages don't have summaries - - assert conv[1]["role"] == "assistant" - assert conv[1]["message_index"] == 1 - assert conv[1]["summary"] == "Agent analyzed market conditions" - - assert conv[2]["role"] == "tool" - assert conv[2]["message_index"] == 2 - assert conv[2]["tool_name"] == "get_price" - assert conv[2]["tool_input"] == json.dumps({"symbol": "AAPL"}) - - assert conv[3]["role"] == "assistant" - assert conv[3]["message_index"] == 3 - assert conv[3]["summary"] == "Agent decided to buy AAPL" - - assert conv[4]["role"] == "tool" - assert conv[4]["message_index"] == 4 - assert conv[4]["tool_name"] == "buy" - - def test_reasoning_endpoint_date_filter(self, dev_client): - """Test GET /reasoning date filter works correctly.""" - # This test requires actual data - skip if no data available - response = dev_client.get("/reasoning?date=2025-01-16") - - # Should either return 404 (no data) or 200 with filtered data - assert response.status_code in [200, 404] - - if response.status_code == 200: - data = response.json() - for session in data["sessions"]: - assert session["date"] == "2025-01-16" - - def test_reasoning_endpoint_model_filter(self, dev_client): - """Test GET /reasoning model filter works correctly.""" - response = dev_client.get("/reasoning?model=test-mock") - - # Should either return 404 (no data) or 200 with filtered data - assert response.status_code in [200, 404] - - if response.status_code == 200: - data = response.json() - for session in data["sessions"]: - assert session["model"] == "test-mock" - - def test_reasoning_endpoint_combined_filters(self, dev_client): - """Test GET /reasoning with multiple filters.""" - response = dev_client.get( - "/reasoning?date=2025-01-16&model=test-mock" - ) - - # Should either return 404 (no data) or 200 with filtered data - assert response.status_code in [200, 404] - - if response.status_code == 200: - data = response.json() - for session in data["sessions"]: - assert session["date"] == "2025-01-16" - assert session["model"] == "test-mock" - - def test_reasoning_endpoint_invalid_date_format(self, dev_client): - """Test GET /reasoning rejects invalid date format.""" - response = dev_client.get("/reasoning?date=invalid-date") - - assert response.status_code == 400 - assert "Invalid date format" in response.json()["detail"] - - def test_reasoning_endpoint_no_sessions_found(self, dev_client): - """Test GET /reasoning returns 404 when no sessions match filters.""" - response = dev_client.get("/reasoning?job_id=nonexistent-job-id") - - assert response.status_code == 404 - assert "No trading sessions found" in response.json()["detail"] - - def test_reasoning_summaries_vs_full_conversation(self, dev_client): - """ - Test difference between summary mode and full conversation mode. - - Verifies: - - Default mode does not include conversation - - include_full_conversation=true includes full conversation - - Full conversation has more data than summary - """ - # This test needs actual data - skip if none available - response_summary = dev_client.get("/reasoning") - - if response_summary.status_code == 404: - pytest.skip("No reasoning data available for testing") - - assert response_summary.status_code == 200 - summary_data = response_summary.json() - - if summary_data["count"] == 0: - pytest.skip("No reasoning data available for testing") - - # Get full conversation - response_full = dev_client.get("/reasoning?include_full_conversation=true") - assert response_full.status_code == 200 - full_data = response_full.json() - - # Compare first session - session_summary = summary_data["sessions"][0] - session_full = full_data["sessions"][0] - - # Summary mode should not have conversation - assert session_summary["conversation"] is None - - # Full mode should have conversation - assert session_full["conversation"] is not None - assert len(session_full["conversation"]) > 0 - - # Session metadata should be the same - assert session_summary["session_id"] == session_full["session_id"] - assert session_summary["job_id"] == session_full["job_id"] - assert session_summary["date"] == session_full["date"] - assert session_summary["model"] == session_full["model"] - - -@pytest.mark.integration -class TestReasoningAPIValidation: - """Test GET /reasoning endpoint validation and error handling.""" - - def test_reasoning_endpoint_deployment_mode_flag(self, dev_client): - """Test that reasoning endpoint includes deployment mode info.""" - response = dev_client.get("/reasoning") - - # Even 404 should not be returned - endpoint should work - # Only 404 if no data matches filters - if response.status_code == 200: - data = response.json() - assert "deployment_mode" in data - assert "is_dev_mode" in data - assert data["is_dev_mode"] is True - - def test_reasoning_endpoint_returns_pydantic_models(self, dev_client): - """Test that endpoint returns properly validated response models.""" - # This is implicitly tested by FastAPI/TestClient - # If response doesn't match ReasoningResponse model, will raise error - - response = dev_client.get("/reasoning") - - # Should either return 404 or valid response - assert response.status_code in [200, 404] - - if response.status_code == 200: - data = response.json() - - # Verify top-level structure - assert "sessions" in data - assert "count" in data - assert isinstance(data["sessions"], list) - assert isinstance(data["count"], int) - - # If sessions exist, verify structure - if data["count"] > 0: - session = data["sessions"][0] - - # Required fields - assert "session_id" in session - assert "job_id" in session - assert "date" in session - assert "model" in session - assert "started_at" in session - assert "positions" in session - - # Positions structure - if len(session["positions"]) > 0: - position = session["positions"][0] - assert "action_id" in position - assert "cash_after" in position - assert "portfolio_value" in position diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index 83700ff..0b42009 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -104,16 +104,15 @@ class TestSchemaInitialization: tables = [row[0] for row in cursor.fetchall()] expected_tables = [ + 'actions', 'holdings', 'job_details', 'jobs', - 'positions', - 'reasoning_logs', 'tool_usage', 'price_data', 'price_data_coverage', 'simulation_runs', - 'trading_sessions' # Added in reasoning logs feature + 'trading_days' # New day-centric schema ] assert sorted(tables) == sorted(expected_tables) @@ -149,19 +148,19 @@ class TestSchemaInitialization: conn.close() - def test_initialize_database_creates_positions_table(self, clean_db): - """Should create positions table with correct schema.""" + def test_initialize_database_creates_trading_days_table(self, clean_db): + """Should create trading_days table with correct schema.""" conn = get_db_connection(clean_db) cursor = conn.cursor() - cursor.execute("PRAGMA table_info(positions)") + cursor.execute("PRAGMA table_info(trading_days)") columns = {row[1]: row[2] for row in cursor.fetchall()} required_columns = [ - 'id', 'job_id', 'date', 'model', 'action_id', 'action_type', - 'symbol', 'amount', 'price', 'cash', 'portfolio_value', - 'daily_profit', 'daily_return_pct', 'cumulative_profit', - 'cumulative_return_pct', 'created_at' + 'id', 'job_id', 'date', 'model', 'starting_cash', 'ending_cash', + 'starting_portfolio_value', 'ending_portfolio_value', + 'daily_profit', 'daily_return_pct', 'days_since_last_trading', + 'total_actions', 'reasoning_summary', 'reasoning_full', 'created_at' ] for col_name in required_columns: @@ -188,20 +187,9 @@ class TestSchemaInitialization: 'idx_job_details_job_id', 'idx_job_details_status', 'idx_job_details_unique', - 'idx_positions_job_id', - 'idx_positions_date', - 'idx_positions_model', - 'idx_positions_date_model', - 'idx_positions_unique', - 'idx_positions_session_id', # Link positions to trading sessions - 'idx_holdings_position_id', - 'idx_holdings_symbol', - 'idx_sessions_job_id', # Trading sessions indexes - 'idx_sessions_date', - 'idx_sessions_model', - 'idx_sessions_unique', - 'idx_reasoning_logs_session_id', # Reasoning logs now linked to sessions - 'idx_reasoning_logs_unique', + 'idx_trading_days_lookup', # Compound index in new schema + 'idx_holdings_day', + 'idx_actions_day', 'idx_tool_usage_job_date_model' ] @@ -274,8 +262,8 @@ class TestForeignKeyConstraints: conn.close() - def test_cascade_delete_positions(self, clean_db, sample_job_data, sample_position_data): - """Should cascade delete positions when job is deleted.""" + def test_cascade_delete_trading_days(self, clean_db, sample_job_data): + """Should cascade delete trading_days when job is deleted.""" conn = get_db_connection(clean_db) cursor = conn.cursor() @@ -292,14 +280,19 @@ class TestForeignKeyConstraints: sample_job_data["created_at"] )) - # Insert position + # Insert trading_day cursor.execute(""" - INSERT INTO positions ( - job_id, date, model, action_id, action_type, symbol, amount, price, - cash, portfolio_value, daily_profit, daily_return_pct, - cumulative_profit, cumulative_return_pct, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, tuple(sample_position_data.values())) + INSERT INTO trading_days ( + job_id, date, model, starting_cash, ending_cash, + starting_portfolio_value, ending_portfolio_value, + daily_profit, daily_return_pct, days_since_last_trading, + total_actions, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + sample_job_data["job_id"], "2025-01-16", "test-model", + 10000.0, 9500.0, 10000.0, 9500.0, + -500.0, -5.0, 0, 1, "2025-01-16T10:00:00Z" + )) conn.commit() @@ -307,14 +300,14 @@ class TestForeignKeyConstraints: cursor.execute("DELETE FROM jobs WHERE job_id = ?", (sample_job_data["job_id"],)) conn.commit() - # Verify position was cascade deleted - cursor.execute("SELECT COUNT(*) FROM positions WHERE job_id = ?", (sample_job_data["job_id"],)) + # Verify trading_day was cascade deleted + cursor.execute("SELECT COUNT(*) FROM trading_days WHERE job_id = ?", (sample_job_data["job_id"],)) assert cursor.fetchone()[0] == 0 conn.close() - def test_cascade_delete_holdings(self, clean_db, sample_job_data, sample_position_data): - """Should cascade delete holdings when position is deleted.""" + def test_cascade_delete_holdings(self, clean_db, sample_job_data): + """Should cascade delete holdings when trading_day is deleted.""" conn = get_db_connection(clean_db) cursor = conn.cursor() @@ -331,35 +324,40 @@ class TestForeignKeyConstraints: sample_job_data["created_at"] )) - # Insert position + # Insert trading_day cursor.execute(""" - INSERT INTO positions ( - job_id, date, model, action_id, action_type, symbol, amount, price, - cash, portfolio_value, daily_profit, daily_return_pct, - cumulative_profit, cumulative_return_pct, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, tuple(sample_position_data.values())) + INSERT INTO trading_days ( + job_id, date, model, starting_cash, ending_cash, + starting_portfolio_value, ending_portfolio_value, + daily_profit, daily_return_pct, days_since_last_trading, + total_actions, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + sample_job_data["job_id"], "2025-01-16", "test-model", + 10000.0, 9500.0, 10000.0, 9500.0, + -500.0, -5.0, 0, 1, "2025-01-16T10:00:00Z" + )) - position_id = cursor.lastrowid + trading_day_id = cursor.lastrowid # Insert holding cursor.execute(""" - INSERT INTO holdings (position_id, symbol, quantity) + INSERT INTO holdings (trading_day_id, symbol, quantity) VALUES (?, ?, ?) - """, (position_id, "AAPL", 10)) + """, (trading_day_id, "AAPL", 10)) conn.commit() # Verify holding exists - cursor.execute("SELECT COUNT(*) FROM holdings WHERE position_id = ?", (position_id,)) + cursor.execute("SELECT COUNT(*) FROM holdings WHERE trading_day_id = ?", (trading_day_id,)) assert cursor.fetchone()[0] == 1 - # Delete position - cursor.execute("DELETE FROM positions WHERE id = ?", (position_id,)) + # Delete trading_day + cursor.execute("DELETE FROM trading_days WHERE id = ?", (trading_day_id,)) conn.commit() # Verify holding was cascade deleted - cursor.execute("SELECT COUNT(*) FROM holdings WHERE position_id = ?", (position_id,)) + cursor.execute("SELECT COUNT(*) FROM holdings WHERE trading_day_id = ?", (trading_day_id,)) assert cursor.fetchone()[0] == 0 conn.close() @@ -374,11 +372,17 @@ class TestUtilityFunctions: # Initialize database initialize_database(test_db_path) + # Also initialize new schema + from api.database import Database + db = Database(test_db_path) + db.connection.close() + # Verify tables exist conn = get_db_connection(test_db_path) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") - assert cursor.fetchone()[0] == 10 # Updated to reflect all tables including trading_sessions + # New schema: jobs, job_details, trading_days, holdings, actions, tool_usage, price_data, price_data_coverage, simulation_runs (9 tables) + assert cursor.fetchone()[0] == 9 conn.close() # Drop all tables @@ -410,9 +414,9 @@ class TestUtilityFunctions: assert "database_size_mb" in stats assert stats["jobs"] == 0 assert stats["job_details"] == 0 - assert stats["positions"] == 0 + assert stats["trading_days"] == 0 assert stats["holdings"] == 0 - assert stats["reasoning_logs"] == 0 + assert stats["actions"] == 0 assert stats["tool_usage"] == 0 def test_get_database_stats_with_data(self, clean_db, sample_job_data): @@ -486,67 +490,6 @@ class TestSchemaMigration: # Clean up after test - drop all tables so we don't affect other tests drop_all_tables(test_db_path) - def test_migration_adds_simulation_run_id_column(self, test_db_path): - """Should add simulation_run_id column to existing positions table without it.""" - from api.database import drop_all_tables - - # Start with a clean slate - drop_all_tables(test_db_path) - - # Create database without simulation_run_id column (simulate old schema) - conn = get_db_connection(test_db_path) - cursor = conn.cursor() - - # Create jobs table first (for foreign key) - cursor.execute(""" - CREATE TABLE jobs ( - job_id TEXT PRIMARY KEY, - config_path TEXT NOT NULL, - status TEXT NOT NULL CHECK(status IN ('pending', 'downloading_data', 'running', 'completed', 'partial', 'failed')), - date_range TEXT NOT NULL, - models TEXT NOT NULL, - created_at TEXT NOT NULL - ) - """) - - # Create positions table without simulation_run_id column (old schema) - cursor.execute(""" - CREATE TABLE positions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - job_id TEXT NOT NULL, - date TEXT NOT NULL, - model TEXT NOT NULL, - action_id INTEGER NOT NULL, - cash REAL NOT NULL, - portfolio_value REAL NOT NULL, - created_at TEXT NOT NULL, - FOREIGN KEY (job_id) REFERENCES jobs(job_id) ON DELETE CASCADE - ) - """) - conn.commit() - - # Verify simulation_run_id column doesn't exist - cursor.execute("PRAGMA table_info(positions)") - columns = [row[1] for row in cursor.fetchall()] - assert 'simulation_run_id' not in columns - - conn.close() - - # Run initialize_database which should trigger migration - initialize_database(test_db_path) - - # Verify simulation_run_id column was added - conn = get_db_connection(test_db_path) - cursor = conn.cursor() - cursor.execute("PRAGMA table_info(positions)") - columns = [row[1] for row in cursor.fetchall()] - assert 'simulation_run_id' in columns - - conn.close() - - # Clean up after test - drop all tables so we don't affect other tests - drop_all_tables(test_db_path) - @pytest.mark.unit class TestCheckConstraints: @@ -586,8 +529,8 @@ class TestCheckConstraints: conn.close() - def test_positions_action_type_constraint(self, clean_db, sample_job_data): - """Should reject invalid action_type values.""" + def test_actions_action_type_constraint(self, clean_db, sample_job_data): + """Should reject invalid action_type values in actions table.""" conn = get_db_connection(clean_db) cursor = conn.cursor() @@ -597,13 +540,29 @@ class TestCheckConstraints: VALUES (?, ?, ?, ?, ?, ?) """, tuple(sample_job_data.values())) - # Try to insert position with invalid action_type + # Insert trading_day + cursor.execute(""" + INSERT INTO trading_days ( + job_id, date, model, starting_cash, ending_cash, + starting_portfolio_value, ending_portfolio_value, + daily_profit, daily_return_pct, days_since_last_trading, + total_actions, created_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, ( + sample_job_data["job_id"], "2025-01-16", "test-model", + 10000.0, 9500.0, 10000.0, 9500.0, + -500.0, -5.0, 0, 1, "2025-01-16T10:00:00Z" + )) + + trading_day_id = cursor.lastrowid + + # Try to insert action with invalid action_type with pytest.raises(sqlite3.IntegrityError, match="CHECK constraint failed"): cursor.execute(""" - INSERT INTO positions ( - job_id, date, model, action_id, action_type, cash, portfolio_value, created_at - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, (sample_job_data["job_id"], "2025-01-16", "gpt-5", 1, "invalid_action", 10000, 10000, "2025-01-16T00:00:00Z")) + INSERT INTO actions ( + trading_day_id, action_type, symbol, quantity, price, created_at + ) VALUES (?, ?, ?, ?, ?, ?) + """, (trading_day_id, "invalid_action", "AAPL", 10, 150.0, "2025-01-16T10:00:00Z")) conn.close() diff --git a/tests/unit/test_position_tracking_bugs.py b/tests/unit/test_position_tracking_bugs.py deleted file mode 100644 index 8ee942d..0000000 --- a/tests/unit/test_position_tracking_bugs.py +++ /dev/null @@ -1,309 +0,0 @@ -""" -Tests demonstrating position tracking bugs before fix. - -These tests should FAIL before implementing fixes, and PASS after. -""" - -import pytest -from datetime import datetime -from api.database import get_db_connection, initialize_database -from api.job_manager import JobManager -from agent_tools.tool_trade import _buy_impl -from tools.price_tools import add_no_trade_record_to_db -import os -from pathlib import Path - - -@pytest.fixture(scope="function") -def test_db_with_prices(): - """ - Create test database with price data using production database path. - - Note: Since agent_tools hardcode db_path="data/jobs.db", we must use - the production database path for integration testing. - """ - # Use production database path - db_path = "data/jobs.db" - - # Ensure directory exists - Path(db_path).parent.mkdir(parents=True, exist_ok=True) - - # Initialize database - initialize_database(db_path) - - # Clear existing test data if any - conn = get_db_connection(db_path) - cursor = conn.cursor() - - # Clean up any existing test data (in correct order for foreign keys) - cursor.execute("DELETE FROM holdings WHERE position_id IN (SELECT id FROM positions WHERE model = 'claude-sonnet-4.5')") - cursor.execute("DELETE FROM positions WHERE model = 'claude-sonnet-4.5'") - cursor.execute("DELETE FROM trading_sessions WHERE model = 'claude-sonnet-4.5'") - cursor.execute("DELETE FROM job_details WHERE model = 'claude-sonnet-4.5'") - cursor.execute("DELETE FROM price_data WHERE symbol = 'NVDA' AND date IN ('2025-10-06', '2025-10-07')") - - # Mark any pending/running jobs as completed to allow new test jobs - cursor.execute("UPDATE jobs SET status = 'completed' WHERE status IN ('pending', 'running')") - - # Insert price data for testing - # 2025-10-06 prices - cursor.execute(""" - INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at) - VALUES ('NVDA', '2025-10-06', 185.5, 190.0, 185.0, 188.0, 1000000, ?) - """, (datetime.utcnow().isoformat() + "Z",)) - - # 2025-10-07 prices (Monday after weekend) - cursor.execute(""" - INSERT INTO price_data (symbol, date, open, high, low, close, volume, created_at) - VALUES ('NVDA', '2025-10-07', 186.23, 190.0, 186.0, 189.0, 1000000, ?) - """, (datetime.utcnow().isoformat() + "Z",)) - - conn.commit() - conn.close() - - yield db_path - - # Cleanup after test - conn = get_db_connection(db_path) - cursor = conn.cursor() - cursor.execute("DELETE FROM holdings WHERE position_id IN (SELECT id FROM positions WHERE model = 'claude-sonnet-4.5')") - cursor.execute("DELETE FROM positions WHERE model = 'claude-sonnet-4.5'") - cursor.execute("DELETE FROM trading_sessions WHERE model = 'claude-sonnet-4.5'") - cursor.execute("DELETE FROM job_details WHERE model = 'claude-sonnet-4.5'") - cursor.execute("DELETE FROM price_data WHERE symbol = 'NVDA' AND date IN ('2025-10-06', '2025-10-07')") - - # Mark any pending/running jobs as completed - cursor.execute("UPDATE jobs SET status = 'completed' WHERE status IN ('pending', 'running')") - - conn.commit() - conn.close() - - -@pytest.mark.unit -class TestPositionTrackingBugs: - """Tests demonstrating the three critical bugs.""" - - def test_cash_not_reset_between_days(self, test_db_with_prices): - """ - Bug #1: Cash should carry over from previous day, not reset to initial value. - - Scenario: - - Day 1: Start with $10,000, buy 5 NVDA @ $185.50 = $927.50, cash left = $9,072.50 - - Day 2: Should start with $9,072.50 cash, not $10,000 - """ - # Create job - manager = JobManager(db_path=test_db_with_prices) - job_id = manager.create_job( - config_path="configs/test.json", - date_range=["2025-10-06", "2025-10-07"], - models=["claude-sonnet-4.5"] - ) - - # Day 1: Initial position (action_id=0) - conn = get_db_connection(test_db_with_prices) - cursor = conn.cursor() - - cursor.execute(""" - INSERT INTO trading_sessions (job_id, date, model, started_at) - VALUES (?, ?, ?, ?) - """, (job_id, "2025-10-06", "claude-sonnet-4.5", datetime.utcnow().isoformat() + "Z")) - session_id_day1 = cursor.lastrowid - - cursor.execute(""" - INSERT INTO positions ( - job_id, date, model, action_id, action_type, - cash, portfolio_value, session_id, created_at - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - job_id, "2025-10-06", "claude-sonnet-4.5", 0, "no_trade", - 10000.0, 10000.0, session_id_day1, datetime.utcnow().isoformat() + "Z" - )) - - conn.commit() - conn.close() - - # Day 1: Buy 5 NVDA @ $185.50 - result = _buy_impl( - symbol="NVDA", - amount=5, - signature="claude-sonnet-4.5", - today_date="2025-10-06", - job_id=job_id, - session_id=session_id_day1 - ) - - assert "error" not in result - assert result["CASH"] == 9072.5 # 10000 - (5 * 185.5) - - # Day 2: Create new session - conn = get_db_connection(test_db_with_prices) - cursor = conn.cursor() - - cursor.execute(""" - INSERT INTO trading_sessions (job_id, date, model, started_at) - VALUES (?, ?, ?, ?) - """, (job_id, "2025-10-07", "claude-sonnet-4.5", datetime.utcnow().isoformat() + "Z")) - session_id_day2 = cursor.lastrowid - conn.commit() - conn.close() - - # Day 2: Check starting cash (should be $9,072.50, not $10,000) - from agent_tools.tool_trade import get_current_position_from_db - - position, next_action_id = get_current_position_from_db( - job_id=job_id, - model="claude-sonnet-4.5", - date="2025-10-07" - ) - - # BUG: This will fail before fix - cash resets to $10,000 or $0 - assert position["CASH"] == 9072.5, f"Expected cash $9,072.50 but got ${position['CASH']}" - assert position["NVDA"] == 5, f"Expected 5 NVDA shares but got {position.get('NVDA', 0)}" - - def test_positions_persist_over_weekend(self, test_db_with_prices): - """ - Bug #2: Positions should persist over non-trading days (weekends). - - Scenario: - - Friday 2025-10-06: Buy 5 NVDA - - Monday 2025-10-07: Should still have 5 NVDA - """ - # Create job - manager = JobManager(db_path=test_db_with_prices) - job_id = manager.create_job( - config_path="configs/test.json", - date_range=["2025-10-06", "2025-10-07"], - models=["claude-sonnet-4.5"] - ) - - # Friday: Initial position + buy - conn = get_db_connection(test_db_with_prices) - cursor = conn.cursor() - - cursor.execute(""" - INSERT INTO trading_sessions (job_id, date, model, started_at) - VALUES (?, ?, ?, ?) - """, (job_id, "2025-10-06", "claude-sonnet-4.5", datetime.utcnow().isoformat() + "Z")) - session_id = cursor.lastrowid - - cursor.execute(""" - INSERT INTO positions ( - job_id, date, model, action_id, action_type, - cash, portfolio_value, session_id, created_at - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - job_id, "2025-10-06", "claude-sonnet-4.5", 0, "no_trade", - 10000.0, 10000.0, session_id, datetime.utcnow().isoformat() + "Z" - )) - - conn.commit() - conn.close() - - _buy_impl( - symbol="NVDA", - amount=5, - signature="claude-sonnet-4.5", - today_date="2025-10-06", - job_id=job_id, - session_id=session_id - ) - - # Monday: Check positions persist - from agent_tools.tool_trade import get_current_position_from_db - - position, _ = get_current_position_from_db( - job_id=job_id, - model="claude-sonnet-4.5", - date="2025-10-07" - ) - - # BUG: This will fail before fix - positions lost, holdings=[] - assert "NVDA" in position, "NVDA position should persist over weekend" - assert position["NVDA"] == 5, f"Expected 5 NVDA shares but got {position.get('NVDA', 0)}" - - def test_profit_calculation_accuracy(self, test_db_with_prices): - """ - Bug #3: Profit should reflect actual gains/losses, not show trades as losses. - - Scenario: - - Start with $10,000 cash, portfolio value = $10,000 - - Buy 5 NVDA @ $185.50 = $927.50 - - New position: cash = $9,072.50, 5 NVDA worth $927.50 - - Portfolio value = $9,072.50 + $927.50 = $10,000 (unchanged) - - Expected profit = $0 (no price change yet, just traded) - - Current bug: Shows profit = -$927.50 or similar (treating trade as loss) - """ - # Create job - manager = JobManager(db_path=test_db_with_prices) - job_id = manager.create_job( - config_path="configs/test.json", - date_range=["2025-10-06"], - models=["claude-sonnet-4.5"] - ) - - # Create session and initial position - conn = get_db_connection(test_db_with_prices) - cursor = conn.cursor() - - cursor.execute(""" - INSERT INTO trading_sessions (job_id, date, model, started_at) - VALUES (?, ?, ?, ?) - """, (job_id, "2025-10-06", "claude-sonnet-4.5", datetime.utcnow().isoformat() + "Z")) - session_id = cursor.lastrowid - - cursor.execute(""" - INSERT INTO positions ( - job_id, date, model, action_id, action_type, - cash, portfolio_value, daily_profit, daily_return_pct, - session_id, created_at - ) - VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) - """, ( - job_id, "2025-10-06", "claude-sonnet-4.5", 0, "no_trade", - 10000.0, 10000.0, None, None, - session_id, datetime.utcnow().isoformat() + "Z" - )) - - conn.commit() - conn.close() - - # Buy 5 NVDA @ $185.50 - _buy_impl( - symbol="NVDA", - amount=5, - signature="claude-sonnet-4.5", - today_date="2025-10-06", - job_id=job_id, - session_id=session_id - ) - - # Check profit calculation - conn = get_db_connection(test_db_with_prices) - cursor = conn.cursor() - - cursor.execute(""" - SELECT portfolio_value, daily_profit, daily_return_pct - FROM positions - WHERE job_id = ? AND model = ? AND date = ? AND action_id = 1 - """, (job_id, "claude-sonnet-4.5", "2025-10-06")) - - row = cursor.fetchone() - conn.close() - - portfolio_value = row[0] - daily_profit = row[1] - daily_return_pct = row[2] - - # Portfolio value should be $10,000 (cash $9,072.50 + 5 NVDA @ $185.50) - assert abs(portfolio_value - 10000.0) < 0.01, \ - f"Expected portfolio value $10,000 but got ${portfolio_value}" - - # BUG: This will fail before fix - shows profit as negative or zero when should be zero - # Profit should be $0 (no price movement, just traded) - assert abs(daily_profit) < 0.01, \ - f"Expected profit $0 (no price change) but got ${daily_profit}" - assert abs(daily_return_pct) < 0.01, \ - f"Expected return 0% but got {daily_return_pct}%"