mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-09 20:27:25 -04:00
feat: add model factory for provider-specific chat models
Implements factory pattern to create appropriate chat model based on provider prefix in basemodel string. Supported providers: - deepseek/*: Uses ChatDeepSeek (native tool calling) - openai/*: Uses ChatOpenAI - others: Fall back to ChatOpenAI (OpenAI-compatible) This enables native DeepSeek integration while maintaining backward compatibility with existing OpenAI-compatible providers.
This commit is contained in:
68
agent/model_factory.py
Normal file
68
agent/model_factory.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Model factory for creating provider-specific chat models.
|
||||
|
||||
Supports multiple AI providers with native integrations where available:
|
||||
- DeepSeek: Uses ChatDeepSeek for native tool calling support
|
||||
- OpenAI: Uses ChatOpenAI
|
||||
- Others: Fall back to ChatOpenAI (OpenAI-compatible endpoints)
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
|
||||
def create_model(
|
||||
basemodel: str,
|
||||
api_key: str,
|
||||
base_url: str,
|
||||
temperature: float,
|
||||
timeout: int
|
||||
) -> Any:
|
||||
"""
|
||||
Create appropriate chat model based on provider.
|
||||
|
||||
Args:
|
||||
basemodel: Model identifier (e.g., "deepseek/deepseek-chat", "openai/gpt-4")
|
||||
api_key: API key for the provider
|
||||
base_url: Base URL for API endpoint
|
||||
temperature: Sampling temperature (0-1)
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Provider-specific chat model instance
|
||||
|
||||
Examples:
|
||||
>>> model = create_model("deepseek/deepseek-chat", "key", "url", 0.7, 30)
|
||||
>>> isinstance(model, ChatDeepSeek)
|
||||
True
|
||||
|
||||
>>> model = create_model("openai/gpt-4", "key", "url", 0.7, 30)
|
||||
>>> isinstance(model, ChatOpenAI)
|
||||
True
|
||||
"""
|
||||
# Extract provider from basemodel (format: "provider/model-name")
|
||||
provider = basemodel.split("/")[0].lower() if "/" in basemodel else "unknown"
|
||||
|
||||
if provider == "deepseek":
|
||||
# Use native ChatDeepSeek for DeepSeek models
|
||||
# Extract model name without provider prefix
|
||||
model_name = basemodel.split("/", 1)[1] if "/" in basemodel else basemodel
|
||||
|
||||
return ChatDeepSeek(
|
||||
model=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
timeout=timeout
|
||||
)
|
||||
else:
|
||||
# Use ChatOpenAI for OpenAI and OpenAI-compatible endpoints
|
||||
# (Anthropic, Google, Qwen, etc. via compatibility layer)
|
||||
return ChatOpenAI(
|
||||
model=basemodel,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
temperature=temperature,
|
||||
timeout=timeout
|
||||
)
|
||||
Reference in New Issue
Block a user