mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-02 09:37:23 -04:00
Compare commits
11 Commits
v0.4.0
...
v0.4.2-alp
| Author | SHA1 | Date | |
|---|---|---|---|
| e8939be04e | |||
| 2e0cf4d507 | |||
| 7b35394ce7 | |||
| 2d41717b2b | |||
| 7c4874715b | |||
| 6d30244fc9 | |||
| 0641ce554a | |||
| 0c6de5b74b | |||
| 0f49977700 | |||
| 27a824f4a6 | |||
| 3e50868a4d |
63
CHANGELOG.md
63
CHANGELOG.md
@@ -7,7 +7,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.4.0] - 2025-11-04
|
||||
### Fixed
|
||||
- Fixed Pydantic validation errors when using DeepSeek models via OpenRouter
|
||||
- Root cause: DeepSeek returns tool_calls in non-standard format with `args` field directly, bypassing LangChain's `parse_tool_call()`
|
||||
- Solution: Added `ToolCallArgsParsingWrapper` that normalizes non-standard tool_call format to OpenAI standard before LangChain processing
|
||||
- Wrapper converts `{name, args, id}` → `{function: {name, arguments}, id}` format
|
||||
- Includes diagnostic logging to identify format inconsistencies across providers
|
||||
|
||||
## [0.4.1] - 2025-11-06
|
||||
|
||||
### Fixed
|
||||
- Fixed "No trading" message always displaying despite trading activity by initializing `IF_TRADE` to `True` (trades expected by default)
|
||||
- Root cause: `IF_TRADE` was initialized to `False` in runtime config but never updated when trades executed
|
||||
|
||||
### Note
|
||||
- ChatDeepSeek integration was reverted as it conflicts with OpenRouter unified gateway architecture
|
||||
- System uses `OPENAI_API_BASE` (OpenRouter) with single `OPENAI_API_KEY` for all providers
|
||||
- Sporadic DeepSeek validation errors appear to be transient and do not require code changes
|
||||
|
||||
## [0.4.0] - 2025-11-05
|
||||
|
||||
### BREAKING CHANGES
|
||||
|
||||
@@ -130,6 +148,49 @@ New `/results?reasoning=full` returns:
|
||||
- Test coverage increased with 36+ new comprehensive tests
|
||||
- Documentation updated with complete API reference and database schema details
|
||||
|
||||
### Fixed
|
||||
- **Critical:** Intra-day position tracking for sell-then-buy trades (e20dce7)
|
||||
- Sell proceeds now immediately available for subsequent buy orders within same trading session
|
||||
- ContextInjector maintains in-memory position state during trading sessions
|
||||
- Position updates accumulate after each successful trade
|
||||
- Enables agents to rebalance portfolios (sell + buy) in single session
|
||||
- Added 13 comprehensive tests for position tracking
|
||||
- **Critical:** Tool message extraction in conversation history (462de3a, abb9cd0)
|
||||
- Fixed bug where tool messages (buy/sell trades) were not captured when agent completed in single step
|
||||
- Tool extraction now happens BEFORE finish signal check
|
||||
- Reasoning summaries now accurately reflect actual trades executed
|
||||
- Resolves issue where summarizer saw 0 tools despite multiple trades
|
||||
- Reasoning summary generation improvements (6d126db)
|
||||
- Summaries now explicitly mention specific trades executed (symbols, quantities, actions)
|
||||
- Added TRADES EXECUTED section highlighting tool calls
|
||||
- Example: 'sold 1 GOOGL and 1 AMZN to reduce exposure' instead of 'maintain core holdings'
|
||||
- Final holdings calculation accuracy (a8d912b)
|
||||
- Final positions now calculated from actions instead of querying incomplete database records
|
||||
- Correctly handles first trading day with multiple trades
|
||||
- New `_calculate_final_position_from_actions()` method applies all trades to calculate final state
|
||||
- Holdings now persist correctly across all trading days
|
||||
- Added 3 comprehensive tests for final position calculation
|
||||
- Holdings persistence between trading days (aa16480)
|
||||
- Query now retrieves previous day's ending position as current day's starting position
|
||||
- Changed query from `date <=` to `date <` to prevent returning incomplete current-day records
|
||||
- Fixes empty starting_position/final_position in API responses despite successful trades
|
||||
- Updated tests to verify correct previous-day retrieval
|
||||
- Context injector trading_day_id synchronization (05620fa)
|
||||
- ContextInjector now updated with trading_day_id after record creation
|
||||
- Fixes "Trade failed: trading_day_id not found in runtime config" error
|
||||
- MCP tools now correctly receive trading_day_id via context injection
|
||||
- Schema migration compatibility fixes (7c71a04)
|
||||
- Updated position queries to use new trading_days schema instead of obsolete positions table
|
||||
- Removed obsolete add_no_trade_record_to_db function calls
|
||||
- Fixes "no such table: positions" error
|
||||
- Simplified _handle_trading_result logic
|
||||
- Database referential integrity (9da65c2)
|
||||
- Corrected Database default path from "data/trading.db" to "data/jobs.db"
|
||||
- Ensures all components use same database file
|
||||
- Fixes FOREIGN KEY constraint failures when creating trading_day records
|
||||
- Debug logging cleanup (1e7bdb5)
|
||||
- Removed verbose debug logging from ContextInjector for cleaner output
|
||||
|
||||
## [0.3.1] - 2025-11-03
|
||||
|
||||
### Fixed
|
||||
|
||||
80
ROADMAP.md
80
ROADMAP.md
@@ -4,6 +4,78 @@ This document outlines planned features and improvements for the AI-Trader proje
|
||||
|
||||
## Release Planning
|
||||
|
||||
### v0.5.0 - Performance Metrics & Status APIs (Planned)
|
||||
|
||||
**Focus:** Enhanced observability and performance tracking
|
||||
|
||||
#### Performance Metrics API
|
||||
- **Performance Summary Endpoint** - Query model performance over date ranges
|
||||
- `GET /metrics/performance` - Aggregated performance metrics
|
||||
- Query parameters: `model`, `start_date`, `end_date`
|
||||
- Returns comprehensive performance summary:
|
||||
- Total return (dollar amount and percentage)
|
||||
- Number of trades executed (buy + sell)
|
||||
- Win rate (profitable trading days / total trading days)
|
||||
- Average daily P&L (profit and loss)
|
||||
- Best/worst trading day (highest/lowest daily P&L)
|
||||
- Final portfolio value (cash + holdings at market value)
|
||||
- Number of trading days in queried range
|
||||
- Starting vs. ending portfolio comparison
|
||||
- Use cases:
|
||||
- Compare model performance across different time periods
|
||||
- Evaluate strategy effectiveness
|
||||
- Identify top-performing models
|
||||
- Example: `GET /metrics/performance?model=gpt-4&start_date=2025-01-01&end_date=2025-01-31`
|
||||
- Filtering options:
|
||||
- Single model or all models
|
||||
- Custom date ranges
|
||||
- Exclude incomplete trading days
|
||||
- Response format: JSON with clear metric definitions
|
||||
|
||||
#### Status & Coverage Endpoint
|
||||
- **System Status Summary** - Data availability and simulation progress
|
||||
- `GET /status` - Comprehensive system status
|
||||
- Price data coverage section:
|
||||
- Available symbols (NASDAQ 100 constituents)
|
||||
- Date range of downloaded price data per symbol
|
||||
- Total trading days with complete data
|
||||
- Missing data gaps (symbols without data, date gaps)
|
||||
- Last data refresh timestamp
|
||||
- Model simulation status section:
|
||||
- List of all configured models (enabled/disabled)
|
||||
- Date ranges simulated per model (first and last trading day)
|
||||
- Total trading days completed per model
|
||||
- Most recent simulation date per model
|
||||
- Completion percentage (simulated days / available data days)
|
||||
- System health section:
|
||||
- Database connectivity status
|
||||
- MCP services status (Math, Search, Trade, LocalPrices)
|
||||
- API version and deployment mode
|
||||
- Disk space usage (database size, log size)
|
||||
- Use cases:
|
||||
- Verify data availability before triggering simulations
|
||||
- Identify which models need updates to latest data
|
||||
- Monitor system health and readiness
|
||||
- Plan data downloads for missing date ranges
|
||||
- Example: `GET /status` (no parameters required)
|
||||
- Benefits:
|
||||
- Single endpoint for complete system overview
|
||||
- No need to query multiple endpoints for status
|
||||
- Clear visibility into data gaps
|
||||
- Track simulation progress across models
|
||||
|
||||
#### Implementation Details
|
||||
- Database queries for efficient metric calculation
|
||||
- Caching for frequently accessed metrics (optional)
|
||||
- Response time target: <500ms for typical queries
|
||||
- Comprehensive error handling for missing data
|
||||
|
||||
#### Benefits
|
||||
- **Better Observability** - Clear view of system state and model performance
|
||||
- **Data-Driven Decisions** - Quantitative metrics for model comparison
|
||||
- **Proactive Monitoring** - Identify data gaps before simulations fail
|
||||
- **User Experience** - Single endpoint to check "what's available and what's been done"
|
||||
|
||||
### v1.0.0 - Production Stability & Validation (Planned)
|
||||
|
||||
**Focus:** Comprehensive testing, documentation, and production readiness
|
||||
@@ -607,11 +679,13 @@ To propose a new feature:
|
||||
|
||||
- **v0.1.0** - Initial release with batch execution
|
||||
- **v0.2.0** - Docker deployment support
|
||||
- **v0.3.0** - REST API, on-demand downloads, database storage (current)
|
||||
- **v0.3.0** - REST API, on-demand downloads, database storage
|
||||
- **v0.4.0** - Daily P&L calculation, day-centric results API, reasoning summaries (current)
|
||||
- **v0.5.0** - Performance metrics & status APIs (planned)
|
||||
- **v1.0.0** - Production stability & validation (planned)
|
||||
- **v1.1.0** - API authentication & security (planned)
|
||||
- **v1.2.0** - Position history & analytics (planned)
|
||||
- **v1.3.0** - Performance metrics & analytics (planned)
|
||||
- **v1.3.0** - Advanced performance metrics & analytics (planned)
|
||||
- **v1.4.0** - Data management API (planned)
|
||||
- **v1.5.0** - Web dashboard UI (planned)
|
||||
- **v1.6.0** - Advanced configuration & customization (planned)
|
||||
@@ -619,4 +693,4 @@ To propose a new feature:
|
||||
|
||||
---
|
||||
|
||||
Last updated: 2025-11-01
|
||||
Last updated: 2025-11-06
|
||||
|
||||
@@ -33,6 +33,7 @@ from tools.deployment_config import (
|
||||
from agent.context_injector import ContextInjector
|
||||
from agent.pnl_calculator import DailyPnLCalculator
|
||||
from agent.reasoning_summarizer import ReasoningSummarizer
|
||||
from agent.chat_model_wrapper import ToolCallArgsParsingWrapper
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv()
|
||||
@@ -211,14 +212,16 @@ class BaseAgent:
|
||||
self.model = MockChatModel(date="2025-01-01") # Date will be updated per session
|
||||
print(f"🤖 Using MockChatModel (DEV mode)")
|
||||
else:
|
||||
self.model = ChatOpenAI(
|
||||
base_model = ChatOpenAI(
|
||||
model=self.basemodel,
|
||||
base_url=self.openai_base_url,
|
||||
api_key=self.openai_api_key,
|
||||
max_retries=3,
|
||||
timeout=30
|
||||
)
|
||||
print(f"🤖 Using {self.basemodel} (PROD mode)")
|
||||
# Wrap model with diagnostic wrapper
|
||||
self.model = ToolCallArgsParsingWrapper(model=base_model)
|
||||
print(f"🤖 Using {self.basemodel} (PROD mode) with diagnostic wrapper")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"❌ Failed to initialize AI model: {e}")
|
||||
|
||||
|
||||
141
agent/chat_model_wrapper.py
Normal file
141
agent/chat_model_wrapper.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Chat model wrapper to fix tool_calls args parsing issues.
|
||||
|
||||
DeepSeek and other providers return tool_calls.args as JSON strings, which need
|
||||
to be parsed to dicts before AIMessage construction.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Optional, Dict
|
||||
from functools import wraps
|
||||
|
||||
|
||||
class ToolCallArgsParsingWrapper:
|
||||
"""
|
||||
Wrapper that adds diagnostic logging and fixes tool_calls args if needed.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Any, **kwargs):
|
||||
"""
|
||||
Initialize wrapper around a chat model.
|
||||
|
||||
Args:
|
||||
model: The chat model to wrap
|
||||
**kwargs: Additional parameters (ignored, for compatibility)
|
||||
"""
|
||||
self.wrapped_model = model
|
||||
self._patch_model()
|
||||
|
||||
def _patch_model(self):
|
||||
"""Monkey-patch the model's _create_chat_result to add diagnostics"""
|
||||
if not hasattr(self.wrapped_model, '_create_chat_result'):
|
||||
# Model doesn't have this method (e.g., MockChatModel), skip patching
|
||||
return
|
||||
|
||||
original_create_chat_result = self.wrapped_model._create_chat_result
|
||||
|
||||
@wraps(original_create_chat_result)
|
||||
def patched_create_chat_result(response: Any, generation_info: Optional[Dict] = None):
|
||||
"""Patched version with diagnostic logging and args parsing"""
|
||||
response_dict = response if isinstance(response, dict) else response.model_dump()
|
||||
|
||||
# DIAGNOSTIC: Log response structure for debugging
|
||||
print(f"\n[DIAGNOSTIC] Response structure:")
|
||||
print(f" Response keys: {list(response_dict.keys())}")
|
||||
|
||||
if 'choices' in response_dict and response_dict['choices']:
|
||||
choice = response_dict['choices'][0]
|
||||
print(f" Choice keys: {list(choice.keys())}")
|
||||
|
||||
if 'message' in choice:
|
||||
message = choice['message']
|
||||
print(f" Message keys: {list(message.keys())}")
|
||||
|
||||
# Check for raw tool_calls in message (before parse_tool_call processing)
|
||||
if 'tool_calls' in message:
|
||||
tool_calls_value = message['tool_calls']
|
||||
print(f" message['tool_calls'] type: {type(tool_calls_value)}")
|
||||
|
||||
if tool_calls_value:
|
||||
print(f" tool_calls count: {len(tool_calls_value)}")
|
||||
for i, tc in enumerate(tool_calls_value): # Show ALL
|
||||
print(f" tool_calls[{i}] type: {type(tc)}")
|
||||
print(f" tool_calls[{i}] keys: {list(tc.keys()) if isinstance(tc, dict) else 'N/A'}")
|
||||
if isinstance(tc, dict):
|
||||
if 'function' in tc:
|
||||
print(f" function keys: {list(tc['function'].keys())}")
|
||||
if 'arguments' in tc['function']:
|
||||
args = tc['function']['arguments']
|
||||
print(f" function.arguments type: {type(args).__name__}")
|
||||
print(f" function.arguments value: {str(args)[:100]}")
|
||||
if 'args' in tc:
|
||||
print(f" ALSO HAS 'args' KEY: type={type(tc['args']).__name__}")
|
||||
print(f" args value: {str(tc['args'])[:100]}")
|
||||
|
||||
# Fix tool_calls: Normalize to OpenAI format if needed
|
||||
if 'choices' in response_dict:
|
||||
for choice in response_dict['choices']:
|
||||
if 'message' not in choice:
|
||||
continue
|
||||
|
||||
message = choice['message']
|
||||
|
||||
# Fix tool_calls: Ensure standard OpenAI format
|
||||
if 'tool_calls' in message and message['tool_calls']:
|
||||
print(f"[DIAGNOSTIC] Processing {len(message['tool_calls'])} tool_calls...")
|
||||
for idx, tool_call in enumerate(message['tool_calls']):
|
||||
# Check if this is non-standard format (has 'args' directly)
|
||||
if 'args' in tool_call and 'function' not in tool_call:
|
||||
print(f"[DIAGNOSTIC] tool_calls[{idx}] has non-standard format (direct args)")
|
||||
# Convert to standard OpenAI format
|
||||
args = tool_call['args']
|
||||
tool_call['function'] = {
|
||||
'name': tool_call.get('name', ''),
|
||||
'arguments': args if isinstance(args, str) else json.dumps(args)
|
||||
}
|
||||
# Remove non-standard fields
|
||||
if 'name' in tool_call:
|
||||
del tool_call['name']
|
||||
if 'args' in tool_call:
|
||||
del tool_call['args']
|
||||
print(f"[DIAGNOSTIC] Converted tool_calls[{idx}] to standard OpenAI format")
|
||||
|
||||
# Fix invalid_tool_calls: dict args -> string
|
||||
if 'invalid_tool_calls' in message and message['invalid_tool_calls']:
|
||||
print(f"[DIAGNOSTIC] Checking invalid_tool_calls for dict-to-string conversion...")
|
||||
for idx, invalid_call in enumerate(message['invalid_tool_calls']):
|
||||
if 'args' in invalid_call:
|
||||
args = invalid_call['args']
|
||||
# Convert dict arguments to JSON string
|
||||
if isinstance(args, dict):
|
||||
try:
|
||||
invalid_call['args'] = json.dumps(args)
|
||||
print(f"[DIAGNOSTIC] Converted invalid_tool_calls[{idx}].args from dict to string")
|
||||
except (TypeError, ValueError) as e:
|
||||
print(f"[DIAGNOSTIC] Failed to serialize invalid_tool_calls[{idx}].args: {e}")
|
||||
# Keep as-is if serialization fails
|
||||
|
||||
# Call original method with fixed response
|
||||
return original_create_chat_result(response_dict, generation_info)
|
||||
|
||||
# Replace the method
|
||||
self.wrapped_model._create_chat_result = patched_create_chat_result
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return identifier for this LLM type"""
|
||||
if hasattr(self.wrapped_model, '_llm_type'):
|
||||
return f"wrapped-{self.wrapped_model._llm_type}"
|
||||
return "wrapped-chat-model"
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""Proxy all attributes/methods to the wrapped model"""
|
||||
return getattr(self.wrapped_model, name)
|
||||
|
||||
def bind_tools(self, tools: Any, **kwargs):
|
||||
"""Bind tools to the wrapped model"""
|
||||
return self.wrapped_model.bind_tools(tools, **kwargs)
|
||||
|
||||
def bind(self, **kwargs):
|
||||
"""Bind settings to the wrapped model"""
|
||||
return self.wrapped_model.bind(**kwargs)
|
||||
@@ -80,7 +80,7 @@ class RuntimeConfigManager:
|
||||
initial_config = {
|
||||
"TODAY_DATE": date,
|
||||
"SIGNATURE": model_sig,
|
||||
"IF_TRADE": False,
|
||||
"IF_TRADE": True, # FIX: Trades are expected by default
|
||||
"JOB_ID": job_id,
|
||||
"TRADING_DAY_ID": trading_day_id
|
||||
}
|
||||
|
||||
216
tests/unit/test_chat_model_wrapper.py
Normal file
216
tests/unit/test_chat_model_wrapper.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Unit tests for ChatModelWrapper - tool_calls args parsing fix
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.outputs import ChatResult, ChatGeneration
|
||||
|
||||
from agent.chat_model_wrapper import ToolCallArgsParsingWrapper
|
||||
|
||||
|
||||
class TestToolCallArgsParsingWrapper:
|
||||
"""Tests for ToolCallArgsParsingWrapper"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model(self):
|
||||
"""Create a mock chat model"""
|
||||
model = Mock()
|
||||
model._llm_type = "mock-model"
|
||||
return model
|
||||
|
||||
@pytest.fixture
|
||||
def wrapper(self, mock_model):
|
||||
"""Create a wrapper around mock model"""
|
||||
return ToolCallArgsParsingWrapper(model=mock_model)
|
||||
|
||||
def test_fix_tool_calls_with_string_args(self, wrapper):
|
||||
"""Test that string args are parsed to dict"""
|
||||
# Create message with tool_calls where args is a JSON string
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "AAPL", "amount": 10}', # String, not dict
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "AAPL", "amount": 10}
|
||||
|
||||
def test_fix_tool_calls_with_dict_args(self, wrapper):
|
||||
"""Test that dict args are left unchanged"""
|
||||
# Create message with tool_calls where args is already a dict
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": {"symbol": "AAPL", "amount": 10}, # Already a dict
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is still a dict
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "AAPL", "amount": 10}
|
||||
|
||||
def test_fix_tool_calls_with_invalid_json(self, wrapper):
|
||||
"""Test that invalid JSON string is left unchanged"""
|
||||
# Create message with tool_calls where args is an invalid JSON string
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": 'invalid json {', # Invalid JSON
|
||||
"id": "call_123"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
# Check that args is still a string (parsing failed)
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], str)
|
||||
assert fixed_message.tool_calls[0]['args'] == 'invalid json {'
|
||||
|
||||
def test_fix_tool_calls_no_tool_calls(self, wrapper):
|
||||
"""Test that messages without tool_calls are left unchanged"""
|
||||
message = AIMessage(content="Hello, world!")
|
||||
fixed_message = wrapper._fix_tool_calls(message)
|
||||
|
||||
assert fixed_message == message
|
||||
|
||||
def test_generate_with_string_args(self, wrapper, mock_model):
|
||||
"""Test _generate method with string args"""
|
||||
# Create a response with string args
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "MSFT", "amount": 5}',
|
||||
"id": "call_456"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_result = ChatResult(
|
||||
generations=[ChatGeneration(message=original_message)]
|
||||
)
|
||||
mock_model._generate.return_value = mock_result
|
||||
|
||||
# Call wrapper's _generate
|
||||
result = wrapper._generate(messages=[], stop=None, run_manager=None)
|
||||
|
||||
# Check that args is now a dict
|
||||
fixed_message = result.generations[0].message
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "MSFT", "amount": 5}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agenerate_with_string_args(self, wrapper, mock_model):
|
||||
"""Test _agenerate method with string args"""
|
||||
# Create a response with string args
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "sell",
|
||||
"args": '{"symbol": "GOOGL", "amount": 3}',
|
||||
"id": "call_789"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_result = ChatResult(
|
||||
generations=[ChatGeneration(message=original_message)]
|
||||
)
|
||||
mock_model._agenerate = AsyncMock(return_value=mock_result)
|
||||
|
||||
# Call wrapper's _agenerate
|
||||
result = await wrapper._agenerate(messages=[], stop=None, run_manager=None)
|
||||
|
||||
# Check that args is now a dict
|
||||
fixed_message = result.generations[0].message
|
||||
assert isinstance(fixed_message.tool_calls[0]['args'], dict)
|
||||
assert fixed_message.tool_calls[0]['args'] == {"symbol": "GOOGL", "amount": 3}
|
||||
|
||||
def test_invoke_with_string_args(self, wrapper, mock_model):
|
||||
"""Test invoke method with string args"""
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "buy",
|
||||
"args": '{"symbol": "NVDA", "amount": 20}',
|
||||
"id": "call_999"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_model.invoke.return_value = original_message
|
||||
|
||||
# Call wrapper's invoke
|
||||
result = wrapper.invoke(input=[])
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(result.tool_calls[0]['args'], dict)
|
||||
assert result.tool_calls[0]['args'] == {"symbol": "NVDA", "amount": 20}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ainvoke_with_string_args(self, wrapper, mock_model):
|
||||
"""Test ainvoke method with string args"""
|
||||
original_message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "sell",
|
||||
"args": '{"symbol": "TSLA", "amount": 15}',
|
||||
"id": "call_111"
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
mock_model.ainvoke = AsyncMock(return_value=original_message)
|
||||
|
||||
# Call wrapper's ainvoke
|
||||
result = await wrapper.ainvoke(input=[])
|
||||
|
||||
# Check that args is now a dict
|
||||
assert isinstance(result.tool_calls[0]['args'], dict)
|
||||
assert result.tool_calls[0]['args'] == {"symbol": "TSLA", "amount": 15}
|
||||
|
||||
def test_bind_tools_returns_wrapper(self, wrapper, mock_model):
|
||||
"""Test that bind_tools returns a new wrapper"""
|
||||
mock_bound = Mock()
|
||||
mock_model.bind_tools.return_value = mock_bound
|
||||
|
||||
result = wrapper.bind_tools(tools=[], strict=True)
|
||||
|
||||
# Check that result is a wrapper around the bound model
|
||||
assert isinstance(result, ToolCallArgsParsingWrapper)
|
||||
assert result.wrapped_model == mock_bound
|
||||
|
||||
def test_bind_returns_wrapper(self, wrapper, mock_model):
|
||||
"""Test that bind returns a new wrapper"""
|
||||
mock_bound = Mock()
|
||||
mock_model.bind.return_value = mock_bound
|
||||
|
||||
result = wrapper.bind(max_tokens=100)
|
||||
|
||||
# Check that result is a wrapper around the bound model
|
||||
assert isinstance(result, ToolCallArgsParsingWrapper)
|
||||
assert result.wrapped_model == mock_bound
|
||||
@@ -63,7 +63,7 @@ class TestRuntimeConfigCreation:
|
||||
|
||||
assert config["TODAY_DATE"] == "2025-01-16"
|
||||
assert config["SIGNATURE"] == "gpt-5"
|
||||
assert config["IF_TRADE"] is False
|
||||
assert config["IF_TRADE"] is True
|
||||
assert config["JOB_ID"] == "test-job-123"
|
||||
|
||||
def test_create_runtime_config_unique_paths(self):
|
||||
@@ -108,6 +108,32 @@ class TestRuntimeConfigCreation:
|
||||
# Config file should exist
|
||||
assert os.path.exists(config_path)
|
||||
|
||||
def test_create_runtime_config_if_trade_defaults_true(self):
|
||||
"""Test that IF_TRADE initializes to True (trades expected by default)"""
|
||||
from api.runtime_manager import RuntimeConfigManager
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
manager = RuntimeConfigManager(data_dir=temp_dir)
|
||||
|
||||
config_path = manager.create_runtime_config(
|
||||
job_id="test-job-123",
|
||||
model_sig="test-model",
|
||||
date="2025-01-16",
|
||||
trading_day_id=1
|
||||
)
|
||||
|
||||
try:
|
||||
# Read the config file
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Verify IF_TRADE is True by default
|
||||
assert config["IF_TRADE"] is True, "IF_TRADE should initialize to True"
|
||||
finally:
|
||||
# Cleanup
|
||||
if os.path.exists(config_path):
|
||||
os.remove(config_path)
|
||||
|
||||
|
||||
@pytest.mark.unit
|
||||
class TestRuntimeConfigCleanup:
|
||||
|
||||
Reference in New Issue
Block a user