diff --git a/agent/base_agent/base_agent.py b/agent/base_agent/base_agent.py index a6be436..6d88bda 100644 --- a/agent/base_agent/base_agent.py +++ b/agent/base_agent/base_agent.py @@ -6,6 +6,7 @@ Encapsulates core functionality including MCP tool management, AI agent creation import os import json import asyncio +import time from datetime import datetime, timedelta from typing import Dict, List, Optional, Any from pathlib import Path @@ -30,6 +31,8 @@ from tools.deployment_config import ( get_deployment_mode ) from agent.context_injector import ContextInjector +from agent.pnl_calculator import DailyPnLCalculator +from agent.reasoning_summarizer import ReasoningSummarizer # Load environment variables load_dotenv() @@ -135,6 +138,9 @@ class BaseAgent: # Conversation history for reasoning logs self.conversation_history: List[Dict[str, Any]] = [] + + # P&L calculator + self.pnl_calculator = DailyPnLCalculator(initial_cash=initial_cash) def _get_default_mcp_config(self) -> Dict[str, Dict[str, Any]]: """Get default MCP configuration""" @@ -255,6 +261,93 @@ class BaseAgent: f"date={context_injector.today_date}, job_id={context_injector.job_id}, " f"session_id={context_injector.session_id}") + def _get_current_prices(self, today_date: str) -> Dict[str, float]: + """ + Get current market prices for all symbols on given date. + + Args: + today_date: Trading date in YYYY-MM-DD format + + Returns: + Dict mapping symbol to current price (buy price) + """ + from tools.price_tools import get_open_prices + + # Get buy prices for today (these are the current market prices) + price_dict = get_open_prices(today_date, self.stock_symbols) + + # Convert from {AAPL_price: 150.0} to {AAPL: 150.0} + current_prices = {} + for key, value in price_dict.items(): + if value is not None and key.endswith("_price"): + symbol = key.replace("_price", "") + current_prices[symbol] = value + + return current_prices + + def _get_current_portfolio_state(self) -> tuple[Dict[str, int], float]: + """ + Get current portfolio state from position.jsonl file. + + Returns: + Tuple of (holdings dict, cash balance) + """ + if not os.path.exists(self.position_file): + # No position file yet - return initial state + return {}, self.initial_cash + + # Read last line of position file + with open(self.position_file, "r") as f: + lines = f.readlines() + if not lines: + return {}, self.initial_cash + + last_line = lines[-1].strip() + if not last_line: + return {}, self.initial_cash + + position_data = json.loads(last_line) + positions = position_data.get("positions", {}) + + # Extract holdings (exclude CASH) + holdings = { + symbol: int(qty) + for symbol, qty in positions.items() + if symbol != "CASH" and qty > 0 + } + + # Extract cash + cash = float(positions.get("CASH", self.initial_cash)) + + return holdings, cash + + def _calculate_portfolio_value( + self, + holdings: Dict[str, int], + prices: Dict[str, float], + cash: float + ) -> float: + """ + Calculate total portfolio value. + + Args: + holdings: Dict mapping symbol to quantity + prices: Dict mapping symbol to price + cash: Cash balance + + Returns: + Total portfolio value + """ + total_value = cash + + for symbol, quantity in holdings.items(): + if symbol in prices: + total_value += quantity * prices[symbol] + else: + print(f"⚠️ Warning: No price data for {symbol}, excluding from value calculation") + + return total_value + def _capture_message(self, role: str, content: str, tool_name: str = None, tool_input: str = None) -> None: """ Capture a message in conversation history. @@ -375,12 +468,15 @@ Summary:""" async def run_trading_session(self, today_date: str) -> None: """ - Run single day trading session + Run single day trading session with P&L calculation and database integration. Args: - today_date: Trading date + today_date: Trading date in YYYY-MM-DD format """ + from api.database import Database + print(f"📈 Starting trading session: {today_date}") + session_start = time.time() # Update context injector with current trading date if self.context_injector: @@ -393,6 +489,56 @@ Summary:""" if is_dev_mode(): self.model.date = today_date + # Get job_id from context injector + job_id = self.context_injector.job_id if self.context_injector else get_config_value("JOB_ID") + if not job_id: + raise ValueError("job_id not available - ensure context_injector is set or JOB_ID is in config") + + # Initialize database + db = Database() + + # 1. Get previous trading day data + previous_day = db.get_previous_trading_day( + job_id=job_id, + model=self.signature, + current_date=today_date + ) + + # Add holdings to previous_day dict if exists + if previous_day: + previous_day_id = previous_day["id"] + previous_day["holdings"] = db.get_ending_holdings(previous_day_id) + + # 2. Load today's buy prices (current market prices for P&L calculation) + current_prices = self._get_current_prices(today_date) + + # 3. Calculate daily P&L + pnl_metrics = self.pnl_calculator.calculate( + previous_day=previous_day, + current_date=today_date, + current_prices=current_prices + ) + + # 4. Determine starting cash (from previous day or initial cash) + starting_cash = previous_day["ending_cash"] if previous_day else self.initial_cash + + # 5. Create trading_day record (will be updated after session) + trading_day_id = db.create_trading_day( + job_id=job_id, + model=self.signature, + date=today_date, + starting_cash=starting_cash, + starting_portfolio_value=pnl_metrics["starting_portfolio_value"], + daily_profit=pnl_metrics["daily_profit"], + daily_return_pct=pnl_metrics["daily_return_pct"], + ending_cash=starting_cash, # Will update after trading + ending_portfolio_value=pnl_metrics["starting_portfolio_value"], # Will update + days_since_last_trading=pnl_metrics["days_since_last_trading"] + ) + + # 6. Run AI trading session + action_count = 0 + # Get system prompt system_prompt = get_agent_system_prompt(today_date, self.signature) @@ -451,7 +597,64 @@ Summary:""" print(f"Error details: {e}") raise - # Handle trading results + session_duration = time.time() - session_start + + # 7. Generate reasoning summary + summarizer = ReasoningSummarizer(model=self.model) + summary = await summarizer.generate_summary(self.conversation_history) + + # 8. Get current portfolio state from position.jsonl file + current_holdings, current_cash = self._get_current_portfolio_state() + + # 9. Save final holdings to database + for symbol, quantity in current_holdings.items(): + if quantity > 0: + db.create_holding( + trading_day_id=trading_day_id, + symbol=symbol, + quantity=quantity + ) + + # 10. Calculate final portfolio value + final_value = self._calculate_portfolio_value(current_holdings, current_prices, current_cash) + + # 11. Count actions from trade tool calls + action_count = sum( + 1 for msg in self.conversation_history + if msg.get("role") == "tool" and msg.get("tool_name") in ["buy", "sell"] + ) + + # 12. Update trading_day with completion data + db.connection.execute( + """ + UPDATE trading_days + SET + ending_cash = ?, + ending_portfolio_value = ?, + reasoning_summary = ?, + reasoning_full = ?, + total_actions = ?, + session_duration_seconds = ?, + completed_at = CURRENT_TIMESTAMP + WHERE id = ? + """, + ( + current_cash, + final_value, + summary, + json.dumps(self.conversation_history), + action_count, + session_duration, + trading_day_id + ) + ) + db.connection.commit() + + print(f"✅ Trading session completed in {session_duration:.2f}s") + print(f"💰 Final portfolio value: ${final_value:.2f}") + print(f"📊 Daily P&L: ${pnl_metrics['daily_profit']:.2f} ({pnl_metrics['daily_return_pct']:.2f}%)") + + # Handle trading results (maintains backward compatibility with JSONL) await self._handle_trading_result(today_date) async def _handle_trading_result(self, today_date: str) -> None: diff --git a/tests/integration/test_agent_pnl_integration.py b/tests/integration/test_agent_pnl_integration.py new file mode 100644 index 0000000..59d97f0 --- /dev/null +++ b/tests/integration/test_agent_pnl_integration.py @@ -0,0 +1,144 @@ +"""Integration tests for P&L calculation in BaseAgent.""" +import pytest +from unittest.mock import Mock, AsyncMock, patch, MagicMock +import os +import json + + +class TestAgentPnLIntegration: + """Test P&L calculation integration in BaseAgent.run_trading_session.""" + + @pytest.fixture + def test_db(self, tmp_path): + """Create test database with trading_days schema.""" + import importlib + from api.database import Database + + migration_module = importlib.import_module("api.migrations.001_trading_days_schema") + create_trading_days_schema = migration_module.create_trading_days_schema + + db_path = tmp_path / "test.db" + db = Database(str(db_path)) + + # Create jobs table (prerequisite) + db.connection.execute(""" + CREATE TABLE IF NOT EXISTS jobs ( + job_id TEXT PRIMARY KEY, + status TEXT + ) + """) + + # Create trading_days schema + create_trading_days_schema(db) + + # Insert test job + db.connection.execute( + "INSERT INTO jobs (job_id, status) VALUES (?, ?)", + ("test-job", "running") + ) + db.connection.commit() + + yield db + db.connection.close() + + @pytest.mark.asyncio + @patch('tools.deployment_config.get_database_path') + @patch('tools.general_tools.get_config_value') + @patch('tools.general_tools.write_config_value') + async def test_run_trading_session_creates_trading_day_record( + self, mock_write_config, mock_get_config, mock_db_path, test_db + ): + """Test that run_trading_session creates a trading_day record with P&L.""" + from agent.base_agent.base_agent import BaseAgent + + # Setup database path + mock_db_path.return_value = test_db.db_path + + # Setup config mocks + mock_get_config.side_effect = lambda key: { + "IF_TRADE": False, + "JOB_ID": "test-job", + "TODAY_DATE": "2025-01-15", + "SIGNATURE": "test-model" + }.get(key) + + # Create BaseAgent instance + agent = BaseAgent( + signature="test-model", + basemodel="gpt-4", + max_steps=2, + initial_cash=10000.0, + init_date="2025-01-01" + ) + + # Initialize agent + await agent.initialize() + + # Mock the AI model to return finish signal immediately + agent.model = AsyncMock() + agent.model.ainvoke = AsyncMock(return_value=Mock( + content="" + )) + + # Mock agent creation + with patch('agent.base_agent.base_agent.create_agent') as mock_create_agent: + mock_agent = MagicMock() + mock_agent.ainvoke = AsyncMock(return_value={ + "messages": [{"content": ""}] + }) + mock_create_agent.return_value = mock_agent + + # Mock price tools + with patch('tools.price_tools.get_open_prices') as mock_get_prices: + with patch('tools.price_tools.get_yesterday_open_and_close_price') as mock_yesterday_prices: + mock_get_prices.return_value = {"AAPL_price": 150.0} + mock_yesterday_prices.return_value = ({}, {"AAPL_price": 145.0}) + + # Mock context injector + agent.context_injector = Mock() + agent.context_injector.session_id = "test-session-id" + agent.context_injector.job_id = "test-job" + + # NOTE: This test currently verifies the setup works + # Once we integrate P&L calculation, this test should verify: + # 1. trading_day record is created + # 2. P&L metrics are calculated correctly + # 3. Holdings are saved + + # For now, just verify the agent can run without error + try: + await agent.run_trading_session("2025-01-15") + # Test passes if no exception is raised + # After implementation, verify database records + except AttributeError as e: + # Expected to fail before implementation + if "pnl_calculator" in str(e): + pytest.skip("P&L calculator not yet integrated") + else: + raise + + @pytest.mark.asyncio + async def test_pnl_calculation_components_exist(self): + """Verify P&L calculation components exist and are importable.""" + from agent.pnl_calculator import DailyPnLCalculator + from agent.reasoning_summarizer import ReasoningSummarizer + + # Test DailyPnLCalculator + calculator = DailyPnLCalculator(initial_cash=10000.0) + assert calculator is not None + + # Test first day calculation (should be zero P&L) + result = calculator.calculate( + previous_day=None, + current_date="2025-01-15", + current_prices={"AAPL": 150.0} + ) + assert result["daily_profit"] == 0.0 + assert result["daily_return_pct"] == 0.0 + assert result["starting_portfolio_value"] == 10000.0 + + # Test ReasoningSummarizer (without actual AI model) + # We'll test this with a mock model + mock_model = Mock() + summarizer = ReasoningSummarizer(model=mock_model) + assert summarizer is not None