mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
feat: add model factory for provider-specific chat models
Implements factory pattern to create appropriate chat model based on provider prefix in basemodel string. Supported providers: - deepseek/*: Uses ChatDeepSeek (native tool calling) - openai/*: Uses ChatOpenAI - others: Fall back to ChatOpenAI (OpenAI-compatible) This enables native DeepSeek integration while maintaining backward compatibility with existing OpenAI-compatible providers.
This commit is contained in:
68
agent/model_factory.py
Normal file
68
agent/model_factory.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Model factory for creating provider-specific chat models.
|
||||
|
||||
Supports multiple AI providers with native integrations where available:
|
||||
- DeepSeek: Uses ChatDeepSeek for native tool calling support
|
||||
- OpenAI: Uses ChatOpenAI
|
||||
- Others: Fall back to ChatOpenAI (OpenAI-compatible endpoints)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
|
||||
def create_model(
|
||||
basemodel: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
temperature: float,
|
||||
timeout: int
|
||||
) -> Any:
|
||||
"""
|
||||
Create appropriate chat model based on provider.
|
||||
|
||||
Args:
|
||||
basemodel: Model identifier (e.g., "deepseek/deepseek-chat", "openai/gpt-4")
|
||||
api_key: API key for the provider
|
||||
base_url: Base URL for API endpoint
|
||||
temperature: Sampling temperature (0-1)
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Provider-specific chat model instance
|
||||
|
||||
Examples:
|
||||
>>> model = create_model("deepseek/deepseek-chat", "key", "url", 0.7, 30)
|
||||
>>> isinstance(model, ChatDeepSeek)
|
||||
True
|
||||
|
||||
>>> model = create_model("openai/gpt-4", "key", "url", 0.7, 30)
|
||||
>>> isinstance(model, ChatOpenAI)
|
||||
True
|
||||
"""
|
||||
# Extract provider from basemodel (format: "provider/model-name")
|
||||
provider = basemodel.split("/")[0].lower() if "/" in basemodel else "unknown"
|
||||
|
||||
if provider == "deepseek":
|
||||
# Use native ChatDeepSeek for DeepSeek models
|
||||
# Extract model name without provider prefix
|
||||
model_name = basemodel.split("/", 1)[1] if "/" in basemodel else basemodel
|
||||
|
||||
return ChatDeepSeek(
|
||||
model=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
timeout=timeout
|
||||
)
|
||||
else:
|
||||
# Use ChatOpenAI for OpenAI and OpenAI-compatible endpoints
|
||||
# (Anthropic, Google, Qwen, etc. via compatibility layer)
|
||||
return ChatOpenAI(
|
||||
model=basemodel,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
timeout=timeout
|
||||
)
|
||||
108
tests/unit/test_model_factory.py
Normal file
108
tests/unit/test_model_factory.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Unit tests for model factory - provider-specific model creation"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from agent.model_factory import create_model
|
||||
|
||||
|
||||
class TestModelFactory:
|
||||
"""Tests for create_model factory function"""
|
||||
|
||||
@patch('agent.model_factory.ChatDeepSeek')
|
||||
def test_create_model_deepseek(self, mock_deepseek_class):
|
||||
"""Test that DeepSeek models use ChatDeepSeek"""
|
||||
mock_model = Mock()
|
||||
mock_deepseek_class.return_value = mock_model
|
||||
|
||||
result = create_model(
|
||||
basemodel="deepseek/deepseek-chat",
|
||||
api_key="test-key",
|
||||
base_url="https://api.deepseek.com",
|
||||
temperature=0.7,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Verify ChatDeepSeek was called with correct params
|
||||
mock_deepseek_class.assert_called_once_with(
|
||||
model="deepseek-chat", # Extracted from "deepseek/deepseek-chat"
|
||||
api_key="test-key",
|
||||
base_url="https://api.deepseek.com",
|
||||
temperature=0.7,
|
||||
timeout=30
|
||||
)
|
||||
assert result == mock_model
|
||||
|
||||
@patch('agent.model_factory.ChatOpenAI')
|
||||
def test_create_model_openai(self, mock_openai_class):
|
||||
"""Test that OpenAI models use ChatOpenAI"""
|
||||
mock_model = Mock()
|
||||
mock_openai_class.return_value = mock_model
|
||||
|
||||
result = create_model(
|
||||
basemodel="openai/gpt-4",
|
||||
api_key="test-key",
|
||||
base_url="https://api.openai.com/v1",
|
||||
temperature=0.7,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Verify ChatOpenAI was called with correct params
|
||||
mock_openai_class.assert_called_once_with(
|
||||
model="openai/gpt-4",
|
||||
api_key="test-key",
|
||||
base_url="https://api.openai.com/v1",
|
||||
temperature=0.7,
|
||||
timeout=30
|
||||
)
|
||||
assert result == mock_model
|
||||
|
||||
@patch('agent.model_factory.ChatOpenAI')
|
||||
def test_create_model_anthropic(self, mock_openai_class):
|
||||
"""Test that Anthropic models use ChatOpenAI (via compatibility)"""
|
||||
mock_model = Mock()
|
||||
mock_openai_class.return_value = mock_model
|
||||
|
||||
result = create_model(
|
||||
basemodel="anthropic/claude-sonnet-4.5",
|
||||
api_key="test-key",
|
||||
base_url="https://api.anthropic.com/v1",
|
||||
temperature=0.7,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Verify ChatOpenAI was used (Anthropic via OpenAI-compatible endpoint)
|
||||
mock_openai_class.assert_called_once()
|
||||
assert result == mock_model
|
||||
|
||||
@patch('agent.model_factory.ChatOpenAI')
|
||||
def test_create_model_generic_provider(self, mock_openai_class):
|
||||
"""Test that unknown providers default to ChatOpenAI"""
|
||||
mock_model = Mock()
|
||||
mock_openai_class.return_value = mock_model
|
||||
|
||||
result = create_model(
|
||||
basemodel="custom/custom-model",
|
||||
api_key="test-key",
|
||||
base_url="https://api.custom.com",
|
||||
temperature=0.7,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Should fall back to ChatOpenAI for unknown providers
|
||||
mock_openai_class.assert_called_once()
|
||||
assert result == mock_model
|
||||
|
||||
def test_create_model_deepseek_extracts_model_name(self):
|
||||
"""Test that DeepSeek model name is extracted correctly"""
|
||||
with patch('agent.model_factory.ChatDeepSeek') as mock_class:
|
||||
create_model(
|
||||
basemodel="deepseek/deepseek-chat-v3.1",
|
||||
api_key="key",
|
||||
base_url="url",
|
||||
temperature=0,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# Check that model param is just "deepseek-chat-v3.1"
|
||||
call_kwargs = mock_class.call_args[1]
|
||||
assert call_kwargs['model'] == "deepseek-chat-v3.1"
|
||||
Reference in New Issue
Block a user