diff --git a/pyproject.toml b/pyproject.toml index 84aa6df..0b2e74a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,3 +28,6 @@ build-backend = "hatchling.build" [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests/unit", "tests/integration"] +markers = [ + "integration: marks tests as integration tests (require Docker containers)", +] diff --git a/tests/integration/mock_grist/server.py b/tests/integration/mock_grist/server.py index 7731f14..75dee91 100644 --- a/tests/integration/mock_grist/server.py +++ b/tests/integration/mock_grist/server.py @@ -178,6 +178,15 @@ async def modify_column(request): return JSONResponse({}) +async def modify_columns(request): + """PATCH /api/docs/{doc_id}/tables/{table_id}/columns - batch modify columns""" + doc_id = request.path_params["doc_id"] + table_id = request.path_params["table_id"] + body = await request.json() + log_request("PATCH", f"/api/docs/{doc_id}/tables/{table_id}/columns", body) + return JSONResponse({}) + + async def delete_column(request): """DELETE /api/docs/{doc_id}/tables/{table_id}/columns/{col_id}""" doc_id = request.path_params["doc_id"] @@ -199,6 +208,7 @@ app = Starlette( Route("/api/docs/{doc_id}/tables", endpoint=create_tables, methods=["POST"]), Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=get_table_columns), Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=add_column, methods=["POST"]), + Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=modify_columns, methods=["PATCH"]), Route("/api/docs/{doc_id}/tables/{table_id}/columns/{col_id}", endpoint=modify_column, methods=["PATCH"]), Route("/api/docs/{doc_id}/tables/{table_id}/columns/{col_id}", endpoint=delete_column, methods=["DELETE"]), Route("/api/docs/{doc_id}/tables/{table_id}/records", endpoint=get_records), diff --git a/tests/integration/test_mcp_protocol.py b/tests/integration/test_mcp_protocol.py index 0e0d237..a59878d 100644 --- a/tests/integration/test_mcp_protocol.py +++ b/tests/integration/test_mcp_protocol.py @@ -9,12 +9,14 @@ from mcp.client.sse import sse_client GRIST_MCP_URL = os.environ.get("GRIST_MCP_URL", "http://localhost:3000") +GRIST_MCP_TOKEN = os.environ.get("GRIST_MCP_TOKEN", "test-token") @asynccontextmanager async def create_mcp_session(): """Create and yield an MCP session.""" - async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream): + headers = {"Authorization": f"Bearer {GRIST_MCP_TOKEN}"} + async with sse_client(f"{GRIST_MCP_URL}/sse", headers=headers) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() yield session @@ -44,12 +46,14 @@ async def test_mcp_protocol_compliance(services_ready): "add_column", "modify_column", "delete_column", + "get_proxy_documentation", + "request_session_token", ] for expected in expected_tools: assert expected in tool_names, f"Missing tool: {expected}" - assert len(result.tools) == 12, f"Expected 12 tools, got {len(result.tools)}" + assert len(result.tools) == 14, f"Expected 14 tools, got {len(result.tools)}" # Test 3: All tools have descriptions for tool in result.tools: diff --git a/tests/integration/test_tools_integration.py b/tests/integration/test_tools_integration.py index e7d7ff2..022217b 100644 --- a/tests/integration/test_tools_integration.py +++ b/tests/integration/test_tools_integration.py @@ -12,12 +12,14 @@ from mcp.client.sse import sse_client GRIST_MCP_URL = os.environ.get("GRIST_MCP_URL", "http://localhost:3000") MOCK_GRIST_URL = os.environ.get("MOCK_GRIST_URL", "http://localhost:8484") +GRIST_MCP_TOKEN = os.environ.get("GRIST_MCP_TOKEN", "test-token") @asynccontextmanager async def create_mcp_session(): """Create and yield an MCP session.""" - async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream): + headers = {"Authorization": f"Bearer {GRIST_MCP_TOKEN}"} + async with sse_client(f"{GRIST_MCP_URL}/sse", headers=headers) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: await session.initialize() yield session @@ -194,7 +196,7 @@ async def test_all_tools(services_ready): data = json.loads(result.content[0].text) assert "modified" in data log = get_mock_request_log() - patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]] + patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns" in e["path"]] assert len(patch_cols) >= 1 # Test delete_column