From 9ffd42481a35efe15fa2ba14dd02ac69ce18ff5d Mon Sep 17 00:00:00 2001 From: Bill Date: Sat, 1 Nov 2025 11:15:59 -0400 Subject: [PATCH] feat: add LangChain-compatible mock chat model wrapper MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- agent/mock_provider/__init__.py | 3 +- agent/mock_provider/mock_langchain_model.py | 108 ++++++++++++++++++++ tests/unit/test_mock_provider.py | 40 ++++++++ 3 files changed, 150 insertions(+), 1 deletion(-) create mode 100644 agent/mock_provider/mock_langchain_model.py diff --git a/agent/mock_provider/__init__.py b/agent/mock_provider/__init__.py index a4740fc..026166d 100644 --- a/agent/mock_provider/__init__.py +++ b/agent/mock_provider/__init__.py @@ -1,4 +1,5 @@ """Mock AI provider for development mode testing""" from .mock_ai_provider import MockAIProvider +from .mock_langchain_model import MockChatModel -__all__ = ["MockAIProvider"] +__all__ = ["MockAIProvider", "MockChatModel"] diff --git a/agent/mock_provider/mock_langchain_model.py b/agent/mock_provider/mock_langchain_model.py new file mode 100644 index 0000000..40ee3d9 --- /dev/null +++ b/agent/mock_provider/mock_langchain_model.py @@ -0,0 +1,108 @@ +""" +Mock LangChain-compatible chat model for development mode + +Wraps MockAIProvider to work with LangChain's agent framework. +""" + +from typing import Any, List, Optional, Dict +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.outputs import ChatResult, ChatGeneration +from .mock_ai_provider import MockAIProvider + + +class MockChatModel(BaseChatModel): + """ + Mock chat model compatible with LangChain's agent framework + + Attributes: + date: Current trading date for response generation + step_counter: Tracks reasoning steps within a trading session + provider: MockAIProvider instance + """ + + date: str = "2025-01-01" + step_counter: int = 0 + provider: Optional[MockAIProvider] = None + + def __init__(self, date: str = "2025-01-01", **kwargs): + """ + Initialize mock chat model + + Args: + date: Trading date for mock responses + **kwargs: Additional LangChain model parameters + """ + super().__init__(**kwargs) + self.date = date + self.step_counter = 0 + self.provider = MockAIProvider() + + @property + def _llm_type(self) -> str: + """Return identifier for this LLM type""" + return "mock-chat-model" + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[Any] = None, + **kwargs: Any, + ) -> ChatResult: + """ + Generate mock response (synchronous) + + Args: + messages: Input messages (ignored in mock) + stop: Stop sequences (ignored in mock) + run_manager: LangChain run manager + **kwargs: Additional generation parameters + + Returns: + ChatResult with mock AI response + """ + response_text = self.provider.generate_response(self.date, self.step_counter) + self.step_counter += 1 + + message = AIMessage( + content=response_text, + response_metadata={"finish_reason": "stop"} + ) + + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[Any] = None, + **kwargs: Any, + ) -> ChatResult: + """ + Generate mock response (asynchronous) + + Same as _generate but async-compatible for LangChain agents. + """ + return self._generate(messages, stop, run_manager, **kwargs) + + def invoke(self, input: Any, **kwargs) -> AIMessage: + """Synchronous invoke (LangChain compatibility)""" + if isinstance(input, list): + messages = input + else: + messages = [] + + result = self._generate(messages, **kwargs) + return result.generations[0].message + + async def ainvoke(self, input: Any, **kwargs) -> AIMessage: + """Asynchronous invoke (LangChain compatibility)""" + if isinstance(input, list): + messages = input + else: + messages = [] + + result = await self._agenerate(messages, **kwargs) + return result.generations[0].message diff --git a/tests/unit/test_mock_provider.py b/tests/unit/test_mock_provider.py index 1b5e4bb..28f29cf 100644 --- a/tests/unit/test_mock_provider.py +++ b/tests/unit/test_mock_provider.py @@ -1,5 +1,7 @@ import pytest +import asyncio from agent.mock_provider.mock_ai_provider import MockAIProvider +from agent.mock_provider.mock_langchain_model import MockChatModel def test_mock_provider_rotates_stocks(): @@ -32,3 +34,41 @@ def test_mock_provider_valid_json_tool_calls(): provider = MockAIProvider() response = provider.generate_response("2025-01-01", step=0) assert "[calls tool_get_price" in response or "get_price" in response.lower() + + +def test_mock_chat_model_invoke(): + """Test synchronous invoke returns proper message format""" + model = MockChatModel(date="2025-01-01") + + messages = [{"role": "user", "content": "Analyze the market"}] + response = model.invoke(messages) + + assert hasattr(response, "content") + assert "AAPL" in response.content + assert "" in response.content + + +def test_mock_chat_model_ainvoke(): + """Test asynchronous invoke returns proper message format""" + async def run_test(): + model = MockChatModel(date="2025-01-02") + messages = [{"role": "user", "content": "Analyze the market"}] + response = await model.ainvoke(messages) + + assert hasattr(response, "content") + assert "MSFT" in response.content + assert "" in response.content + + asyncio.run(run_test()) + + +def test_mock_chat_model_different_dates(): + """Test that different dates produce different responses""" + model1 = MockChatModel(date="2025-01-01") + model2 = MockChatModel(date="2025-01-02") + + msg = [{"role": "user", "content": "Trade"}] + response1 = model1.invoke(msg) + response2 = model2.invoke(msg) + + assert response1.content != response2.content