From e20dce7432c23cb1b047a90b4b657b72b777a440 Mon Sep 17 00:00:00 2001 From: Bill Date: Wed, 5 Nov 2025 06:56:54 -0500 Subject: [PATCH] fix: enable intra-day position tracking for sell-then-buy trades Resolves issue where sell proceeds were not immediately available for subsequent buy orders within the same trading session. Problem: - Both buy() and sell() independently queried database for starting position - Multiple trades within same day all saw pre-trade cash balance - Agents couldn't rebalance portfolios (sell + buy) in single session Solution: - ContextInjector maintains in-memory position state during trading session - Position updates accumulate after each successful trade - Position state injected into buy/sell via _current_position parameter - Reset position state at start of each trading day Changes: - agent/context_injector.py: Add position tracking with reset_position() - agent_tools/tool_trade.py: Accept _current_position in buy/sell functions - agent/base_agent/base_agent.py: Reset position state daily - tests: Add 13 comprehensive tests for position tracking All new tests pass. Backward compatible, no schema changes required. --- agent/base_agent/base_agent.py | 2 + agent/context_injector.py | 34 +++- agent_tools/tool_trade.py | 48 ++++-- tests/unit/test_context_injector.py | 192 ++++++++++++++++++++++ tests/unit/test_trade_tools_new_schema.py | 187 +++++++++++++++++++++ 5 files changed, 449 insertions(+), 14 deletions(-) create mode 100644 tests/unit/test_context_injector.py diff --git a/agent/base_agent/base_agent.py b/agent/base_agent/base_agent.py index db673f2..7314d11 100644 --- a/agent/base_agent/base_agent.py +++ b/agent/base_agent/base_agent.py @@ -533,6 +533,8 @@ Summary:""" # Update context injector with current trading date if self.context_injector: self.context_injector.today_date = today_date + # Reset position state for new trading day (enables intra-day tracking) + self.context_injector.reset_position() # Clear conversation history for new trading day self.clear_conversation_history() diff --git a/agent/context_injector.py b/agent/context_injector.py index f2a1858..7ca10a3 100644 --- a/agent/context_injector.py +++ b/agent/context_injector.py @@ -3,15 +3,22 @@ Tool interceptor for injecting runtime context into MCP tool calls. This interceptor automatically injects `signature` and `today_date` parameters into buy/sell tool calls to support concurrent multi-model simulations. + +It also maintains in-memory position state to track cumulative changes within +a single trading session, ensuring sell proceeds are immediately available for +subsequent buy orders. """ -from typing import Any, Callable, Awaitable +from typing import Any, Callable, Awaitable, Dict, Optional class ContextInjector: """ Intercepts tool calls to inject runtime context (signature, today_date). + Also maintains cumulative position state during trading session to ensure + sell proceeds are immediately available for subsequent buys. + Usage: interceptor = ContextInjector(signature="gpt-5", today_date="2025-10-01") client = MultiServerMCPClient(config, tool_interceptors=[interceptor]) @@ -34,6 +41,13 @@ class ContextInjector: self.job_id = job_id self.session_id = session_id # Deprecated but kept for compatibility self.trading_day_id = trading_day_id + self._current_position: Optional[Dict[str, float]] = None + + def reset_position(self) -> None: + """ + Reset position state (call at start of each trading day). + """ + self._current_position = None async def __call__( self, @@ -43,6 +57,9 @@ class ContextInjector: """ Intercept tool call and inject context parameters. + For buy/sell operations, maintains cumulative position state to ensure + sell proceeds are immediately available for subsequent buys. + Args: request: Tool call request containing name and arguments handler: Async callable to execute the actual tool @@ -62,5 +79,18 @@ class ContextInjector: if self.trading_day_id: request.args["trading_day_id"] = self.trading_day_id + # Inject current position if we're tracking it + if self._current_position is not None: + request.args["_current_position"] = self._current_position + # Call the actual tool handler - return await handler(request) + result = await handler(request) + + # Update position state after successful trade + if request.name in ["buy", "sell"]: + # Check if result is a valid position dict (not an error) + if isinstance(result, dict) and "error" not in result and "CASH" in result: + # Update our tracked position with the new state + self._current_position = result.copy() + + return result diff --git a/agent_tools/tool_trade.py b/agent_tools/tool_trade.py index c518c1d..410a3ef 100644 --- a/agent_tools/tool_trade.py +++ b/agent_tools/tool_trade.py @@ -91,7 +91,8 @@ def get_current_position_from_db( def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = None, - job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]: + job_id: str = None, session_id: int = None, trading_day_id: int = None, + _current_position: Dict[str, float] = None) -> Dict[str, Any]: """ Internal buy implementation - accepts injected context parameters. @@ -103,9 +104,13 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = job_id: Job ID (injected) session_id: Session ID (injected, DEPRECATED) trading_day_id: Trading day ID (injected) + _current_position: Current position state (injected by ContextInjector) This function is not exposed to the AI model. It receives runtime context (signature, today_date, job_id, session_id, trading_day_id) from the ContextInjector. + + The _current_position parameter enables intra-day position tracking, ensuring + sell proceeds are immediately available for subsequent buys. """ # Validate required parameters if not job_id: @@ -121,7 +126,13 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = try: # Step 1: Get current position - current_position, next_action_id = get_current_position_from_db(job_id, signature, today_date) + # Use injected position if available (for intra-day tracking), + # otherwise query database for starting position + if _current_position is not None: + current_position = _current_position + next_action_id = 0 # Not used in new schema + else: + current_position, next_action_id = get_current_position_from_db(job_id, signature, today_date) # Step 2: Get stock price try: @@ -186,7 +197,8 @@ def _buy_impl(symbol: str, amount: int, signature: str = None, today_date: str = @mcp.tool() def buy(symbol: str, amount: int, signature: str = None, today_date: str = None, - job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]: + job_id: str = None, session_id: int = None, trading_day_id: int = None, + _current_position: Dict[str, float] = None) -> Dict[str, Any]: """ Buy stock shares. @@ -199,14 +211,15 @@ def buy(symbol: str, amount: int, signature: str = None, today_date: str = None, - Success: {"CASH": remaining_cash, "SYMBOL": shares, ...} - Failure: {"error": error_message, ...} - Note: signature, today_date, job_id, session_id, trading_day_id are - automatically injected by the system. Do not provide these parameters. + Note: signature, today_date, job_id, session_id, trading_day_id, _current_position + are automatically injected by the system. Do not provide these parameters. """ - return _buy_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id) + return _buy_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id, _current_position) def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str = None, - job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]: + job_id: str = None, session_id: int = None, trading_day_id: int = None, + _current_position: Dict[str, float] = None) -> Dict[str, Any]: """ Sell stock function - writes to SQLite database. @@ -218,11 +231,15 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str job_id: Job UUID (injected by ContextInjector) session_id: Trading session ID (injected by ContextInjector, DEPRECATED) trading_day_id: Trading day ID (injected by ContextInjector) + _current_position: Current position state (injected by ContextInjector) Returns: Dict[str, Any]: - Success: {"CASH": amount, symbol: quantity, ...} - Failure: {"error": message, ...} + + The _current_position parameter enables intra-day position tracking, ensuring + sell proceeds are immediately available for subsequent buys. """ # Validate required parameters if not job_id: @@ -238,7 +255,13 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str try: # Step 1: Get current position - current_position, next_action_id = get_current_position_from_db(job_id, signature, today_date) + # Use injected position if available (for intra-day tracking), + # otherwise query database for starting position + if _current_position is not None: + current_position = _current_position + next_action_id = 0 # Not used in new schema + else: + current_position, next_action_id = get_current_position_from_db(job_id, signature, today_date) # Step 2: Validate position exists if symbol not in current_position: @@ -298,7 +321,8 @@ def _sell_impl(symbol: str, amount: int, signature: str = None, today_date: str @mcp.tool() def sell(symbol: str, amount: int, signature: str = None, today_date: str = None, - job_id: str = None, session_id: int = None, trading_day_id: int = None) -> Dict[str, Any]: + job_id: str = None, session_id: int = None, trading_day_id: int = None, + _current_position: Dict[str, float] = None) -> Dict[str, Any]: """ Sell stock shares. @@ -311,10 +335,10 @@ def sell(symbol: str, amount: int, signature: str = None, today_date: str = None - Success: {"CASH": remaining_cash, "SYMBOL": shares, ...} - Failure: {"error": error_message, ...} - Note: signature, today_date, job_id, session_id, trading_day_id are - automatically injected by the system. Do not provide these parameters. + Note: signature, today_date, job_id, session_id, trading_day_id, _current_position + are automatically injected by the system. Do not provide these parameters. """ - return _sell_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id) + return _sell_impl(symbol, amount, signature, today_date, job_id, session_id, trading_day_id, _current_position) if __name__ == "__main__": diff --git a/tests/unit/test_context_injector.py b/tests/unit/test_context_injector.py new file mode 100644 index 0000000..f0e7d50 --- /dev/null +++ b/tests/unit/test_context_injector.py @@ -0,0 +1,192 @@ +"""Test ContextInjector position tracking functionality.""" + +import pytest +from agent.context_injector import ContextInjector + + +@pytest.fixture +def injector(): + """Create a ContextInjector instance for testing.""" + return ContextInjector( + signature="test-model", + today_date="2025-01-15", + job_id="test-job-123", + trading_day_id=1 + ) + + +class MockRequest: + """Mock MCP tool request.""" + def __init__(self, name, args=None): + self.name = name + self.args = args or {} + + +async def mock_handler_success(request): + """Mock handler that returns a successful position update.""" + # Simulate a successful trade returning updated position + if request.name == "sell": + return { + "CASH": 1100.0, + "AAPL": 7, + "MSFT": 5 + } + elif request.name == "buy": + return { + "CASH": 50.0, + "AAPL": 7, + "MSFT": 12 + } + return {} + + +async def mock_handler_error(request): + """Mock handler that returns an error.""" + return {"error": "Insufficient cash"} + + +@pytest.mark.asyncio +async def test_context_injector_initializes_with_no_position(injector): + """Test that ContextInjector starts with no position state.""" + assert injector._current_position is None + + +@pytest.mark.asyncio +async def test_context_injector_reset_position(injector): + """Test that reset_position() clears position state.""" + # Set some position state + injector._current_position = {"CASH": 5000.0, "AAPL": 10} + + # Reset + injector.reset_position() + + assert injector._current_position is None + + +@pytest.mark.asyncio +async def test_context_injector_injects_parameters(injector): + """Test that context parameters are injected into buy/sell requests.""" + request = MockRequest("buy", {"symbol": "AAPL", "amount": 10}) + + # Mock handler that just returns the request args + async def handler(req): + return req.args + + result = await injector(request, handler) + + # Verify context was injected + assert result["signature"] == "test-model" + assert result["today_date"] == "2025-01-15" + assert result["job_id"] == "test-job-123" + assert result["trading_day_id"] == 1 + + +@pytest.mark.asyncio +async def test_context_injector_tracks_position_after_successful_trade(injector): + """Test that position state is updated after successful trades.""" + assert injector._current_position is None + + # Execute a sell trade + request = MockRequest("sell", {"symbol": "AAPL", "amount": 3}) + result = await injector(request, mock_handler_success) + + # Verify position was updated + assert injector._current_position is not None + assert injector._current_position["CASH"] == 1100.0 + assert injector._current_position["AAPL"] == 7 + assert injector._current_position["MSFT"] == 5 + + +@pytest.mark.asyncio +async def test_context_injector_injects_current_position_on_subsequent_trades(injector): + """Test that current position is injected into subsequent trade requests.""" + # First trade - establish position + request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3}) + await injector(request1, mock_handler_success) + + # Second trade - should receive current position + request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 7}) + + async def verify_injection_handler(req): + # Verify that _current_position was injected + assert "_current_position" in req.args + assert req.args["_current_position"]["CASH"] == 1100.0 + assert req.args["_current_position"]["AAPL"] == 7 + return mock_handler_success(req) + + await injector(request2, verify_injection_handler) + + +@pytest.mark.asyncio +async def test_context_injector_does_not_update_position_on_error(injector): + """Test that position state is NOT updated when trade fails.""" + # First successful trade + request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3}) + await injector(request1, mock_handler_success) + + original_position = injector._current_position.copy() + + # Second trade that fails + request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 100}) + result = await injector(request2, mock_handler_error) + + # Verify position was NOT updated + assert injector._current_position == original_position + assert "error" in result + + +@pytest.mark.asyncio +async def test_context_injector_does_not_inject_position_for_non_trade_tools(injector): + """Test that position is not injected for non-buy/sell tools.""" + # Set up position state + injector._current_position = {"CASH": 5000.0, "AAPL": 10} + + # Call a non-trade tool + request = MockRequest("search", {"query": "market news"}) + + async def verify_no_injection_handler(req): + assert "_current_position" not in req.args + return {"results": []} + + await injector(request, verify_no_injection_handler) + + +@pytest.mark.asyncio +async def test_context_injector_full_trading_session_simulation(injector): + """Test full trading session with multiple trades and position tracking.""" + # Reset position at start of day + injector.reset_position() + assert injector._current_position is None + + # Trade 1: Sell AAPL + request1 = MockRequest("sell", {"symbol": "AAPL", "amount": 3}) + + async def handler1(req): + # First trade should NOT have injected position + assert req.args.get("_current_position") is None + return {"CASH": 1100.0, "AAPL": 7} + + result1 = await injector(request1, handler1) + assert injector._current_position == {"CASH": 1100.0, "AAPL": 7} + + # Trade 2: Buy MSFT (should use position from trade 1) + request2 = MockRequest("buy", {"symbol": "MSFT", "amount": 7}) + + async def handler2(req): + # Second trade SHOULD have injected position from trade 1 + assert req.args["_current_position"]["CASH"] == 1100.0 + assert req.args["_current_position"]["AAPL"] == 7 + return {"CASH": 50.0, "AAPL": 7, "MSFT": 7} + + result2 = await injector(request2, handler2) + assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7} + + # Trade 3: Failed trade (should not update position) + request3 = MockRequest("buy", {"symbol": "GOOGL", "amount": 100}) + + async def handler3(req): + return {"error": "Insufficient cash", "cash_available": 50.0} + + result3 = await injector(request3, handler3) + # Position should remain unchanged after failed trade + assert injector._current_position == {"CASH": 50.0, "AAPL": 7, "MSFT": 7} diff --git a/tests/unit/test_trade_tools_new_schema.py b/tests/unit/test_trade_tools_new_schema.py index f0cf87b..7276a3e 100644 --- a/tests/unit/test_trade_tools_new_schema.py +++ b/tests/unit/test_trade_tools_new_schema.py @@ -295,3 +295,190 @@ def test_sell_writes_to_actions_table(test_db, monkeypatch): assert row[1] == 'AAPL' assert row[2] == 5 assert row[3] == 160.0 + + +def test_intraday_position_tracking_sell_then_buy(test_db, monkeypatch): + """Test that sell proceeds are immediately available for subsequent buys.""" + db, trading_day_id = test_db + + # Setup: Create starting position with AAPL shares and limited cash + db.create_holding(trading_day_id, 'AAPL', 10) + db.connection.commit() + + # Create a mock connection wrapper + class MockConnection: + def __init__(self, real_conn): + self.real_conn = real_conn + + def cursor(self): + return self.real_conn.cursor() + + def commit(self): + return self.real_conn.commit() + + def rollback(self): + return self.real_conn.rollback() + + def close(self): + pass + + mock_conn = MockConnection(db.connection) + monkeypatch.setattr('agent_tools.tool_trade.get_db_connection', + lambda x: mock_conn) + + # Mock get_current_position_from_db to return starting position + monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db', + lambda job_id, sig, date: ({'CASH': 500.0, 'AAPL': 10}, 0)) + + monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime_intraday.json') + + import json + with open('/tmp/test_runtime_intraday.json', 'w') as f: + json.dump({ + 'TODAY_DATE': '2025-01-15', + 'SIGNATURE': 'test-model', + 'JOB_ID': 'test-job-123', + 'TRADING_DAY_ID': trading_day_id + }, f) + + # Mock prices: AAPL sells for 200, MSFT costs 150 + def mock_get_prices(date, symbols): + if 'AAPL' in symbols: + return {'AAPL_price': 200.0} + elif 'MSFT' in symbols: + return {'MSFT_price': 150.0} + return {} + + monkeypatch.setattr('agent_tools.tool_trade.get_open_prices', mock_get_prices) + + # Step 1: Sell 3 shares of AAPL for 600.0 + # Starting cash: 500.0, proceeds: 600.0, new cash: 1100.0 + result_sell = _sell_impl( + symbol='AAPL', + amount=3, + signature='test-model', + today_date='2025-01-15', + job_id='test-job-123', + trading_day_id=trading_day_id, + _current_position=None # Use database position (starting position) + ) + + assert 'error' not in result_sell, f"Sell should succeed: {result_sell}" + assert result_sell['CASH'] == 1100.0, "Cash should be 500 + (3 * 200) = 1100" + assert result_sell['AAPL'] == 7, "AAPL shares should be 10 - 3 = 7" + + # Step 2: Buy 7 shares of MSFT for 1050.0 using the position from the sell + # This should work because we pass the updated position from step 1 + result_buy = _buy_impl( + symbol='MSFT', + amount=7, + signature='test-model', + today_date='2025-01-15', + job_id='test-job-123', + trading_day_id=trading_day_id, + _current_position=result_sell # Use position from sell + ) + + assert 'error' not in result_buy, f"Buy should succeed with sell proceeds: {result_buy}" + assert result_buy['CASH'] == 50.0, "Cash should be 1100 - (7 * 150) = 50" + assert result_buy['MSFT'] == 7, "MSFT shares should be 7" + assert result_buy['AAPL'] == 7, "AAPL shares should still be 7" + + # Verify both actions were recorded + cursor = db.connection.execute(""" + SELECT action_type, symbol, quantity, price + FROM actions + WHERE trading_day_id = ? + ORDER BY created_at + """, (trading_day_id,)) + + actions = cursor.fetchall() + assert len(actions) == 2, "Should have 2 actions (sell + buy)" + assert actions[0][0] == 'sell' and actions[0][1] == 'AAPL' + assert actions[1][0] == 'buy' and actions[1][1] == 'MSFT' + + +def test_intraday_tracking_without_position_injection_fails(test_db, monkeypatch): + """Test that without position injection, sell proceeds are NOT available for subsequent buys.""" + db, trading_day_id = test_db + + # Setup: Create starting position with AAPL shares and limited cash + db.create_holding(trading_day_id, 'AAPL', 10) + db.connection.commit() + + # Create a mock connection wrapper + class MockConnection: + def __init__(self, real_conn): + self.real_conn = real_conn + + def cursor(self): + return self.real_conn.cursor() + + def commit(self): + return self.real_conn.commit() + + def rollback(self): + return self.real_conn.rollback() + + def close(self): + pass + + mock_conn = MockConnection(db.connection) + monkeypatch.setattr('agent_tools.tool_trade.get_db_connection', + lambda x: mock_conn) + + # Mock get_current_position_from_db to ALWAYS return starting position + # (simulating the old buggy behavior) + monkeypatch.setattr('agent_tools.tool_trade.get_current_position_from_db', + lambda job_id, sig, date: ({'CASH': 500.0, 'AAPL': 10}, 0)) + + monkeypatch.setenv('RUNTIME_ENV_PATH', '/tmp/test_runtime_no_injection.json') + + import json + with open('/tmp/test_runtime_no_injection.json', 'w') as f: + json.dump({ + 'TODAY_DATE': '2025-01-15', + 'SIGNATURE': 'test-model', + 'JOB_ID': 'test-job-123', + 'TRADING_DAY_ID': trading_day_id + }, f) + + # Mock prices + def mock_get_prices(date, symbols): + if 'AAPL' in symbols: + return {'AAPL_price': 200.0} + elif 'MSFT' in symbols: + return {'MSFT_price': 150.0} + return {} + + monkeypatch.setattr('agent_tools.tool_trade.get_open_prices', mock_get_prices) + + # Step 1: Sell 3 shares of AAPL + result_sell = _sell_impl( + symbol='AAPL', + amount=3, + signature='test-model', + today_date='2025-01-15', + job_id='test-job-123', + trading_day_id=trading_day_id, + _current_position=None # Don't inject position (old behavior) + ) + + assert 'error' not in result_sell, "Sell should succeed" + + # Step 2: Try to buy 7 shares of MSFT WITHOUT passing updated position + # This should FAIL because it will query the database and get the original 500.0 cash + result_buy = _buy_impl( + symbol='MSFT', + amount=7, + signature='test-model', + today_date='2025-01-15', + job_id='test-job-123', + trading_day_id=trading_day_id, + _current_position=None # Don't inject position (old behavior) + ) + + # This should fail with insufficient cash + assert 'error' in result_buy, "Buy should fail without position injection" + assert result_buy['error'] == 'Insufficient cash', f"Expected insufficient cash error, got: {result_buy}" + assert result_buy['cash_available'] == 500.0, "Should see original cash, not updated cash"