mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
feat: add LangChain-compatible mock chat model wrapper
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
"""Mock AI provider for development mode testing"""
|
||||
from .mock_ai_provider import MockAIProvider
|
||||
from .mock_langchain_model import MockChatModel
|
||||
|
||||
__all__ = ["MockAIProvider"]
|
||||
__all__ = ["MockAIProvider", "MockChatModel"]
|
||||
|
||||
108
agent/mock_provider/mock_langchain_model.py
Normal file
108
agent/mock_provider/mock_langchain_model.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Mock LangChain-compatible chat model for development mode
|
||||
|
||||
Wraps MockAIProvider to work with LangChain's agent framework.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, Dict
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import AIMessage, BaseMessage
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
from .mock_ai_provider import MockAIProvider
|
||||
|
||||
|
||||
class MockChatModel(BaseChatModel):
|
||||
"""
|
||||
Mock chat model compatible with LangChain's agent framework
|
||||
|
||||
Attributes:
|
||||
date: Current trading date for response generation
|
||||
step_counter: Tracks reasoning steps within a trading session
|
||||
provider: MockAIProvider instance
|
||||
"""
|
||||
|
||||
date: str = "2025-01-01"
|
||||
step_counter: int = 0
|
||||
provider: Optional[MockAIProvider] = None
|
||||
|
||||
def __init__(self, date: str = "2025-01-01", **kwargs):
|
||||
"""
|
||||
Initialize mock chat model
|
||||
|
||||
Args:
|
||||
date: Trading date for mock responses
|
||||
**kwargs: Additional LangChain model parameters
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.date = date
|
||||
self.step_counter = 0
|
||||
self.provider = MockAIProvider()
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return identifier for this LLM type"""
|
||||
return "mock-chat-model"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Generate mock response (synchronous)
|
||||
|
||||
Args:
|
||||
messages: Input messages (ignored in mock)
|
||||
stop: Stop sequences (ignored in mock)
|
||||
run_manager: LangChain run manager
|
||||
**kwargs: Additional generation parameters
|
||||
|
||||
Returns:
|
||||
ChatResult with mock AI response
|
||||
"""
|
||||
response_text = self.provider.generate_response(self.date, self.step_counter)
|
||||
self.step_counter += 1
|
||||
|
||||
message = AIMessage(
|
||||
content=response_text,
|
||||
response_metadata={"finish_reason": "stop"}
|
||||
)
|
||||
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[Any] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
Generate mock response (asynchronous)
|
||||
|
||||
Same as _generate but async-compatible for LangChain agents.
|
||||
"""
|
||||
return self._generate(messages, stop, run_manager, **kwargs)
|
||||
|
||||
def invoke(self, input: Any, **kwargs) -> AIMessage:
|
||||
"""Synchronous invoke (LangChain compatibility)"""
|
||||
if isinstance(input, list):
|
||||
messages = input
|
||||
else:
|
||||
messages = []
|
||||
|
||||
result = self._generate(messages, **kwargs)
|
||||
return result.generations[0].message
|
||||
|
||||
async def ainvoke(self, input: Any, **kwargs) -> AIMessage:
|
||||
"""Asynchronous invoke (LangChain compatibility)"""
|
||||
if isinstance(input, list):
|
||||
messages = input
|
||||
else:
|
||||
messages = []
|
||||
|
||||
result = await self._agenerate(messages, **kwargs)
|
||||
return result.generations[0].message
|
||||
@@ -1,5 +1,7 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from agent.mock_provider.mock_ai_provider import MockAIProvider
|
||||
from agent.mock_provider.mock_langchain_model import MockChatModel
|
||||
|
||||
|
||||
def test_mock_provider_rotates_stocks():
|
||||
@@ -32,3 +34,41 @@ def test_mock_provider_valid_json_tool_calls():
|
||||
provider = MockAIProvider()
|
||||
response = provider.generate_response("2025-01-01", step=0)
|
||||
assert "[calls tool_get_price" in response or "get_price" in response.lower()
|
||||
|
||||
|
||||
def test_mock_chat_model_invoke():
|
||||
"""Test synchronous invoke returns proper message format"""
|
||||
model = MockChatModel(date="2025-01-01")
|
||||
|
||||
messages = [{"role": "user", "content": "Analyze the market"}]
|
||||
response = model.invoke(messages)
|
||||
|
||||
assert hasattr(response, "content")
|
||||
assert "AAPL" in response.content
|
||||
assert "<FINISH_SIGNAL>" in response.content
|
||||
|
||||
|
||||
def test_mock_chat_model_ainvoke():
|
||||
"""Test asynchronous invoke returns proper message format"""
|
||||
async def run_test():
|
||||
model = MockChatModel(date="2025-01-02")
|
||||
messages = [{"role": "user", "content": "Analyze the market"}]
|
||||
response = await model.ainvoke(messages)
|
||||
|
||||
assert hasattr(response, "content")
|
||||
assert "MSFT" in response.content
|
||||
assert "<FINISH_SIGNAL>" in response.content
|
||||
|
||||
asyncio.run(run_test())
|
||||
|
||||
|
||||
def test_mock_chat_model_different_dates():
|
||||
"""Test that different dates produce different responses"""
|
||||
model1 = MockChatModel(date="2025-01-01")
|
||||
model2 = MockChatModel(date="2025-01-02")
|
||||
|
||||
msg = [{"role": "user", "content": "Trade"}]
|
||||
response1 = model1.invoke(msg)
|
||||
response2 = model2.invoke(msg)
|
||||
|
||||
assert response1.content != response2.content
|
||||
|
||||
Reference in New Issue
Block a user