fix: use pure ASGI app for SSE transport compatibility

- Replace Starlette routing with direct ASGI dispatcher to avoid
  double-response issues with SSE transport
- Simplify integration test fixtures by removing async client fixture
- Consolidate integration tests into single test functions per file
  to prevent SSE connection cleanup issues between tests
- Fix add_records assertion to expect 'inserted_ids' (actual API response)
This commit is contained in:
2025-12-30 15:05:32 -05:00
parent 987b6d087a
commit c57e71b92a
4 changed files with 297 additions and 318 deletions

View File

@@ -4,8 +4,6 @@ import time
import httpx
import pytest
from mcp import ClientSession
from mcp.client.sse import sse_client
GRIST_MCP_URL = "http://localhost:3000"
@@ -35,26 +33,3 @@ def services_ready():
if not wait_for_service(GRIST_MCP_URL):
pytest.fail(f"grist-mcp server not ready at {GRIST_MCP_URL}")
return True
@pytest.fixture
async def mcp_client(services_ready):
"""Create an MCP client connected to grist-mcp via SSE."""
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
@pytest.fixture
def mock_grist_client(services_ready):
"""HTTP client for interacting with mock Grist test endpoints."""
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
yield client
@pytest.fixture(autouse=True)
def clear_mock_grist_log(mock_grist_client):
"""Clear the mock Grist request log before each test."""
mock_grist_client.post("/_test/requests/clear")
yield

View File

@@ -1,57 +1,61 @@
"""Test MCP protocol compliance over SSE transport."""
from contextlib import asynccontextmanager
import pytest
from mcp import ClientSession
from mcp.client.sse import sse_client
GRIST_MCP_URL = "http://localhost:3000"
@asynccontextmanager
async def create_mcp_session():
"""Create and yield an MCP session."""
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
@pytest.mark.asyncio
async def test_mcp_connection_initializes(mcp_client):
"""Test that MCP client can connect and initialize."""
# If we get here, connection and initialization succeeded
assert mcp_client is not None
async def test_mcp_protocol_compliance(services_ready):
"""Test MCP protocol compliance - connection, tools, descriptions, schemas."""
async with create_mcp_session() as client:
# Test 1: Connection initializes
assert client is not None
# Test 2: list_tools returns all expected tools
result = await client.list_tools()
tool_names = [tool.name for tool in result.tools]
@pytest.mark.asyncio
async def test_list_tools_returns_all_tools(mcp_client):
"""Test that list_tools returns all expected tools."""
result = await mcp_client.list_tools()
tool_names = [tool.name for tool in result.tools]
expected_tools = [
"list_documents",
"list_tables",
"describe_table",
"get_records",
"sql_query",
"add_records",
"update_records",
"delete_records",
"create_table",
"add_column",
"modify_column",
"delete_column",
]
expected_tools = [
"list_documents",
"list_tables",
"describe_table",
"get_records",
"sql_query",
"add_records",
"update_records",
"delete_records",
"create_table",
"add_column",
"modify_column",
"delete_column",
]
for expected in expected_tools:
assert expected in tool_names, f"Missing tool: {expected}"
for expected in expected_tools:
assert expected in tool_names, f"Missing tool: {expected}"
assert len(result.tools) == 12, f"Expected 12 tools, got {len(result.tools)}"
assert len(result.tools) == 12
# Test 3: All tools have descriptions
for tool in result.tools:
assert tool.description, f"Tool {tool.name} has no description"
assert len(tool.description) > 10, f"Tool {tool.name} description too short"
@pytest.mark.asyncio
async def test_list_tools_has_descriptions(mcp_client):
"""Test that all tools have descriptions."""
result = await mcp_client.list_tools()
for tool in result.tools:
assert tool.description, f"Tool {tool.name} has no description"
assert len(tool.description) > 10, f"Tool {tool.name} description too short"
@pytest.mark.asyncio
async def test_list_tools_has_input_schemas(mcp_client):
"""Test that all tools have input schemas."""
result = await mcp_client.list_tools()
for tool in result.tools:
assert tool.inputSchema is not None, f"Tool {tool.name} has no inputSchema"
assert "type" in tool.inputSchema, f"Tool {tool.name} schema missing type"
# Test 4: All tools have input schemas
for tool in result.tools:
assert tool.inputSchema is not None, f"Tool {tool.name} has no inputSchema"
assert "type" in tool.inputSchema, f"Tool {tool.name} schema missing type"

View File

@@ -1,252 +1,222 @@
"""Test tool calls through MCP client to verify Grist API interactions."""
import json
from contextlib import asynccontextmanager
import httpx
import pytest
from mcp import ClientSession
from mcp.client.sse import sse_client
GRIST_MCP_URL = "http://localhost:3000"
MOCK_GRIST_URL = "http://localhost:8484"
@asynccontextmanager
async def create_mcp_session():
"""Create and yield an MCP session."""
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
async with ClientSession(read_stream, write_stream) as session:
await session.initialize()
yield session
def get_mock_request_log():
"""Get the request log from mock Grist server."""
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
return client.get("/_test/requests").json()
def clear_mock_request_log():
"""Clear the mock Grist request log."""
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
client.post("/_test/requests/clear")
@pytest.mark.asyncio
async def test_list_documents(mcp_client):
"""Test list_documents returns accessible documents."""
result = await mcp_client.call_tool("list_documents", {})
async def test_all_tools(services_ready):
"""Test all MCP tools - reads, writes, schema ops, and auth errors."""
async with create_mcp_session() as client:
# ===== READ TOOLS =====
assert len(result.content) == 1
data = json.loads(result.content[0].text)
# Test list_documents
clear_mock_request_log()
result = await client.call_tool("list_documents", {})
assert len(result.content) == 1
data = json.loads(result.content[0].text)
assert "documents" in data
assert len(data["documents"]) == 1
assert data["documents"][0]["name"] == "test-doc"
assert "read" in data["documents"][0]["permissions"]
assert "documents" in data
assert len(data["documents"]) == 1
assert data["documents"][0]["name"] == "test-doc"
assert "read" in data["documents"][0]["permissions"]
# Test list_tables
clear_mock_request_log()
result = await client.call_tool("list_tables", {"document": "test-doc"})
data = json.loads(result.content[0].text)
assert "tables" in data
assert "People" in data["tables"]
assert "Tasks" in data["tables"]
log = get_mock_request_log()
assert any("/tables" in entry["path"] for entry in log)
# Test describe_table
clear_mock_request_log()
result = await client.call_tool(
"describe_table",
{"document": "test-doc", "table": "People"}
)
data = json.loads(result.content[0].text)
assert "columns" in data
column_ids = [c["id"] for c in data["columns"]]
assert "Name" in column_ids
assert "Age" in column_ids
log = get_mock_request_log()
assert any("/columns" in entry["path"] for entry in log)
@pytest.mark.asyncio
async def test_list_tables(mcp_client, mock_grist_client):
"""Test list_tables calls correct Grist API endpoint."""
result = await mcp_client.call_tool("list_tables", {"document": "test-doc"})
# Test get_records
clear_mock_request_log()
result = await client.call_tool(
"get_records",
{"document": "test-doc", "table": "People"}
)
data = json.loads(result.content[0].text)
assert "records" in data
assert len(data["records"]) == 2
assert data["records"][0]["Name"] == "Alice"
log = get_mock_request_log()
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
# Check response
data = json.loads(result.content[0].text)
assert "tables" in data
assert "People" in data["tables"]
assert "Tasks" in data["tables"]
# Test sql_query
clear_mock_request_log()
result = await client.call_tool(
"sql_query",
{"document": "test-doc", "query": "SELECT Name, Age FROM People"}
)
data = json.loads(result.content[0].text)
assert "records" in data
assert len(data["records"]) >= 1
log = get_mock_request_log()
assert any("/sql" in entry["path"] for entry in log)
# Verify mock received correct request
log = mock_grist_client.get("/_test/requests").json()
assert len(log) >= 1
assert log[-1]["method"] == "GET"
assert "/tables" in log[-1]["path"]
# ===== WRITE TOOLS =====
# Test add_records
clear_mock_request_log()
new_records = [
{"Name": "Charlie", "Age": 35, "Email": "charlie@example.com"}
]
result = await client.call_tool(
"add_records",
{"document": "test-doc", "table": "People", "records": new_records}
)
data = json.loads(result.content[0].text)
assert "inserted_ids" in data
assert len(data["inserted_ids"]) == 1
log = get_mock_request_log()
post_requests = [e for e in log if e["method"] == "POST" and "/records" in e["path"]]
assert len(post_requests) >= 1
assert post_requests[-1]["body"]["records"][0]["fields"]["Name"] == "Charlie"
@pytest.mark.asyncio
async def test_describe_table(mcp_client, mock_grist_client):
"""Test describe_table returns column information."""
result = await mcp_client.call_tool(
"describe_table",
{"document": "test-doc", "table": "People"}
)
# Test update_records
clear_mock_request_log()
updates = [{"id": 1, "fields": {"Age": 31}}]
result = await client.call_tool(
"update_records",
{"document": "test-doc", "table": "People", "records": updates}
)
data = json.loads(result.content[0].text)
assert "updated" in data
log = get_mock_request_log()
patch_requests = [e for e in log if e["method"] == "PATCH" and "/records" in e["path"]]
assert len(patch_requests) >= 1
data = json.loads(result.content[0].text)
assert "columns" in data
# Test delete_records
clear_mock_request_log()
result = await client.call_tool(
"delete_records",
{"document": "test-doc", "table": "People", "record_ids": [1, 2]}
)
data = json.loads(result.content[0].text)
assert "deleted" in data
log = get_mock_request_log()
delete_requests = [e for e in log if "/data/delete" in e["path"]]
assert len(delete_requests) >= 1
assert delete_requests[-1]["body"] == [1, 2]
column_ids = [c["id"] for c in data["columns"]]
assert "Name" in column_ids
assert "Age" in column_ids
# ===== SCHEMA TOOLS =====
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
assert any("/columns" in entry["path"] for entry in log)
# Test create_table
clear_mock_request_log()
columns = [
{"id": "Title", "type": "Text"},
{"id": "Count", "type": "Int"},
]
result = await client.call_tool(
"create_table",
{"document": "test-doc", "table_id": "NewTable", "columns": columns}
)
data = json.loads(result.content[0].text)
assert "table_id" in data
log = get_mock_request_log()
post_tables = [e for e in log if e["method"] == "POST" and e["path"].endswith("/tables")]
assert len(post_tables) >= 1
# Test add_column
clear_mock_request_log()
result = await client.call_tool(
"add_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Phone",
"column_type": "Text",
}
)
data = json.loads(result.content[0].text)
assert "column_id" in data
log = get_mock_request_log()
post_cols = [e for e in log if e["method"] == "POST" and "/columns" in e["path"]]
assert len(post_cols) >= 1
@pytest.mark.asyncio
async def test_get_records(mcp_client, mock_grist_client):
"""Test get_records fetches records from table."""
result = await mcp_client.call_tool(
"get_records",
{"document": "test-doc", "table": "People"}
)
# Test modify_column
clear_mock_request_log()
result = await client.call_tool(
"modify_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Age",
"type": "Numeric",
}
)
data = json.loads(result.content[0].text)
assert "modified" in data
log = get_mock_request_log()
patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]]
assert len(patch_cols) >= 1
data = json.loads(result.content[0].text)
assert "records" in data
assert len(data["records"]) == 2
assert data["records"][0]["Name"] == "Alice"
# Test delete_column
clear_mock_request_log()
result = await client.call_tool(
"delete_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Email",
}
)
data = json.loads(result.content[0].text)
assert "deleted" in data
log = get_mock_request_log()
delete_cols = [e for e in log if e["method"] == "DELETE" and "/columns/" in e["path"]]
assert len(delete_cols) >= 1
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
# ===== AUTHORIZATION =====
@pytest.mark.asyncio
async def test_sql_query(mcp_client, mock_grist_client):
"""Test sql_query executes SQL and returns results."""
result = await mcp_client.call_tool(
"sql_query",
{"document": "test-doc", "query": "SELECT Name, Age FROM People"}
)
data = json.loads(result.content[0].text)
assert "records" in data
assert len(data["records"]) >= 1
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
assert any("/sql" in entry["path"] for entry in log)
@pytest.mark.asyncio
async def test_add_records(mcp_client, mock_grist_client):
"""Test add_records sends correct payload to Grist."""
new_records = [
{"Name": "Charlie", "Age": 35, "Email": "charlie@example.com"}
]
result = await mcp_client.call_tool(
"add_records",
{"document": "test-doc", "table": "People", "records": new_records}
)
data = json.loads(result.content[0].text)
assert "record_ids" in data
assert len(data["record_ids"]) == 1
# Verify API call body
log = mock_grist_client.get("/_test/requests").json()
post_requests = [e for e in log if e["method"] == "POST" and "/records" in e["path"]]
assert len(post_requests) >= 1
assert post_requests[-1]["body"]["records"][0]["fields"]["Name"] == "Charlie"
@pytest.mark.asyncio
async def test_update_records(mcp_client, mock_grist_client):
"""Test update_records sends correct payload to Grist."""
updates = [
{"id": 1, "fields": {"Age": 31}}
]
result = await mcp_client.call_tool(
"update_records",
{"document": "test-doc", "table": "People", "records": updates}
)
data = json.loads(result.content[0].text)
assert "updated" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
patch_requests = [e for e in log if e["method"] == "PATCH" and "/records" in e["path"]]
assert len(patch_requests) >= 1
@pytest.mark.asyncio
async def test_delete_records(mcp_client, mock_grist_client):
"""Test delete_records sends correct IDs to Grist."""
result = await mcp_client.call_tool(
"delete_records",
{"document": "test-doc", "table": "People", "record_ids": [1, 2]}
)
data = json.loads(result.content[0].text)
assert "deleted" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
delete_requests = [e for e in log if "/data/delete" in e["path"]]
assert len(delete_requests) >= 1
assert delete_requests[-1]["body"] == [1, 2]
@pytest.mark.asyncio
async def test_create_table(mcp_client, mock_grist_client):
"""Test create_table sends correct schema to Grist."""
columns = [
{"id": "Title", "type": "Text"},
{"id": "Count", "type": "Int"},
]
result = await mcp_client.call_tool(
"create_table",
{"document": "test-doc", "table_id": "NewTable", "columns": columns}
)
data = json.loads(result.content[0].text)
assert "table_id" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
post_tables = [e for e in log if e["method"] == "POST" and e["path"].endswith("/tables")]
assert len(post_tables) >= 1
@pytest.mark.asyncio
async def test_add_column(mcp_client, mock_grist_client):
"""Test add_column sends correct column definition."""
result = await mcp_client.call_tool(
"add_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Phone",
"column_type": "Text",
}
)
data = json.loads(result.content[0].text)
assert "column_id" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
post_cols = [e for e in log if e["method"] == "POST" and "/columns" in e["path"]]
assert len(post_cols) >= 1
@pytest.mark.asyncio
async def test_modify_column(mcp_client, mock_grist_client):
"""Test modify_column sends correct update."""
result = await mcp_client.call_tool(
"modify_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Age",
"type": "Numeric",
}
)
data = json.loads(result.content[0].text)
assert "modified" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]]
assert len(patch_cols) >= 1
@pytest.mark.asyncio
async def test_delete_column(mcp_client, mock_grist_client):
"""Test delete_column calls correct endpoint."""
result = await mcp_client.call_tool(
"delete_column",
{
"document": "test-doc",
"table": "People",
"column_id": "Email",
}
)
data = json.loads(result.content[0].text)
assert "deleted" in data
# Verify API call
log = mock_grist_client.get("/_test/requests").json()
delete_cols = [e for e in log if e["method"] == "DELETE" and "/columns/" in e["path"]]
assert len(delete_cols) >= 1
@pytest.mark.asyncio
async def test_unauthorized_document_fails(mcp_client):
"""Test that accessing unauthorized document returns error."""
result = await mcp_client.call_tool(
"list_tables",
{"document": "unauthorized-doc"}
)
assert "error" in result.content[0].text.lower() or "authorization" in result.content[0].text.lower()
# Test unauthorized document fails
result = await client.call_tool(
"list_tables",
{"document": "unauthorized-doc"}
)
assert "error" in result.content[0].text.lower() or "authorization" in result.content[0].text.lower()