diff --git a/src/grist_mcp/main.py b/src/grist_mcp/main.py index e4f0c1b..8d545ad 100644 --- a/src/grist_mcp/main.py +++ b/src/grist_mcp/main.py @@ -2,19 +2,22 @@ import os import sys +from typing import Any import uvicorn from mcp.server.sse import SseServerTransport -from starlette.applications import Starlette -from starlette.responses import JSONResponse -from starlette.routing import Route from grist_mcp.server import create_server from grist_mcp.auth import AuthError -def create_app() -> Starlette: - """Create the Starlette ASGI application.""" +Scope = dict[str, Any] +Receive = Any +Send = Any + + +def create_app(): + """Create the ASGI application.""" config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml") if not os.path.exists(config_path): @@ -29,27 +32,54 @@ def create_app() -> Starlette: sse = SseServerTransport("/messages") - async def handle_sse(request): - async with sse.connect_sse( - request.scope, request.receive, request._send - ) as streams: + async def handle_sse(scope: Scope, receive: Receive, send: Send) -> None: + async with sse.connect_sse(scope, receive, send) as streams: await server.run( streams[0], streams[1], server.create_initialization_options() ) - async def handle_messages(request): - await sse.handle_post_message(request.scope, request.receive, request._send) + async def handle_messages(scope: Scope, receive: Receive, send: Send) -> None: + await sse.handle_post_message(scope, receive, send) - async def handle_health(request): - return JSONResponse({"status": "ok"}) + async def handle_health(scope: Scope, receive: Receive, send: Send) -> None: + await send({ + "type": "http.response.start", + "status": 200, + "headers": [[b"content-type", b"application/json"]], + }) + await send({ + "type": "http.response.body", + "body": b'{"status":"ok"}', + }) - return Starlette( - routes=[ - Route("/health", endpoint=handle_health), - Route("/sse", endpoint=handle_sse), - Route("/messages", endpoint=handle_messages, methods=["POST"]), - ] - ) + async def handle_not_found(scope: Scope, receive: Receive, send: Send) -> None: + await send({ + "type": "http.response.start", + "status": 404, + "headers": [[b"content-type", b"application/json"]], + }) + await send({ + "type": "http.response.body", + "body": b'{"error":"Not found"}', + }) + + async def app(scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] != "http": + return + + path = scope["path"] + method = scope["method"] + + if path == "/health" and method == "GET": + await handle_health(scope, receive, send) + elif path == "/sse" and method == "GET": + await handle_sse(scope, receive, send) + elif path == "/messages" and method == "POST": + await handle_messages(scope, receive, send) + else: + await handle_not_found(scope, receive, send) + + return app def main(): diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 973f42b..5560e55 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -4,8 +4,6 @@ import time import httpx import pytest -from mcp import ClientSession -from mcp.client.sse import sse_client GRIST_MCP_URL = "http://localhost:3000" @@ -35,26 +33,3 @@ def services_ready(): if not wait_for_service(GRIST_MCP_URL): pytest.fail(f"grist-mcp server not ready at {GRIST_MCP_URL}") return True - - -@pytest.fixture -async def mcp_client(services_ready): - """Create an MCP client connected to grist-mcp via SSE.""" - async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - await session.initialize() - yield session - - -@pytest.fixture -def mock_grist_client(services_ready): - """HTTP client for interacting with mock Grist test endpoints.""" - with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client: - yield client - - -@pytest.fixture(autouse=True) -def clear_mock_grist_log(mock_grist_client): - """Clear the mock Grist request log before each test.""" - mock_grist_client.post("/_test/requests/clear") - yield diff --git a/tests/integration/test_mcp_protocol.py b/tests/integration/test_mcp_protocol.py index 327374a..0471b90 100644 --- a/tests/integration/test_mcp_protocol.py +++ b/tests/integration/test_mcp_protocol.py @@ -1,57 +1,61 @@ """Test MCP protocol compliance over SSE transport.""" +from contextlib import asynccontextmanager + import pytest +from mcp import ClientSession +from mcp.client.sse import sse_client + + +GRIST_MCP_URL = "http://localhost:3000" + + +@asynccontextmanager +async def create_mcp_session(): + """Create and yield an MCP session.""" + async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + yield session @pytest.mark.asyncio -async def test_mcp_connection_initializes(mcp_client): - """Test that MCP client can connect and initialize.""" - # If we get here, connection and initialization succeeded - assert mcp_client is not None +async def test_mcp_protocol_compliance(services_ready): + """Test MCP protocol compliance - connection, tools, descriptions, schemas.""" + async with create_mcp_session() as client: + # Test 1: Connection initializes + assert client is not None + # Test 2: list_tools returns all expected tools + result = await client.list_tools() + tool_names = [tool.name for tool in result.tools] -@pytest.mark.asyncio -async def test_list_tools_returns_all_tools(mcp_client): - """Test that list_tools returns all expected tools.""" - result = await mcp_client.list_tools() - tool_names = [tool.name for tool in result.tools] + expected_tools = [ + "list_documents", + "list_tables", + "describe_table", + "get_records", + "sql_query", + "add_records", + "update_records", + "delete_records", + "create_table", + "add_column", + "modify_column", + "delete_column", + ] - expected_tools = [ - "list_documents", - "list_tables", - "describe_table", - "get_records", - "sql_query", - "add_records", - "update_records", - "delete_records", - "create_table", - "add_column", - "modify_column", - "delete_column", - ] + for expected in expected_tools: + assert expected in tool_names, f"Missing tool: {expected}" - for expected in expected_tools: - assert expected in tool_names, f"Missing tool: {expected}" + assert len(result.tools) == 12, f"Expected 12 tools, got {len(result.tools)}" - assert len(result.tools) == 12 + # Test 3: All tools have descriptions + for tool in result.tools: + assert tool.description, f"Tool {tool.name} has no description" + assert len(tool.description) > 10, f"Tool {tool.name} description too short" - -@pytest.mark.asyncio -async def test_list_tools_has_descriptions(mcp_client): - """Test that all tools have descriptions.""" - result = await mcp_client.list_tools() - - for tool in result.tools: - assert tool.description, f"Tool {tool.name} has no description" - assert len(tool.description) > 10, f"Tool {tool.name} description too short" - - -@pytest.mark.asyncio -async def test_list_tools_has_input_schemas(mcp_client): - """Test that all tools have input schemas.""" - result = await mcp_client.list_tools() - - for tool in result.tools: - assert tool.inputSchema is not None, f"Tool {tool.name} has no inputSchema" - assert "type" in tool.inputSchema, f"Tool {tool.name} schema missing type" + # Test 4: All tools have input schemas + for tool in result.tools: + assert tool.inputSchema is not None, f"Tool {tool.name} has no inputSchema" + assert "type" in tool.inputSchema, f"Tool {tool.name} schema missing type" diff --git a/tests/integration/test_tools_integration.py b/tests/integration/test_tools_integration.py index 26e9290..2f3e6ad 100644 --- a/tests/integration/test_tools_integration.py +++ b/tests/integration/test_tools_integration.py @@ -1,252 +1,222 @@ """Test tool calls through MCP client to verify Grist API interactions.""" import json +from contextlib import asynccontextmanager +import httpx import pytest +from mcp import ClientSession +from mcp.client.sse import sse_client + + +GRIST_MCP_URL = "http://localhost:3000" +MOCK_GRIST_URL = "http://localhost:8484" + + +@asynccontextmanager +async def create_mcp_session(): + """Create and yield an MCP session.""" + async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + await session.initialize() + yield session + + +def get_mock_request_log(): + """Get the request log from mock Grist server.""" + with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client: + return client.get("/_test/requests").json() + + +def clear_mock_request_log(): + """Clear the mock Grist request log.""" + with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client: + client.post("/_test/requests/clear") @pytest.mark.asyncio -async def test_list_documents(mcp_client): - """Test list_documents returns accessible documents.""" - result = await mcp_client.call_tool("list_documents", {}) +async def test_all_tools(services_ready): + """Test all MCP tools - reads, writes, schema ops, and auth errors.""" + async with create_mcp_session() as client: + # ===== READ TOOLS ===== - assert len(result.content) == 1 - data = json.loads(result.content[0].text) + # Test list_documents + clear_mock_request_log() + result = await client.call_tool("list_documents", {}) + assert len(result.content) == 1 + data = json.loads(result.content[0].text) + assert "documents" in data + assert len(data["documents"]) == 1 + assert data["documents"][0]["name"] == "test-doc" + assert "read" in data["documents"][0]["permissions"] - assert "documents" in data - assert len(data["documents"]) == 1 - assert data["documents"][0]["name"] == "test-doc" - assert "read" in data["documents"][0]["permissions"] + # Test list_tables + clear_mock_request_log() + result = await client.call_tool("list_tables", {"document": "test-doc"}) + data = json.loads(result.content[0].text) + assert "tables" in data + assert "People" in data["tables"] + assert "Tasks" in data["tables"] + log = get_mock_request_log() + assert any("/tables" in entry["path"] for entry in log) + # Test describe_table + clear_mock_request_log() + result = await client.call_tool( + "describe_table", + {"document": "test-doc", "table": "People"} + ) + data = json.loads(result.content[0].text) + assert "columns" in data + column_ids = [c["id"] for c in data["columns"]] + assert "Name" in column_ids + assert "Age" in column_ids + log = get_mock_request_log() + assert any("/columns" in entry["path"] for entry in log) -@pytest.mark.asyncio -async def test_list_tables(mcp_client, mock_grist_client): - """Test list_tables calls correct Grist API endpoint.""" - result = await mcp_client.call_tool("list_tables", {"document": "test-doc"}) + # Test get_records + clear_mock_request_log() + result = await client.call_tool( + "get_records", + {"document": "test-doc", "table": "People"} + ) + data = json.loads(result.content[0].text) + assert "records" in data + assert len(data["records"]) == 2 + assert data["records"][0]["Name"] == "Alice" + log = get_mock_request_log() + assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log) - # Check response - data = json.loads(result.content[0].text) - assert "tables" in data - assert "People" in data["tables"] - assert "Tasks" in data["tables"] + # Test sql_query + clear_mock_request_log() + result = await client.call_tool( + "sql_query", + {"document": "test-doc", "query": "SELECT Name, Age FROM People"} + ) + data = json.loads(result.content[0].text) + assert "records" in data + assert len(data["records"]) >= 1 + log = get_mock_request_log() + assert any("/sql" in entry["path"] for entry in log) - # Verify mock received correct request - log = mock_grist_client.get("/_test/requests").json() - assert len(log) >= 1 - assert log[-1]["method"] == "GET" - assert "/tables" in log[-1]["path"] + # ===== WRITE TOOLS ===== + # Test add_records + clear_mock_request_log() + new_records = [ + {"Name": "Charlie", "Age": 35, "Email": "charlie@example.com"} + ] + result = await client.call_tool( + "add_records", + {"document": "test-doc", "table": "People", "records": new_records} + ) + data = json.loads(result.content[0].text) + assert "inserted_ids" in data + assert len(data["inserted_ids"]) == 1 + log = get_mock_request_log() + post_requests = [e for e in log if e["method"] == "POST" and "/records" in e["path"]] + assert len(post_requests) >= 1 + assert post_requests[-1]["body"]["records"][0]["fields"]["Name"] == "Charlie" -@pytest.mark.asyncio -async def test_describe_table(mcp_client, mock_grist_client): - """Test describe_table returns column information.""" - result = await mcp_client.call_tool( - "describe_table", - {"document": "test-doc", "table": "People"} - ) + # Test update_records + clear_mock_request_log() + updates = [{"id": 1, "fields": {"Age": 31}}] + result = await client.call_tool( + "update_records", + {"document": "test-doc", "table": "People", "records": updates} + ) + data = json.loads(result.content[0].text) + assert "updated" in data + log = get_mock_request_log() + patch_requests = [e for e in log if e["method"] == "PATCH" and "/records" in e["path"]] + assert len(patch_requests) >= 1 - data = json.loads(result.content[0].text) - assert "columns" in data + # Test delete_records + clear_mock_request_log() + result = await client.call_tool( + "delete_records", + {"document": "test-doc", "table": "People", "record_ids": [1, 2]} + ) + data = json.loads(result.content[0].text) + assert "deleted" in data + log = get_mock_request_log() + delete_requests = [e for e in log if "/data/delete" in e["path"]] + assert len(delete_requests) >= 1 + assert delete_requests[-1]["body"] == [1, 2] - column_ids = [c["id"] for c in data["columns"]] - assert "Name" in column_ids - assert "Age" in column_ids + # ===== SCHEMA TOOLS ===== - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - assert any("/columns" in entry["path"] for entry in log) + # Test create_table + clear_mock_request_log() + columns = [ + {"id": "Title", "type": "Text"}, + {"id": "Count", "type": "Int"}, + ] + result = await client.call_tool( + "create_table", + {"document": "test-doc", "table_id": "NewTable", "columns": columns} + ) + data = json.loads(result.content[0].text) + assert "table_id" in data + log = get_mock_request_log() + post_tables = [e for e in log if e["method"] == "POST" and e["path"].endswith("/tables")] + assert len(post_tables) >= 1 + # Test add_column + clear_mock_request_log() + result = await client.call_tool( + "add_column", + { + "document": "test-doc", + "table": "People", + "column_id": "Phone", + "column_type": "Text", + } + ) + data = json.loads(result.content[0].text) + assert "column_id" in data + log = get_mock_request_log() + post_cols = [e for e in log if e["method"] == "POST" and "/columns" in e["path"]] + assert len(post_cols) >= 1 -@pytest.mark.asyncio -async def test_get_records(mcp_client, mock_grist_client): - """Test get_records fetches records from table.""" - result = await mcp_client.call_tool( - "get_records", - {"document": "test-doc", "table": "People"} - ) + # Test modify_column + clear_mock_request_log() + result = await client.call_tool( + "modify_column", + { + "document": "test-doc", + "table": "People", + "column_id": "Age", + "type": "Numeric", + } + ) + data = json.loads(result.content[0].text) + assert "modified" in data + log = get_mock_request_log() + patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]] + assert len(patch_cols) >= 1 - data = json.loads(result.content[0].text) - assert "records" in data - assert len(data["records"]) == 2 - assert data["records"][0]["Name"] == "Alice" + # Test delete_column + clear_mock_request_log() + result = await client.call_tool( + "delete_column", + { + "document": "test-doc", + "table": "People", + "column_id": "Email", + } + ) + data = json.loads(result.content[0].text) + assert "deleted" in data + log = get_mock_request_log() + delete_cols = [e for e in log if e["method"] == "DELETE" and "/columns/" in e["path"]] + assert len(delete_cols) >= 1 - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log) + # ===== AUTHORIZATION ===== - -@pytest.mark.asyncio -async def test_sql_query(mcp_client, mock_grist_client): - """Test sql_query executes SQL and returns results.""" - result = await mcp_client.call_tool( - "sql_query", - {"document": "test-doc", "query": "SELECT Name, Age FROM People"} - ) - - data = json.loads(result.content[0].text) - assert "records" in data - assert len(data["records"]) >= 1 - - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - assert any("/sql" in entry["path"] for entry in log) - - -@pytest.mark.asyncio -async def test_add_records(mcp_client, mock_grist_client): - """Test add_records sends correct payload to Grist.""" - new_records = [ - {"Name": "Charlie", "Age": 35, "Email": "charlie@example.com"} - ] - - result = await mcp_client.call_tool( - "add_records", - {"document": "test-doc", "table": "People", "records": new_records} - ) - - data = json.loads(result.content[0].text) - assert "record_ids" in data - assert len(data["record_ids"]) == 1 - - # Verify API call body - log = mock_grist_client.get("/_test/requests").json() - post_requests = [e for e in log if e["method"] == "POST" and "/records" in e["path"]] - assert len(post_requests) >= 1 - assert post_requests[-1]["body"]["records"][0]["fields"]["Name"] == "Charlie" - - -@pytest.mark.asyncio -async def test_update_records(mcp_client, mock_grist_client): - """Test update_records sends correct payload to Grist.""" - updates = [ - {"id": 1, "fields": {"Age": 31}} - ] - - result = await mcp_client.call_tool( - "update_records", - {"document": "test-doc", "table": "People", "records": updates} - ) - - data = json.loads(result.content[0].text) - assert "updated" in data - - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - patch_requests = [e for e in log if e["method"] == "PATCH" and "/records" in e["path"]] - assert len(patch_requests) >= 1 - - -@pytest.mark.asyncio -async def test_delete_records(mcp_client, mock_grist_client): - """Test delete_records sends correct IDs to Grist.""" - result = await mcp_client.call_tool( - "delete_records", - {"document": "test-doc", "table": "People", "record_ids": [1, 2]} - ) - - data = json.loads(result.content[0].text) - assert "deleted" in data - - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - delete_requests = [e for e in log if "/data/delete" in e["path"]] - assert len(delete_requests) >= 1 - assert delete_requests[-1]["body"] == [1, 2] - - -@pytest.mark.asyncio -async def test_create_table(mcp_client, mock_grist_client): - """Test create_table sends correct schema to Grist.""" - columns = [ - {"id": "Title", "type": "Text"}, - {"id": "Count", "type": "Int"}, - ] - - result = await mcp_client.call_tool( - "create_table", - {"document": "test-doc", "table_id": "NewTable", "columns": columns} - ) - - data = json.loads(result.content[0].text) - assert "table_id" in data - - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - post_tables = [e for e in log if e["method"] == "POST" and e["path"].endswith("/tables")] - assert len(post_tables) >= 1 - - -@pytest.mark.asyncio -async def test_add_column(mcp_client, mock_grist_client): - """Test add_column sends correct column definition.""" - result = await mcp_client.call_tool( - "add_column", - { - "document": "test-doc", - "table": "People", - "column_id": "Phone", - "column_type": "Text", - } - ) - - data = json.loads(result.content[0].text) - assert "column_id" in data - - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - post_cols = [e for e in log if e["method"] == "POST" and "/columns" in e["path"]] - assert len(post_cols) >= 1 - - -@pytest.mark.asyncio -async def test_modify_column(mcp_client, mock_grist_client): - """Test modify_column sends correct update.""" - result = await mcp_client.call_tool( - "modify_column", - { - "document": "test-doc", - "table": "People", - "column_id": "Age", - "type": "Numeric", - } - ) - - data = json.loads(result.content[0].text) - assert "modified" in data - - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]] - assert len(patch_cols) >= 1 - - -@pytest.mark.asyncio -async def test_delete_column(mcp_client, mock_grist_client): - """Test delete_column calls correct endpoint.""" - result = await mcp_client.call_tool( - "delete_column", - { - "document": "test-doc", - "table": "People", - "column_id": "Email", - } - ) - - data = json.loads(result.content[0].text) - assert "deleted" in data - - # Verify API call - log = mock_grist_client.get("/_test/requests").json() - delete_cols = [e for e in log if e["method"] == "DELETE" and "/columns/" in e["path"]] - assert len(delete_cols) >= 1 - - -@pytest.mark.asyncio -async def test_unauthorized_document_fails(mcp_client): - """Test that accessing unauthorized document returns error.""" - result = await mcp_client.call_tool( - "list_tables", - {"document": "unauthorized-doc"} - ) - - assert "error" in result.content[0].text.lower() or "authorization" in result.content[0].text.lower() + # Test unauthorized document fails + result = await client.call_tool( + "list_tables", + {"document": "unauthorized-doc"} + ) + assert "error" in result.content[0].text.lower() or "authorization" in result.content[0].text.lower()