diff --git a/api/main.py b/api/main.py index e6123b2..bd45121 100644 --- a/api/main.py +++ b/api/main.py @@ -23,6 +23,7 @@ from api.simulation_worker import SimulationWorker from api.database import get_db_connection from api.price_data_manager import PriceDataManager from api.date_utils import validate_date_range, expand_date_range, get_max_simulation_days +from tools.deployment_config import get_deployment_mode_dict import threading import time @@ -62,6 +63,9 @@ class SimulateTriggerResponse(BaseModel): status: str total_model_days: int message: str + deployment_mode: str + is_dev_mode: bool + preserve_dev_data: Optional[bool] = None class JobProgress(BaseModel): @@ -85,6 +89,9 @@ class JobStatusResponse(BaseModel): total_duration_seconds: Optional[float] = None error: Optional[str] = None details: List[Dict[str, Any]] + deployment_mode: str + is_dev_mode: bool + preserve_dev_data: Optional[bool] = None class HealthResponse(BaseModel): @@ -92,6 +99,9 @@ class HealthResponse(BaseModel): status: str database: str timestamp: str + deployment_mode: str + is_dev_mode: bool + preserve_dev_data: Optional[bool] = None def create_app( @@ -263,11 +273,15 @@ def create_app( if download_info and download_info["rate_limited"]: message += " (rate limit reached - partial data)" + # Get deployment mode info + deployment_info = get_deployment_mode_dict() + response = SimulateTriggerResponse( job_id=job_id, status="pending", total_model_days=len(available_dates) * len(models_to_run), - message=message + message=message, + **deployment_info ) # Add download info if we downloaded @@ -317,6 +331,9 @@ def create_app( # Calculate pending (total - completed - failed) pending = progress["total_model_days"] - progress["completed"] - progress["failed"] + # Get deployment mode info + deployment_info = get_deployment_mode_dict() + return JobStatusResponse( job_id=job["job_id"], status=job["status"], @@ -333,7 +350,8 @@ def create_app( completed_at=job.get("completed_at"), total_duration_seconds=job.get("total_duration_seconds"), error=job.get("error"), - details=details + details=details, + **deployment_info ) except HTTPException: @@ -469,10 +487,14 @@ def create_app( logger.error(f"Database health check failed: {e}") database_status = "disconnected" + # Get deployment mode info + deployment_info = get_deployment_mode_dict() + return HealthResponse( status="healthy" if database_status == "connected" else "unhealthy", database=database_status, - timestamp=datetime.utcnow().isoformat() + "Z" + timestamp=datetime.utcnow().isoformat() + "Z", + **deployment_info ) return app diff --git a/tests/integration/test_api_deployment_flag.py b/tests/integration/test_api_deployment_flag.py new file mode 100644 index 0000000..2bfe917 --- /dev/null +++ b/tests/integration/test_api_deployment_flag.py @@ -0,0 +1,41 @@ +import os +import pytest +from fastapi.testclient import TestClient + + +def test_api_includes_deployment_mode_flag(): + """Test API responses include deployment_mode field""" + os.environ["DEPLOYMENT_MODE"] = "DEV" + + from api.main import app + client = TestClient(app) + + # Test GET /health endpoint (should include deployment info) + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + + assert "deployment_mode" in data + assert data["deployment_mode"] == "DEV" + + +def test_job_response_includes_deployment_mode(): + """Test job creation response includes deployment mode""" + os.environ["DEPLOYMENT_MODE"] = "PROD" + + from api.main import app + client = TestClient(app) + + # Create a test job + config = { + "agent_type": "BaseAgent", + "date_range": {"init_date": "2025-01-01", "end_date": "2025-01-02"}, + "models": [{"name": "test", "basemodel": "mock/test", "signature": "test", "enabled": True}] + } + + response = client.post("/run", json={"config": config}) + + if response.status_code == 200: + data = response.json() + assert "deployment_mode" in data + assert data["deployment_mode"] == "PROD"