mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-09 12:17:24 -04:00
feat: add LangChain-compatible mock chat model wrapper
🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
"""Mock AI provider for development mode testing"""
|
"""Mock AI provider for development mode testing"""
|
||||||
from .mock_ai_provider import MockAIProvider
|
from .mock_ai_provider import MockAIProvider
|
||||||
|
from .mock_langchain_model import MockChatModel
|
||||||
|
|
||||||
__all__ = ["MockAIProvider"]
|
__all__ = ["MockAIProvider", "MockChatModel"]
|
||||||
|
|||||||
108
agent/mock_provider/mock_langchain_model.py
Normal file
108
agent/mock_provider/mock_langchain_model.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""
|
||||||
|
Mock LangChain-compatible chat model for development mode
|
||||||
|
|
||||||
|
Wraps MockAIProvider to work with LangChain's agent framework.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, List, Optional, Dict
|
||||||
|
from langchain_core.language_models import BaseChatModel
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage
|
||||||
|
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||||
|
from .mock_ai_provider import MockAIProvider
|
||||||
|
|
||||||
|
|
||||||
|
class MockChatModel(BaseChatModel):
|
||||||
|
"""
|
||||||
|
Mock chat model compatible with LangChain's agent framework
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
date: Current trading date for response generation
|
||||||
|
step_counter: Tracks reasoning steps within a trading session
|
||||||
|
provider: MockAIProvider instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
date: str = "2025-01-01"
|
||||||
|
step_counter: int = 0
|
||||||
|
provider: Optional[MockAIProvider] = None
|
||||||
|
|
||||||
|
def __init__(self, date: str = "2025-01-01", **kwargs):
|
||||||
|
"""
|
||||||
|
Initialize mock chat model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date: Trading date for mock responses
|
||||||
|
**kwargs: Additional LangChain model parameters
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.date = date
|
||||||
|
self.step_counter = 0
|
||||||
|
self.provider = MockAIProvider()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return identifier for this LLM type"""
|
||||||
|
return "mock-chat-model"
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[Any] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""
|
||||||
|
Generate mock response (synchronous)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages (ignored in mock)
|
||||||
|
stop: Stop sequences (ignored in mock)
|
||||||
|
run_manager: LangChain run manager
|
||||||
|
**kwargs: Additional generation parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResult with mock AI response
|
||||||
|
"""
|
||||||
|
response_text = self.provider.generate_response(self.date, self.step_counter)
|
||||||
|
self.step_counter += 1
|
||||||
|
|
||||||
|
message = AIMessage(
|
||||||
|
content=response_text,
|
||||||
|
response_metadata={"finish_reason": "stop"}
|
||||||
|
)
|
||||||
|
|
||||||
|
generation = ChatGeneration(message=message)
|
||||||
|
return ChatResult(generations=[generation])
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[Any] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""
|
||||||
|
Generate mock response (asynchronous)
|
||||||
|
|
||||||
|
Same as _generate but async-compatible for LangChain agents.
|
||||||
|
"""
|
||||||
|
return self._generate(messages, stop, run_manager, **kwargs)
|
||||||
|
|
||||||
|
def invoke(self, input: Any, **kwargs) -> AIMessage:
|
||||||
|
"""Synchronous invoke (LangChain compatibility)"""
|
||||||
|
if isinstance(input, list):
|
||||||
|
messages = input
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
result = self._generate(messages, **kwargs)
|
||||||
|
return result.generations[0].message
|
||||||
|
|
||||||
|
async def ainvoke(self, input: Any, **kwargs) -> AIMessage:
|
||||||
|
"""Asynchronous invoke (LangChain compatibility)"""
|
||||||
|
if isinstance(input, list):
|
||||||
|
messages = input
|
||||||
|
else:
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
result = await self._agenerate(messages, **kwargs)
|
||||||
|
return result.generations[0].message
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
import pytest
|
import pytest
|
||||||
|
import asyncio
|
||||||
from agent.mock_provider.mock_ai_provider import MockAIProvider
|
from agent.mock_provider.mock_ai_provider import MockAIProvider
|
||||||
|
from agent.mock_provider.mock_langchain_model import MockChatModel
|
||||||
|
|
||||||
|
|
||||||
def test_mock_provider_rotates_stocks():
|
def test_mock_provider_rotates_stocks():
|
||||||
@@ -32,3 +34,41 @@ def test_mock_provider_valid_json_tool_calls():
|
|||||||
provider = MockAIProvider()
|
provider = MockAIProvider()
|
||||||
response = provider.generate_response("2025-01-01", step=0)
|
response = provider.generate_response("2025-01-01", step=0)
|
||||||
assert "[calls tool_get_price" in response or "get_price" in response.lower()
|
assert "[calls tool_get_price" in response or "get_price" in response.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_chat_model_invoke():
|
||||||
|
"""Test synchronous invoke returns proper message format"""
|
||||||
|
model = MockChatModel(date="2025-01-01")
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "Analyze the market"}]
|
||||||
|
response = model.invoke(messages)
|
||||||
|
|
||||||
|
assert hasattr(response, "content")
|
||||||
|
assert "AAPL" in response.content
|
||||||
|
assert "<FINISH_SIGNAL>" in response.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_chat_model_ainvoke():
|
||||||
|
"""Test asynchronous invoke returns proper message format"""
|
||||||
|
async def run_test():
|
||||||
|
model = MockChatModel(date="2025-01-02")
|
||||||
|
messages = [{"role": "user", "content": "Analyze the market"}]
|
||||||
|
response = await model.ainvoke(messages)
|
||||||
|
|
||||||
|
assert hasattr(response, "content")
|
||||||
|
assert "MSFT" in response.content
|
||||||
|
assert "<FINISH_SIGNAL>" in response.content
|
||||||
|
|
||||||
|
asyncio.run(run_test())
|
||||||
|
|
||||||
|
|
||||||
|
def test_mock_chat_model_different_dates():
|
||||||
|
"""Test that different dates produce different responses"""
|
||||||
|
model1 = MockChatModel(date="2025-01-01")
|
||||||
|
model2 = MockChatModel(date="2025-01-02")
|
||||||
|
|
||||||
|
msg = [{"role": "user", "content": "Trade"}]
|
||||||
|
response1 = model1.invoke(msg)
|
||||||
|
response2 = model2.invoke(msg)
|
||||||
|
|
||||||
|
assert response1.content != response2.content
|
||||||
|
|||||||
Reference in New Issue
Block a user