fix: use pure ASGI app for SSE transport compatibility

- Replace Starlette routing with direct ASGI dispatcher to avoid
  double-response issues with SSE transport
- Simplify integration test fixtures by removing async client fixture
- Consolidate integration tests into single test functions per file
  to prevent SSE connection cleanup issues between tests
- Fix add_records assertion to expect 'inserted_ids' (actual API response)
This commit is contained in:
2025-12-30 15:05:32 -05:00
parent 987b6d087a
commit c57e71b92a
4 changed files with 297 additions and 318 deletions

View File

@@ -2,19 +2,22 @@
import os import os
import sys import sys
from typing import Any
import uvicorn import uvicorn
from mcp.server.sse import SseServerTransport from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from grist_mcp.server import create_server from grist_mcp.server import create_server
from grist_mcp.auth import AuthError from grist_mcp.auth import AuthError
def create_app() -> Starlette: Scope = dict[str, Any]
"""Create the Starlette ASGI application.""" Receive = Any
Send = Any
def create_app():
"""Create the ASGI application."""
config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml") config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml")
if not os.path.exists(config_path): if not os.path.exists(config_path):
@@ -29,27 +32,54 @@ def create_app() -> Starlette:
sse = SseServerTransport("/messages") sse = SseServerTransport("/messages")
async def handle_sse(request): async def handle_sse(scope: Scope, receive: Receive, send: Send) -> None:
async with sse.connect_sse( async with sse.connect_sse(scope, receive, send) as streams:
request.scope, request.receive, request._send
) as streams:
await server.run( await server.run(
streams[0], streams[1], server.create_initialization_options() streams[0], streams[1], server.create_initialization_options()
) )
async def handle_messages(request): async def handle_messages(scope: Scope, receive: Receive, send: Send) -> None:
await sse.handle_post_message(request.scope, request.receive, request._send) await sse.handle_post_message(scope, receive, send)
async def handle_health(request): async def handle_health(scope: Scope, receive: Receive, send: Send) -> None:
return JSONResponse({"status": "ok"}) await send({
"type": "http.response.start",
"status": 200,
"headers": [[b"content-type", b"application/json"]],
})
await send({
"type": "http.response.body",
"body": b'{"status":"ok"}',
})
return Starlette( async def handle_not_found(scope: Scope, receive: Receive, send: Send) -> None:
routes=[ await send({
Route("/health", endpoint=handle_health), "type": "http.response.start",
Route("/sse", endpoint=handle_sse), "status": 404,
Route("/messages", endpoint=handle_messages, methods=["POST"]), "headers": [[b"content-type", b"application/json"]],
] })
) await send({
"type": "http.response.body",
"body": b'{"error":"Not found"}',
})
async def app(scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return
path = scope["path"]
method = scope["method"]
if path == "/health" and method == "GET":
await handle_health(scope, receive, send)
elif path == "/sse" and method == "GET":
await handle_sse(scope, receive, send)
elif path == "/messages" and method == "POST":
await handle_messages(scope, receive, send)
else:
await handle_not_found(scope, receive, send)
return app
def main(): def main():

View File

@@ -4,8 +4,6 @@ import time
import httpx import httpx
import pytest import pytest
from mcp import ClientSession
from mcp.client.sse import sse_client
GRIST_MCP_URL = "http://localhost:3000" GRIST_MCP_URL = "http://localhost:3000"
@@ -35,26 +33,3 @@ def services_ready():
if not wait_for_service(GRIST_MCP_URL): if not wait_for_service(GRIST_MCP_URL):
pytest.fail(f"grist-mcp server not ready at {GRIST_MCP_URL}") pytest.fail(f"grist-mcp server not ready at {GRIST_MCP_URL}")
return True return True
@pytest.fixture
async def mcp_client(services_ready):
"""Create an MCP client connected to grist-mcp via SSE."""
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
@pytest.fixture
def mock_grist_client(services_ready):
"""HTTP client for interacting with mock Grist test endpoints."""
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
yield client
@pytest.fixture(autouse=True)
def clear_mock_grist_log(mock_grist_client):
"""Clear the mock Grist request log before each test."""
mock_grist_client.post("/_test/requests/clear")
yield

View File

@@ -1,57 +1,61 @@
"""Test MCP protocol compliance over SSE transport.""" """Test MCP protocol compliance over SSE transport."""
from contextlib import asynccontextmanager
import pytest import pytest
from mcp import ClientSession
from mcp.client.sse import sse_client
GRIST_MCP_URL = "http://localhost:3000"
@asynccontextmanager
async def create_mcp_session():
"""Create and yield an MCP session."""
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_mcp_connection_initializes(mcp_client): async def test_mcp_protocol_compliance(services_ready):
"""Test that MCP client can connect and initialize.""" """Test MCP protocol compliance - connection, tools, descriptions, schemas."""
# If we get here, connection and initialization succeeded async with create_mcp_session() as client:
assert mcp_client is not None # Test 1: Connection initializes
assert client is not None
# Test 2: list_tools returns all expected tools
result = await client.list_tools()
tool_names = [tool.name for tool in result.tools]
@pytest.mark.asyncio expected_tools = [
async def test_list_tools_returns_all_tools(mcp_client): "list_documents",
"""Test that list_tools returns all expected tools.""" "list_tables",
result = await mcp_client.list_tools() "describe_table",
tool_names = [tool.name for tool in result.tools] "get_records",
"sql_query",
"add_records",
"update_records",
"delete_records",
"create_table",
"add_column",
"modify_column",
"delete_column",
]
expected_tools = [ for expected in expected_tools:
"list_documents", assert expected in tool_names, f"Missing tool: {expected}"
"list_tables",
"describe_table",
"get_records",
"sql_query",
"add_records",
"update_records",
"delete_records",
"create_table",
"add_column",
"modify_column",
"delete_column",
]
for expected in expected_tools: assert len(result.tools) == 12, f"Expected 12 tools, got {len(result.tools)}"
assert expected in tool_names, f"Missing tool: {expected}"
assert len(result.tools) == 12 # Test 3: All tools have descriptions
for tool in result.tools:
assert tool.description, f"Tool {tool.name} has no description"
assert len(tool.description) > 10, f"Tool {tool.name} description too short"
# Test 4: All tools have input schemas
@pytest.mark.asyncio for tool in result.tools:
async def test_list_tools_has_descriptions(mcp_client): assert tool.inputSchema is not None, f"Tool {tool.name} has no inputSchema"
"""Test that all tools have descriptions.""" assert "type" in tool.inputSchema, f"Tool {tool.name} schema missing type"
result = await mcp_client.list_tools()
for tool in result.tools:
assert tool.description, f"Tool {tool.name} has no description"
assert len(tool.description) > 10, f"Tool {tool.name} description too short"
@pytest.mark.asyncio
async def test_list_tools_has_input_schemas(mcp_client):
"""Test that all tools have input schemas."""
result = await mcp_client.list_tools()
for tool in result.tools:
assert tool.inputSchema is not None, f"Tool {tool.name} has no inputSchema"
assert "type" in tool.inputSchema, f"Tool {tool.name} schema missing type"

View File

@@ -1,252 +1,222 @@
"""Test tool calls through MCP client to verify Grist API interactions.""" """Test tool calls through MCP client to verify Grist API interactions."""
import json import json
from contextlib import asynccontextmanager
import httpx
import pytest import pytest
from mcp import ClientSession
from mcp.client.sse import sse_client
GRIST_MCP_URL = "http://localhost:3000"
MOCK_GRIST_URL = "http://localhost:8484"
@asynccontextmanager
async def create_mcp_session():
"""Create and yield an MCP session."""
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
def get_mock_request_log():
"""Get the request log from mock Grist server."""
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
return client.get("/_test/requests").json()
def clear_mock_request_log():
"""Clear the mock Grist request log."""
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
client.post("/_test/requests/clear")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_documents(mcp_client): async def test_all_tools(services_ready):
"""Test list_documents returns accessible documents.""" """Test all MCP tools - reads, writes, schema ops, and auth errors."""
result = await mcp_client.call_tool("list_documents", {}) async with create_mcp_session() as client:
# ===== READ TOOLS =====
assert len(result.content) == 1 # Test list_documents
data = json.loads(result.content[0].text) clear_mock_request_log()
result = await client.call_tool("list_documents", {})
assert len(result.content) == 1
data = json.loads(result.content[0].text)
assert "documents" in data
assert len(data["documents"]) == 1
assert data["documents"][0]["name"] == "test-doc"
assert "read" in data["documents"][0]["permissions"]
assert "documents" in data # Test list_tables
assert len(data["documents"]) == 1 clear_mock_request_log()
assert data["documents"][0]["name"] == "test-doc" result = await client.call_tool("list_tables", {"document": "test-doc"})
assert "read" in data["documents"][0]["permissions"] data = json.loads(result.content[0].text)
assert "tables" in data
assert "People" in data["tables"]
assert "Tasks" in data["tables"]
log = get_mock_request_log()
assert any("/tables" in entry["path"] for entry in log)
# Test describe_table
clear_mock_request_log()
result = await client.call_tool(
"describe_table",
{"document": "test-doc", "table": "People"}
)
data = json.loads(result.content[0].text)
assert "columns" in data
column_ids = [c["id"] for c in data["columns"]]
assert "Name" in column_ids
assert "Age" in column_ids
log = get_mock_request_log()
assert any("/columns" in entry["path"] for entry in log)
@pytest.mark.asyncio # Test get_records
async def test_list_tables(mcp_client, mock_grist_client): clear_mock_request_log()
"""Test list_tables calls correct Grist API endpoint.""" result = await client.call_tool(
result = await mcp_client.call_tool("list_tables", {"document": "test-doc"}) "get_records",
{"document": "test-doc", "table": "People"}
)
data = json.loads(result.content[0].text)
assert "records" in data
assert len(data["records"]) == 2
assert data["records"][0]["Name"] == "Alice"
log = get_mock_request_log()
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
# Check response # Test sql_query
data = json.loads(result.content[0].text) clear_mock_request_log()
assert "tables" in data result = await client.call_tool(
assert "People" in data["tables"] "sql_query",
assert "Tasks" in data["tables"] {"document": "test-doc", "query": "SELECT Name, Age FROM People"}
)
data = json.loads(result.content[0].text)
assert "records" in data
assert len(data["records"]) >= 1
log = get_mock_request_log()
assert any("/sql" in entry["path"] for entry in log)
# Verify mock received correct request # ===== WRITE TOOLS =====
log = mock_grist_client.get("/_test/requests").json()
assert len(log) >= 1
assert log[-1]["method"] == "GET"
assert "/tables" in log[-1]["path"]
# Test add_records
clear_mock_request_log()
new_records = [
{"Name": "Charlie", "Age": 35, "Email": "charlie@example.com"}
]
result = await client.call_tool(
"add_records",
{"document": "test-doc", "table": "People", "records": new_records}
)
data = json.loads(result.content[0].text)
assert "inserted_ids" in data
assert len(data["inserted_ids"]) == 1
log = get_mock_request_log()
post_requests = [e for e in log if e["method"] == "POST" and "/records" in e["path"]]
assert len(post_requests) >= 1
assert post_requests[-1]["body"]["records"][0]["fields"]["Name"] == "Charlie"
@pytest.mark.asyncio # Test update_records
async def test_describe_table(mcp_client, mock_grist_client): clear_mock_request_log()
"""Test describe_table returns column information.""" updates = [{"id": 1, "fields": {"Age": 31}}]
result = await mcp_client.call_tool( result = await client.call_tool(
"describe_table", "update_records",
{"document": "test-doc", "table": "People"} {"document": "test-doc", "table": "People", "records": updates}
) )
data = json.loads(result.content[0].text)
assert "updated" in data
log = get_mock_request_log()
patch_requests = [e for e in log if e["method"] == "PATCH" and "/records" in e["path"]]
assert len(patch_requests) >= 1
data = json.loads(result.content[0].text) # Test delete_records
assert "columns" in data clear_mock_request_log()
result = await client.call_tool(
"delete_records",
{"document": "test-doc", "table": "People", "record_ids": [1, 2]}
)
data = json.loads(result.content[0].text)
assert "deleted" in data
log = get_mock_request_log()
delete_requests = [e for e in log if "/data/delete" in e["path"]]
assert len(delete_requests) >= 1
assert delete_requests[-1]["body"] == [1, 2]
column_ids = [c["id"] for c in data["columns"]] # ===== SCHEMA TOOLS =====
assert "Name" in column_ids
assert "Age" in column_ids
# Verify API call # Test create_table
log = mock_grist_client.get("/_test/requests").json() clear_mock_request_log()
assert any("/columns" in entry["path"] for entry in log) columns = [
{"id": "Title", "type": "Text"},
{"id": "Count", "type": "Int"},
]
result = await client.call_tool(
"create_table",
{"document": "test-doc", "table_id": "NewTable", "columns": columns}
)
data = json.loads(result.content[0].text)
assert "table_id" in data
log = get_mock_request_log()
post_tables = [e for e in log if e["method"] == "POST" and e["path"].endswith("/tables")]
assert len(post_tables) >= 1
# Test add_column
clear_mock_request_log()
result = await client.call_tool(
"add_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Phone",
"column_type": "Text",
}
)
data = json.loads(result.content[0].text)
assert "column_id" in data
log = get_mock_request_log()
post_cols = [e for e in log if e["method"] == "POST" and "/columns" in e["path"]]
assert len(post_cols) >= 1
@pytest.mark.asyncio # Test modify_column
async def test_get_records(mcp_client, mock_grist_client): clear_mock_request_log()
"""Test get_records fetches records from table.""" result = await client.call_tool(
result = await mcp_client.call_tool( "modify_column",
"get_records", {
{"document": "test-doc", "table": "People"} "document": "test-doc",
) "table": "People",
"column_id": "Age",
"type": "Numeric",
}
)
data = json.loads(result.content[0].text)
assert "modified" in data
log = get_mock_request_log()
patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]]
assert len(patch_cols) >= 1
data = json.loads(result.content[0].text) # Test delete_column
assert "records" in data clear_mock_request_log()
assert len(data["records"]) == 2 result = await client.call_tool(
assert data["records"][0]["Name"] == "Alice" "delete_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Email",
}
)
data = json.loads(result.content[0].text)
assert "deleted" in data
log = get_mock_request_log()
delete_cols = [e for e in log if e["method"] == "DELETE" and "/columns/" in e["path"]]
assert len(delete_cols) >= 1
# Verify API call # ===== AUTHORIZATION =====
log = mock_grist_client.get("/_test/requests").json()
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
# Test unauthorized document fails
@pytest.mark.asyncio result = await client.call_tool(
async def test_sql_query(mcp_client, mock_grist_client): "list_tables",
"""Test sql_query executes SQL and returns results.""" {"document": "unauthorized-doc"}
result = await mcp_client.call_tool( )
"sql_query", assert "error" in result.content[0].text.lower() or "authorization" in result.content[0].text.lower()
{"document": "test-doc", "query": "SELECT Name, Age FROM People"}
)
data = json.loads(result.content[0].text)
assert "records" in data
assert len(data["records"]) >= 1
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
assert any("/sql" in entry["path"] for entry in log)
@pytest.mark.asyncio
async def test_add_records(mcp_client, mock_grist_client):
"""Test add_records sends correct payload to Grist."""
new_records = [
{"Name": "Charlie", "Age": 35, "Email": "charlie@example.com"}
]
result = await mcp_client.call_tool(
"add_records",
{"document": "test-doc", "table": "People", "records": new_records}
)
data = json.loads(result.content[0].text)
assert "record_ids" in data
assert len(data["record_ids"]) == 1
# Verify API call body
log = mock_grist_client.get("/_test/requests").json()
post_requests = [e for e in log if e["method"] == "POST" and "/records" in e["path"]]
assert len(post_requests) >= 1
assert post_requests[-1]["body"]["records"][0]["fields"]["Name"] == "Charlie"
@pytest.mark.asyncio
async def test_update_records(mcp_client, mock_grist_client):
"""Test update_records sends correct payload to Grist."""
updates = [
{"id": 1, "fields": {"Age": 31}}
]
result = await mcp_client.call_tool(
"update_records",
{"document": "test-doc", "table": "People", "records": updates}
)
data = json.loads(result.content[0].text)
assert "updated" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
patch_requests = [e for e in log if e["method"] == "PATCH" and "/records" in e["path"]]
assert len(patch_requests) >= 1
@pytest.mark.asyncio
async def test_delete_records(mcp_client, mock_grist_client):
"""Test delete_records sends correct IDs to Grist."""
result = await mcp_client.call_tool(
"delete_records",
{"document": "test-doc", "table": "People", "record_ids": [1, 2]}
)
data = json.loads(result.content[0].text)
assert "deleted" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
delete_requests = [e for e in log if "/data/delete" in e["path"]]
assert len(delete_requests) >= 1
assert delete_requests[-1]["body"] == [1, 2]
@pytest.mark.asyncio
async def test_create_table(mcp_client, mock_grist_client):
"""Test create_table sends correct schema to Grist."""
columns = [
{"id": "Title", "type": "Text"},
{"id": "Count", "type": "Int"},
]
result = await mcp_client.call_tool(
"create_table",
{"document": "test-doc", "table_id": "NewTable", "columns": columns}
)
data = json.loads(result.content[0].text)
assert "table_id" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
post_tables = [e for e in log if e["method"] == "POST" and e["path"].endswith("/tables")]
assert len(post_tables) >= 1
@pytest.mark.asyncio
async def test_add_column(mcp_client, mock_grist_client):
"""Test add_column sends correct column definition."""
result = await mcp_client.call_tool(
"add_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Phone",
"column_type": "Text",
}
)
data = json.loads(result.content[0].text)
assert "column_id" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
post_cols = [e for e in log if e["method"] == "POST" and "/columns" in e["path"]]
assert len(post_cols) >= 1
@pytest.mark.asyncio
async def test_modify_column(mcp_client, mock_grist_client):
"""Test modify_column sends correct update."""
result = await mcp_client.call_tool(
"modify_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Age",
"type": "Numeric",
}
)
data = json.loads(result.content[0].text)
assert "modified" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]]
assert len(patch_cols) >= 1
@pytest.mark.asyncio
async def test_delete_column(mcp_client, mock_grist_client):
"""Test delete_column calls correct endpoint."""
result = await mcp_client.call_tool(
"delete_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Email",
}
)
data = json.loads(result.content[0].text)
assert "deleted" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
delete_cols = [e for e in log if e["method"] == "DELETE" and "/columns/" in e["path"]]
assert len(delete_cols) >= 1
@pytest.mark.asyncio
async def test_unauthorized_document_fails(mcp_client):
"""Test that accessing unauthorized document returns error."""
result = await mcp_client.call_tool(
"list_tables",
{"document": "unauthorized-doc"}
)
assert "error" in result.content[0].text.lower() or "authorization" in result.content[0].text.lower()