diff --git a/agent/base_agent/base_agent.py b/agent/base_agent/base_agent.py index 6d88bda..4aa1523 100644 --- a/agent/base_agent/base_agent.py +++ b/agent/base_agent/base_agent.py @@ -285,42 +285,40 @@ class BaseAgent: return current_prices - def _get_current_portfolio_state(self) -> tuple[Dict[str, int], float]: + def _get_current_portfolio_state(self, today_date: str, job_id: str) -> tuple[Dict[str, int], float]: """ - Get current portfolio state from position.jsonl file. + Get current portfolio state from database. + + Args: + today_date: Current trading date + job_id: Job ID for this trading session Returns: Tuple of (holdings dict, cash balance) """ - if not os.path.exists(self.position_file): - # No position file yet - return initial state - return {}, self.initial_cash + from agent_tools.tool_trade import get_current_position_from_db - # Read last line of position file - with open(self.position_file, "r") as f: - lines = f.readlines() - if not lines: - return {}, self.initial_cash - - last_line = lines[-1].strip() - if not last_line: - return {}, self.initial_cash - - position_data = json.loads(last_line) - positions = position_data.get("positions", {}) + try: + # Get position from database + position_dict, _ = get_current_position_from_db(job_id, self.signature, today_date) # Extract holdings (exclude CASH) holdings = { symbol: int(qty) - for symbol, qty in positions.items() + for symbol, qty in position_dict.items() if symbol != "CASH" and qty > 0 } # Extract cash - cash = float(positions.get("CASH", self.initial_cash)) + cash = float(position_dict.get("CASH", self.initial_cash)) return holdings, cash + except Exception as e: + # If no position found (first trading day), return initial state + print(f"⚠️ Could not get position from database: {e}") + return {}, self.initial_cash + def _calculate_portfolio_value( self, holdings: Dict[str, int], @@ -579,8 +577,13 @@ Summary:""" print(agent_response) break - # Extract tool messages + # Extract tool messages and count trade actions tool_msgs = extract_tool_messages(response) + for tool_msg in tool_msgs: + tool_name = getattr(tool_msg, 'name', None) or tool_msg.get('name') if isinstance(tool_msg, dict) else None + if tool_name in ['buy', 'sell']: + action_count += 1 + tool_response = '\n'.join([msg.content for msg in tool_msgs]) # Prepare new messages @@ -603,8 +606,8 @@ Summary:""" summarizer = ReasoningSummarizer(model=self.model) summary = await summarizer.generate_summary(self.conversation_history) - # 8. Get current portfolio state from position.jsonl file - current_holdings, current_cash = self._get_current_portfolio_state() + # 8. Get current portfolio state from database + current_holdings, current_cash = self._get_current_portfolio_state(today_date, job_id) # 9. Save final holdings to database for symbol, quantity in current_holdings.items(): @@ -618,13 +621,7 @@ Summary:""" # 10. Calculate final portfolio value final_value = self._calculate_portfolio_value(current_holdings, current_prices, current_cash) - # 11. Count actions from trade tool calls - action_count = sum( - 1 for msg in self.conversation_history - if msg.get("role") == "tool" and msg.get("tool_name") in ["buy", "sell"] - ) - - # 12. Update trading_day with completion data + # 11. Update trading_day with completion data db.connection.execute( """ UPDATE trading_days diff --git a/api/database.py b/api/database.py index 256f7c1..830f0c1 100644 --- a/api/database.py +++ b/api/database.py @@ -553,8 +553,8 @@ class Database: If None, uses default from deployment config. """ if db_path is None: - from tools.deployment_config import get_database_path - db_path = get_database_path() + from tools.deployment_config import get_db_path + db_path = get_db_path("data/trading.db") self.db_path = db_path self.connection = sqlite3.connect(db_path, check_same_thread=False) diff --git a/tests/integration/test_agent_pnl_integration.py b/tests/integration/test_agent_pnl_integration.py index 59d97f0..b9bb511 100644 --- a/tests/integration/test_agent_pnl_integration.py +++ b/tests/integration/test_agent_pnl_integration.py @@ -42,15 +42,19 @@ class TestAgentPnLIntegration: db.connection.close() @pytest.mark.asyncio - @patch('tools.deployment_config.get_database_path') + @patch('agent.base_agent.base_agent.is_dev_mode') + @patch('tools.deployment_config.get_db_path') @patch('tools.general_tools.get_config_value') @patch('tools.general_tools.write_config_value') async def test_run_trading_session_creates_trading_day_record( - self, mock_write_config, mock_get_config, mock_db_path, test_db + self, mock_write_config, mock_get_config, mock_db_path, mock_is_dev, test_db ): """Test that run_trading_session creates a trading_day record with P&L.""" from agent.base_agent.base_agent import BaseAgent + # Setup dev mode + mock_is_dev.return_value = True + # Setup database path mock_db_path.return_value = test_db.db_path @@ -71,8 +75,9 @@ class TestAgentPnLIntegration: init_date="2025-01-01" ) - # Initialize agent - await agent.initialize() + # Skip actual initialization - just set up mocks directly + agent.client = Mock() + agent.tools = [] # Mock the AI model to return finish signal immediately agent.model = AsyncMock() @@ -99,23 +104,41 @@ class TestAgentPnLIntegration: agent.context_injector.session_id = "test-session-id" agent.context_injector.job_id = "test-job" - # NOTE: This test currently verifies the setup works - # Once we integrate P&L calculation, this test should verify: - # 1. trading_day record is created - # 2. P&L metrics are calculated correctly - # 3. Holdings are saved + # Mock get_current_position_from_db to return initial holdings + with patch('agent_tools.tool_trade.get_current_position_from_db') as mock_get_position: + mock_get_position.return_value = ({"CASH": 10000.0}, 0) - # For now, just verify the agent can run without error - try: - await agent.run_trading_session("2025-01-15") - # Test passes if no exception is raised - # After implementation, verify database records - except AttributeError as e: - # Expected to fail before implementation - if "pnl_calculator" in str(e): - pytest.skip("P&L calculator not yet integrated") - else: - raise + # Mock add_no_trade_record_to_db to avoid FK constraint issues + with patch('tools.price_tools.add_no_trade_record_to_db') as mock_no_trade: + # Run trading session + await agent.run_trading_session("2025-01-15") + + # Verify trading_day record was created + cursor = test_db.connection.execute( + """ + SELECT id, model, date, starting_cash, ending_cash, + starting_portfolio_value, ending_portfolio_value, + daily_profit, daily_return_pct, total_actions + FROM trading_days + WHERE job_id = ? AND model = ? AND date = ? + """, + ("test-job", "test-model", "2025-01-15") + ) + row = cursor.fetchone() + + # Verify record exists + assert row is not None, "trading_day record should be created" + + # Verify basic fields + assert row[1] == "test-model" + assert row[2] == "2025-01-15" + assert row[3] == 10000.0 # starting_cash + assert row[5] == 10000.0 # starting_portfolio_value (first day) + assert row[7] == 0.0 # daily_profit (first day) + assert row[8] == 0.0 # daily_return_pct (first day) + + # Verify action count + assert row[9] == 0 # total_actions (no trades executed in test) @pytest.mark.asyncio async def test_pnl_calculation_components_exist(self):