From e689a78b3f64886118f40c6417ab40d0c913654f Mon Sep 17 00:00:00 2001 From: Bill Date: Thu, 6 Nov 2025 07:47:56 -0500 Subject: [PATCH] feat: add model factory for provider-specific chat models Implements factory pattern to create appropriate chat model based on provider prefix in basemodel string. Supported providers: - deepseek/*: Uses ChatDeepSeek (native tool calling) - openai/*: Uses ChatOpenAI - others: Fall back to ChatOpenAI (OpenAI-compatible) This enables native DeepSeek integration while maintaining backward compatibility with existing OpenAI-compatible providers. --- agent/model_factory.py | 68 +++++++++++++++++++ tests/unit/test_model_factory.py | 108 +++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+) create mode 100644 agent/model_factory.py create mode 100644 tests/unit/test_model_factory.py diff --git a/agent/model_factory.py b/agent/model_factory.py new file mode 100644 index 0000000..b90be60 --- /dev/null +++ b/agent/model_factory.py @@ -0,0 +1,68 @@ +""" +Model factory for creating provider-specific chat models. + +Supports multiple AI providers with native integrations where available: +- DeepSeek: Uses ChatDeepSeek for native tool calling support +- OpenAI: Uses ChatOpenAI +- Others: Fall back to ChatOpenAI (OpenAI-compatible endpoints) +""" + +from typing import Any +from langchain_openai import ChatOpenAI +from langchain_deepseek import ChatDeepSeek + + +def create_model( + basemodel: str, + api_key: str, + base_url: str, + temperature: float, + timeout: int +) -> Any: + """ + Create appropriate chat model based on provider. + + Args: + basemodel: Model identifier (e.g., "deepseek/deepseek-chat", "openai/gpt-4") + api_key: API key for the provider + base_url: Base URL for API endpoint + temperature: Sampling temperature (0-1) + timeout: Request timeout in seconds + + Returns: + Provider-specific chat model instance + + Examples: + >>> model = create_model("deepseek/deepseek-chat", "key", "url", 0.7, 30) + >>> isinstance(model, ChatDeepSeek) + True + + >>> model = create_model("openai/gpt-4", "key", "url", 0.7, 30) + >>> isinstance(model, ChatOpenAI) + True + """ + # Extract provider from basemodel (format: "provider/model-name") + provider = basemodel.split("/")[0].lower() if "/" in basemodel else "unknown" + + if provider == "deepseek": + # Use native ChatDeepSeek for DeepSeek models + # Extract model name without provider prefix + model_name = basemodel.split("/", 1)[1] if "/" in basemodel else basemodel + + return ChatDeepSeek( + model=model_name, + api_key=api_key, + base_url=base_url, + temperature=temperature, + timeout=timeout + ) + else: + # Use ChatOpenAI for OpenAI and OpenAI-compatible endpoints + # (Anthropic, Google, Qwen, etc. via compatibility layer) + return ChatOpenAI( + model=basemodel, + api_key=api_key, + base_url=base_url, + temperature=temperature, + timeout=timeout + ) diff --git a/tests/unit/test_model_factory.py b/tests/unit/test_model_factory.py new file mode 100644 index 0000000..056bd9b --- /dev/null +++ b/tests/unit/test_model_factory.py @@ -0,0 +1,108 @@ +"""Unit tests for model factory - provider-specific model creation""" + +import pytest +from unittest.mock import Mock, patch +from agent.model_factory import create_model + + +class TestModelFactory: + """Tests for create_model factory function""" + + @patch('agent.model_factory.ChatDeepSeek') + def test_create_model_deepseek(self, mock_deepseek_class): + """Test that DeepSeek models use ChatDeepSeek""" + mock_model = Mock() + mock_deepseek_class.return_value = mock_model + + result = create_model( + basemodel="deepseek/deepseek-chat", + api_key="test-key", + base_url="https://api.deepseek.com", + temperature=0.7, + timeout=30 + ) + + # Verify ChatDeepSeek was called with correct params + mock_deepseek_class.assert_called_once_with( + model="deepseek-chat", # Extracted from "deepseek/deepseek-chat" + api_key="test-key", + base_url="https://api.deepseek.com", + temperature=0.7, + timeout=30 + ) + assert result == mock_model + + @patch('agent.model_factory.ChatOpenAI') + def test_create_model_openai(self, mock_openai_class): + """Test that OpenAI models use ChatOpenAI""" + mock_model = Mock() + mock_openai_class.return_value = mock_model + + result = create_model( + basemodel="openai/gpt-4", + api_key="test-key", + base_url="https://api.openai.com/v1", + temperature=0.7, + timeout=30 + ) + + # Verify ChatOpenAI was called with correct params + mock_openai_class.assert_called_once_with( + model="openai/gpt-4", + api_key="test-key", + base_url="https://api.openai.com/v1", + temperature=0.7, + timeout=30 + ) + assert result == mock_model + + @patch('agent.model_factory.ChatOpenAI') + def test_create_model_anthropic(self, mock_openai_class): + """Test that Anthropic models use ChatOpenAI (via compatibility)""" + mock_model = Mock() + mock_openai_class.return_value = mock_model + + result = create_model( + basemodel="anthropic/claude-sonnet-4.5", + api_key="test-key", + base_url="https://api.anthropic.com/v1", + temperature=0.7, + timeout=30 + ) + + # Verify ChatOpenAI was used (Anthropic via OpenAI-compatible endpoint) + mock_openai_class.assert_called_once() + assert result == mock_model + + @patch('agent.model_factory.ChatOpenAI') + def test_create_model_generic_provider(self, mock_openai_class): + """Test that unknown providers default to ChatOpenAI""" + mock_model = Mock() + mock_openai_class.return_value = mock_model + + result = create_model( + basemodel="custom/custom-model", + api_key="test-key", + base_url="https://api.custom.com", + temperature=0.7, + timeout=30 + ) + + # Should fall back to ChatOpenAI for unknown providers + mock_openai_class.assert_called_once() + assert result == mock_model + + def test_create_model_deepseek_extracts_model_name(self): + """Test that DeepSeek model name is extracted correctly""" + with patch('agent.model_factory.ChatDeepSeek') as mock_class: + create_model( + basemodel="deepseek/deepseek-chat-v3.1", + api_key="key", + base_url="url", + temperature=0, + timeout=30 + ) + + # Check that model param is just "deepseek-chat-v3.1" + call_kwargs = mock_class.call_args[1] + assert call_kwargs['model'] == "deepseek-chat-v3.1"