mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
feat: add model factory for provider-specific chat models
Implements factory pattern to create appropriate chat model based on provider prefix in basemodel string. Supported providers: - deepseek/*: Uses ChatDeepSeek (native tool calling) - openai/*: Uses ChatOpenAI - others: Fall back to ChatOpenAI (OpenAI-compatible) This enables native DeepSeek integration while maintaining backward compatibility with existing OpenAI-compatible providers.
This commit is contained in:
68
agent/model_factory.py
Normal file
68
agent/model_factory.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""
|
||||||
|
Model factory for creating provider-specific chat models.
|
||||||
|
|
||||||
|
Supports multiple AI providers with native integrations where available:
|
||||||
|
- DeepSeek: Uses ChatDeepSeek for native tool calling support
|
||||||
|
- OpenAI: Uses ChatOpenAI
|
||||||
|
- Others: Fall back to ChatOpenAI (OpenAI-compatible endpoints)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_deepseek import ChatDeepSeek
|
||||||
|
|
||||||
|
|
||||||
|
def create_model(
|
||||||
|
basemodel: str,
|
||||||
|
api_key: str,
|
||||||
|
base_url: str,
|
||||||
|
temperature: float,
|
||||||
|
timeout: int
|
||||||
|
) -> Any:
|
||||||
|
"""
|
||||||
|
Create appropriate chat model based on provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
basemodel: Model identifier (e.g., "deepseek/deepseek-chat", "openai/gpt-4")
|
||||||
|
api_key: API key for the provider
|
||||||
|
base_url: Base URL for API endpoint
|
||||||
|
temperature: Sampling temperature (0-1)
|
||||||
|
timeout: Request timeout in seconds
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Provider-specific chat model instance
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> model = create_model("deepseek/deepseek-chat", "key", "url", 0.7, 30)
|
||||||
|
>>> isinstance(model, ChatDeepSeek)
|
||||||
|
True
|
||||||
|
|
||||||
|
>>> model = create_model("openai/gpt-4", "key", "url", 0.7, 30)
|
||||||
|
>>> isinstance(model, ChatOpenAI)
|
||||||
|
True
|
||||||
|
"""
|
||||||
|
# Extract provider from basemodel (format: "provider/model-name")
|
||||||
|
provider = basemodel.split("/")[0].lower() if "/" in basemodel else "unknown"
|
||||||
|
|
||||||
|
if provider == "deepseek":
|
||||||
|
# Use native ChatDeepSeek for DeepSeek models
|
||||||
|
# Extract model name without provider prefix
|
||||||
|
model_name = basemodel.split("/", 1)[1] if "/" in basemodel else basemodel
|
||||||
|
|
||||||
|
return ChatDeepSeek(
|
||||||
|
model=model_name,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
temperature=temperature,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Use ChatOpenAI for OpenAI and OpenAI-compatible endpoints
|
||||||
|
# (Anthropic, Google, Qwen, etc. via compatibility layer)
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=basemodel,
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
temperature=temperature,
|
||||||
|
timeout=timeout
|
||||||
|
)
|
||||||
108
tests/unit/test_model_factory.py
Normal file
108
tests/unit/test_model_factory.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Unit tests for model factory - provider-specific model creation"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
from agent.model_factory import create_model
|
||||||
|
|
||||||
|
|
||||||
|
class TestModelFactory:
|
||||||
|
"""Tests for create_model factory function"""
|
||||||
|
|
||||||
|
@patch('agent.model_factory.ChatDeepSeek')
|
||||||
|
def test_create_model_deepseek(self, mock_deepseek_class):
|
||||||
|
"""Test that DeepSeek models use ChatDeepSeek"""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_deepseek_class.return_value = mock_model
|
||||||
|
|
||||||
|
result = create_model(
|
||||||
|
basemodel="deepseek/deepseek-chat",
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify ChatDeepSeek was called with correct params
|
||||||
|
mock_deepseek_class.assert_called_once_with(
|
||||||
|
model="deepseek-chat", # Extracted from "deepseek/deepseek-chat"
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.deepseek.com",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
assert result == mock_model
|
||||||
|
|
||||||
|
@patch('agent.model_factory.ChatOpenAI')
|
||||||
|
def test_create_model_openai(self, mock_openai_class):
|
||||||
|
"""Test that OpenAI models use ChatOpenAI"""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_openai_class.return_value = mock_model
|
||||||
|
|
||||||
|
result = create_model(
|
||||||
|
basemodel="openai/gpt-4",
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify ChatOpenAI was called with correct params
|
||||||
|
mock_openai_class.assert_called_once_with(
|
||||||
|
model="openai/gpt-4",
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.openai.com/v1",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
assert result == mock_model
|
||||||
|
|
||||||
|
@patch('agent.model_factory.ChatOpenAI')
|
||||||
|
def test_create_model_anthropic(self, mock_openai_class):
|
||||||
|
"""Test that Anthropic models use ChatOpenAI (via compatibility)"""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_openai_class.return_value = mock_model
|
||||||
|
|
||||||
|
result = create_model(
|
||||||
|
basemodel="anthropic/claude-sonnet-4.5",
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.anthropic.com/v1",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify ChatOpenAI was used (Anthropic via OpenAI-compatible endpoint)
|
||||||
|
mock_openai_class.assert_called_once()
|
||||||
|
assert result == mock_model
|
||||||
|
|
||||||
|
@patch('agent.model_factory.ChatOpenAI')
|
||||||
|
def test_create_model_generic_provider(self, mock_openai_class):
|
||||||
|
"""Test that unknown providers default to ChatOpenAI"""
|
||||||
|
mock_model = Mock()
|
||||||
|
mock_openai_class.return_value = mock_model
|
||||||
|
|
||||||
|
result = create_model(
|
||||||
|
basemodel="custom/custom-model",
|
||||||
|
api_key="test-key",
|
||||||
|
base_url="https://api.custom.com",
|
||||||
|
temperature=0.7,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should fall back to ChatOpenAI for unknown providers
|
||||||
|
mock_openai_class.assert_called_once()
|
||||||
|
assert result == mock_model
|
||||||
|
|
||||||
|
def test_create_model_deepseek_extracts_model_name(self):
|
||||||
|
"""Test that DeepSeek model name is extracted correctly"""
|
||||||
|
with patch('agent.model_factory.ChatDeepSeek') as mock_class:
|
||||||
|
create_model(
|
||||||
|
basemodel="deepseek/deepseek-chat-v3.1",
|
||||||
|
api_key="key",
|
||||||
|
base_url="url",
|
||||||
|
temperature=0,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that model param is just "deepseek-chat-v3.1"
|
||||||
|
call_kwargs = mock_class.call_args[1]
|
||||||
|
assert call_kwargs['model'] == "deepseek-chat-v3.1"
|
||||||
Reference in New Issue
Block a user