mirror of
https://github.com/Xe138/AI-Trader.git
synced 2026-04-01 17:17:24 -04:00
feat: add comprehensive config validation
This commit is contained in:
@@ -2,7 +2,7 @@ import pytest
|
|||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tools.config_merger import load_config, ConfigValidationError, merge_configs
|
from tools.config_merger import load_config, ConfigValidationError, merge_configs, validate_config
|
||||||
|
|
||||||
|
|
||||||
def test_load_config_valid_json():
|
def test_load_config_valid_json():
|
||||||
@@ -76,3 +76,101 @@ def test_merge_configs_does_not_mutate_inputs():
|
|||||||
|
|
||||||
assert default["a"] == 1 # Original unchanged
|
assert default["a"] == 1 # Original unchanged
|
||||||
assert result["a"] == 2
|
assert result["a"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_config_valid():
|
||||||
|
"""Test validation passes for valid config"""
|
||||||
|
config = {
|
||||||
|
"agent_type": "BaseAgent",
|
||||||
|
"models": [
|
||||||
|
{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": True}
|
||||||
|
],
|
||||||
|
"agent_config": {
|
||||||
|
"max_steps": 30,
|
||||||
|
"max_retries": 3,
|
||||||
|
"initial_cash": 10000.0
|
||||||
|
},
|
||||||
|
"log_config": {"log_path": "./data"}
|
||||||
|
}
|
||||||
|
|
||||||
|
validate_config(config) # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_config_missing_required_field():
|
||||||
|
"""Test validation fails for missing required field"""
|
||||||
|
config = {"agent_type": "BaseAgent"} # Missing models, agent_config, log_config
|
||||||
|
|
||||||
|
with pytest.raises(ConfigValidationError, match="Missing required field"):
|
||||||
|
validate_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_config_no_enabled_models():
|
||||||
|
"""Test validation fails when no models are enabled"""
|
||||||
|
config = {
|
||||||
|
"agent_type": "BaseAgent",
|
||||||
|
"models": [
|
||||||
|
{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": False}
|
||||||
|
],
|
||||||
|
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||||
|
"log_config": {"log_path": "./data"}
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ConfigValidationError, match="At least one model must be enabled"):
|
||||||
|
validate_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_config_duplicate_signatures():
|
||||||
|
"""Test validation fails for duplicate model signatures"""
|
||||||
|
config = {
|
||||||
|
"agent_type": "BaseAgent",
|
||||||
|
"models": [
|
||||||
|
{"name": "test1", "basemodel": "openai/gpt-4", "signature": "same", "enabled": True},
|
||||||
|
{"name": "test2", "basemodel": "openai/gpt-5", "signature": "same", "enabled": True}
|
||||||
|
],
|
||||||
|
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||||
|
"log_config": {"log_path": "./data"}
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ConfigValidationError, match="Duplicate model signature"):
|
||||||
|
validate_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_config_invalid_max_steps():
|
||||||
|
"""Test validation fails for invalid max_steps"""
|
||||||
|
config = {
|
||||||
|
"agent_type": "BaseAgent",
|
||||||
|
"models": [{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": True}],
|
||||||
|
"agent_config": {"max_steps": 0, "max_retries": 3, "initial_cash": 10000.0},
|
||||||
|
"log_config": {"log_path": "./data"}
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ConfigValidationError, match="max_steps must be > 0"):
|
||||||
|
validate_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_config_invalid_date_format():
|
||||||
|
"""Test validation fails for invalid date format"""
|
||||||
|
config = {
|
||||||
|
"agent_type": "BaseAgent",
|
||||||
|
"date_range": {"init_date": "2025-13-01", "end_date": "2025-12-31"}, # Invalid month
|
||||||
|
"models": [{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": True}],
|
||||||
|
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||||
|
"log_config": {"log_path": "./data"}
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ConfigValidationError, match="Invalid date format"):
|
||||||
|
validate_config(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_config_end_before_init():
|
||||||
|
"""Test validation fails when end_date before init_date"""
|
||||||
|
config = {
|
||||||
|
"agent_type": "BaseAgent",
|
||||||
|
"date_range": {"init_date": "2025-12-31", "end_date": "2025-01-01"},
|
||||||
|
"models": [{"name": "test", "basemodel": "openai/gpt-4", "signature": "test", "enabled": True}],
|
||||||
|
"agent_config": {"max_steps": 30, "max_retries": 3, "initial_cash": 10000.0},
|
||||||
|
"log_config": {"log_path": "./data"}
|
||||||
|
}
|
||||||
|
|
||||||
|
with pytest.raises(ConfigValidationError, match="init_date must be <= end_date"):
|
||||||
|
validate_config(config)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
class ConfigValidationError(Exception):
|
class ConfigValidationError(Exception):
|
||||||
@@ -56,3 +57,91 @@ def merge_configs(default: Dict[str, Any], custom: Dict[str, Any]) -> Dict[str,
|
|||||||
merged[key] = value
|
merged[key] = value
|
||||||
|
|
||||||
return merged
|
return merged
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config(config: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Validate configuration structure and values.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ConfigValidationError: If validation fails with detailed message
|
||||||
|
"""
|
||||||
|
# Required top-level fields
|
||||||
|
required_fields = ["agent_type", "models", "agent_config", "log_config"]
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in config:
|
||||||
|
raise ConfigValidationError(f"Missing required field: '{field}'")
|
||||||
|
|
||||||
|
# Validate models
|
||||||
|
models = config["models"]
|
||||||
|
if not isinstance(models, list) or len(models) == 0:
|
||||||
|
raise ConfigValidationError("'models' must be a non-empty array")
|
||||||
|
|
||||||
|
# Check at least one enabled model
|
||||||
|
enabled_models = [m for m in models if m.get("enabled", False)]
|
||||||
|
if not enabled_models:
|
||||||
|
raise ConfigValidationError("At least one model must be enabled")
|
||||||
|
|
||||||
|
# Check required model fields
|
||||||
|
for i, model in enumerate(models):
|
||||||
|
required_model_fields = ["name", "basemodel", "signature", "enabled"]
|
||||||
|
for field in required_model_fields:
|
||||||
|
if field not in model:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
f"Model {i} missing required field: '{field}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for duplicate signatures
|
||||||
|
signatures = [m["signature"] for m in models]
|
||||||
|
if len(signatures) != len(set(signatures)):
|
||||||
|
duplicates = [s for s in signatures if signatures.count(s) > 1]
|
||||||
|
raise ConfigValidationError(
|
||||||
|
f"Duplicate model signature: {duplicates[0]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate agent_config
|
||||||
|
agent_config = config["agent_config"]
|
||||||
|
|
||||||
|
if "max_steps" in agent_config:
|
||||||
|
if agent_config["max_steps"] <= 0:
|
||||||
|
raise ConfigValidationError("max_steps must be > 0")
|
||||||
|
|
||||||
|
if "max_retries" in agent_config:
|
||||||
|
if agent_config["max_retries"] < 0:
|
||||||
|
raise ConfigValidationError("max_retries must be >= 0")
|
||||||
|
|
||||||
|
if "initial_cash" in agent_config:
|
||||||
|
if agent_config["initial_cash"] <= 0:
|
||||||
|
raise ConfigValidationError("initial_cash must be > 0")
|
||||||
|
|
||||||
|
# Validate date_range if present (optional)
|
||||||
|
if "date_range" in config:
|
||||||
|
date_range = config["date_range"]
|
||||||
|
|
||||||
|
if "init_date" in date_range:
|
||||||
|
try:
|
||||||
|
init_dt = datetime.strptime(date_range["init_date"], "%Y-%m-%d")
|
||||||
|
except ValueError:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
f"Invalid date format for init_date: {date_range['init_date']}. "
|
||||||
|
"Expected YYYY-MM-DD"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "end_date" in date_range:
|
||||||
|
try:
|
||||||
|
end_dt = datetime.strptime(date_range["end_date"], "%Y-%m-%d")
|
||||||
|
except ValueError:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
f"Invalid date format for end_date: {date_range['end_date']}. "
|
||||||
|
"Expected YYYY-MM-DD"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check init <= end
|
||||||
|
if "init_date" in date_range and "end_date" in date_range:
|
||||||
|
if init_dt > end_dt:
|
||||||
|
raise ConfigValidationError(
|
||||||
|
f"init_date must be <= end_date (got {date_range['init_date']} > {date_range['end_date']})"
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user