fix: use pure ASGI app for SSE transport compatibility
- Replace Starlette routing with direct ASGI dispatcher to avoid double-response issues with SSE transport - Simplify integration test fixtures by removing async client fixture - Consolidate integration tests into single test functions per file to prevent SSE connection cleanup issues between tests - Fix add_records assertion to expect 'inserted_ids' (actual API response)
This commit is contained in:
@@ -2,19 +2,22 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
from starlette.applications import Starlette
|
|
||||||
from starlette.responses import JSONResponse
|
|
||||||
from starlette.routing import Route
|
|
||||||
|
|
||||||
from grist_mcp.server import create_server
|
from grist_mcp.server import create_server
|
||||||
from grist_mcp.auth import AuthError
|
from grist_mcp.auth import AuthError
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> Starlette:
|
Scope = dict[str, Any]
|
||||||
"""Create the Starlette ASGI application."""
|
Receive = Any
|
||||||
|
Send = Any
|
||||||
|
|
||||||
|
|
||||||
|
def create_app():
|
||||||
|
"""Create the ASGI application."""
|
||||||
config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml")
|
config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml")
|
||||||
|
|
||||||
if not os.path.exists(config_path):
|
if not os.path.exists(config_path):
|
||||||
@@ -29,27 +32,54 @@ def create_app() -> Starlette:
|
|||||||
|
|
||||||
sse = SseServerTransport("/messages")
|
sse = SseServerTransport("/messages")
|
||||||
|
|
||||||
async def handle_sse(request):
|
async def handle_sse(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
async with sse.connect_sse(
|
async with sse.connect_sse(scope, receive, send) as streams:
|
||||||
request.scope, request.receive, request._send
|
|
||||||
) as streams:
|
|
||||||
await server.run(
|
await server.run(
|
||||||
streams[0], streams[1], server.create_initialization_options()
|
streams[0], streams[1], server.create_initialization_options()
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_messages(request):
|
async def handle_messages(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
await sse.handle_post_message(request.scope, request.receive, request._send)
|
await sse.handle_post_message(scope, receive, send)
|
||||||
|
|
||||||
async def handle_health(request):
|
async def handle_health(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
return JSONResponse({"status": "ok"})
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 200,
|
||||||
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": b'{"status":"ok"}',
|
||||||
|
})
|
||||||
|
|
||||||
return Starlette(
|
async def handle_not_found(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
routes=[
|
await send({
|
||||||
Route("/health", endpoint=handle_health),
|
"type": "http.response.start",
|
||||||
Route("/sse", endpoint=handle_sse),
|
"status": 404,
|
||||||
Route("/messages", endpoint=handle_messages, methods=["POST"]),
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
]
|
})
|
||||||
)
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": b'{"error":"Not found"}',
|
||||||
|
})
|
||||||
|
|
||||||
|
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
if scope["type"] != "http":
|
||||||
|
return
|
||||||
|
|
||||||
|
path = scope["path"]
|
||||||
|
method = scope["method"]
|
||||||
|
|
||||||
|
if path == "/health" and method == "GET":
|
||||||
|
await handle_health(scope, receive, send)
|
||||||
|
elif path == "/sse" and method == "GET":
|
||||||
|
await handle_sse(scope, receive, send)
|
||||||
|
elif path == "/messages" and method == "POST":
|
||||||
|
await handle_messages(scope, receive, send)
|
||||||
|
else:
|
||||||
|
await handle_not_found(scope, receive, send)
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -4,8 +4,6 @@ import time
|
|||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
from mcp import ClientSession
|
|
||||||
from mcp.client.sse import sse_client
|
|
||||||
|
|
||||||
|
|
||||||
GRIST_MCP_URL = "http://localhost:3000"
|
GRIST_MCP_URL = "http://localhost:3000"
|
||||||
@@ -35,26 +33,3 @@ def services_ready():
|
|||||||
if not wait_for_service(GRIST_MCP_URL):
|
if not wait_for_service(GRIST_MCP_URL):
|
||||||
pytest.fail(f"grist-mcp server not ready at {GRIST_MCP_URL}")
|
pytest.fail(f"grist-mcp server not ready at {GRIST_MCP_URL}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def mcp_client(services_ready):
|
|
||||||
"""Create an MCP client connected to grist-mcp via SSE."""
|
|
||||||
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
|
|
||||||
async with ClientSession(read_stream, write_stream) as session:
|
|
||||||
await session.initialize()
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_grist_client(services_ready):
|
|
||||||
"""HTTP client for interacting with mock Grist test endpoints."""
|
|
||||||
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
|
|
||||||
yield client
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def clear_mock_grist_log(mock_grist_client):
|
|
||||||
"""Clear the mock Grist request log before each test."""
|
|
||||||
mock_grist_client.post("/_test/requests/clear")
|
|
||||||
yield
|
|
||||||
|
|||||||
@@ -1,57 +1,61 @@
|
|||||||
"""Test MCP protocol compliance over SSE transport."""
|
"""Test MCP protocol compliance over SSE transport."""
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
|
||||||
|
GRIST_MCP_URL = "http://localhost:3000"
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_mcp_session():
|
||||||
|
"""Create and yield an MCP session."""
|
||||||
|
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
await session.initialize()
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_mcp_connection_initializes(mcp_client):
|
async def test_mcp_protocol_compliance(services_ready):
|
||||||
"""Test that MCP client can connect and initialize."""
|
"""Test MCP protocol compliance - connection, tools, descriptions, schemas."""
|
||||||
# If we get here, connection and initialization succeeded
|
async with create_mcp_session() as client:
|
||||||
assert mcp_client is not None
|
# Test 1: Connection initializes
|
||||||
|
assert client is not None
|
||||||
|
|
||||||
|
# Test 2: list_tools returns all expected tools
|
||||||
|
result = await client.list_tools()
|
||||||
|
tool_names = [tool.name for tool in result.tools]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
expected_tools = [
|
||||||
async def test_list_tools_returns_all_tools(mcp_client):
|
"list_documents",
|
||||||
"""Test that list_tools returns all expected tools."""
|
"list_tables",
|
||||||
result = await mcp_client.list_tools()
|
"describe_table",
|
||||||
tool_names = [tool.name for tool in result.tools]
|
"get_records",
|
||||||
|
"sql_query",
|
||||||
|
"add_records",
|
||||||
|
"update_records",
|
||||||
|
"delete_records",
|
||||||
|
"create_table",
|
||||||
|
"add_column",
|
||||||
|
"modify_column",
|
||||||
|
"delete_column",
|
||||||
|
]
|
||||||
|
|
||||||
expected_tools = [
|
for expected in expected_tools:
|
||||||
"list_documents",
|
assert expected in tool_names, f"Missing tool: {expected}"
|
||||||
"list_tables",
|
|
||||||
"describe_table",
|
|
||||||
"get_records",
|
|
||||||
"sql_query",
|
|
||||||
"add_records",
|
|
||||||
"update_records",
|
|
||||||
"delete_records",
|
|
||||||
"create_table",
|
|
||||||
"add_column",
|
|
||||||
"modify_column",
|
|
||||||
"delete_column",
|
|
||||||
]
|
|
||||||
|
|
||||||
for expected in expected_tools:
|
assert len(result.tools) == 12, f"Expected 12 tools, got {len(result.tools)}"
|
||||||
assert expected in tool_names, f"Missing tool: {expected}"
|
|
||||||
|
|
||||||
assert len(result.tools) == 12
|
# Test 3: All tools have descriptions
|
||||||
|
for tool in result.tools:
|
||||||
|
assert tool.description, f"Tool {tool.name} has no description"
|
||||||
|
assert len(tool.description) > 10, f"Tool {tool.name} description too short"
|
||||||
|
|
||||||
|
# Test 4: All tools have input schemas
|
||||||
@pytest.mark.asyncio
|
for tool in result.tools:
|
||||||
async def test_list_tools_has_descriptions(mcp_client):
|
assert tool.inputSchema is not None, f"Tool {tool.name} has no inputSchema"
|
||||||
"""Test that all tools have descriptions."""
|
assert "type" in tool.inputSchema, f"Tool {tool.name} schema missing type"
|
||||||
result = await mcp_client.list_tools()
|
|
||||||
|
|
||||||
for tool in result.tools:
|
|
||||||
assert tool.description, f"Tool {tool.name} has no description"
|
|
||||||
assert len(tool.description) > 10, f"Tool {tool.name} description too short"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_has_input_schemas(mcp_client):
|
|
||||||
"""Test that all tools have input schemas."""
|
|
||||||
result = await mcp_client.list_tools()
|
|
||||||
|
|
||||||
for tool in result.tools:
|
|
||||||
assert tool.inputSchema is not None, f"Tool {tool.name} has no inputSchema"
|
|
||||||
assert "type" in tool.inputSchema, f"Tool {tool.name} schema missing type"
|
|
||||||
|
|||||||
@@ -1,252 +1,222 @@
|
|||||||
"""Test tool calls through MCP client to verify Grist API interactions."""
|
"""Test tool calls through MCP client to verify Grist API interactions."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import httpx
|
||||||
import pytest
|
import pytest
|
||||||
|
from mcp import ClientSession
|
||||||
|
from mcp.client.sse import sse_client
|
||||||
|
|
||||||
|
|
||||||
|
GRIST_MCP_URL = "http://localhost:3000"
|
||||||
|
MOCK_GRIST_URL = "http://localhost:8484"
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def create_mcp_session():
|
||||||
|
"""Create and yield an MCP session."""
|
||||||
|
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
|
||||||
|
async with ClientSession(read_stream, write_stream) as session:
|
||||||
|
await session.initialize()
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
def get_mock_request_log():
|
||||||
|
"""Get the request log from mock Grist server."""
|
||||||
|
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
|
||||||
|
return client.get("/_test/requests").json()
|
||||||
|
|
||||||
|
|
||||||
|
def clear_mock_request_log():
|
||||||
|
"""Clear the mock Grist request log."""
|
||||||
|
with httpx.Client(base_url=MOCK_GRIST_URL, timeout=10.0) as client:
|
||||||
|
client.post("/_test/requests/clear")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_documents(mcp_client):
|
async def test_all_tools(services_ready):
|
||||||
"""Test list_documents returns accessible documents."""
|
"""Test all MCP tools - reads, writes, schema ops, and auth errors."""
|
||||||
result = await mcp_client.call_tool("list_documents", {})
|
async with create_mcp_session() as client:
|
||||||
|
# ===== READ TOOLS =====
|
||||||
|
|
||||||
assert len(result.content) == 1
|
# Test list_documents
|
||||||
data = json.loads(result.content[0].text)
|
clear_mock_request_log()
|
||||||
|
result = await client.call_tool("list_documents", {})
|
||||||
|
assert len(result.content) == 1
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "documents" in data
|
||||||
|
assert len(data["documents"]) == 1
|
||||||
|
assert data["documents"][0]["name"] == "test-doc"
|
||||||
|
assert "read" in data["documents"][0]["permissions"]
|
||||||
|
|
||||||
assert "documents" in data
|
# Test list_tables
|
||||||
assert len(data["documents"]) == 1
|
clear_mock_request_log()
|
||||||
assert data["documents"][0]["name"] == "test-doc"
|
result = await client.call_tool("list_tables", {"document": "test-doc"})
|
||||||
assert "read" in data["documents"][0]["permissions"]
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "tables" in data
|
||||||
|
assert "People" in data["tables"]
|
||||||
|
assert "Tasks" in data["tables"]
|
||||||
|
log = get_mock_request_log()
|
||||||
|
assert any("/tables" in entry["path"] for entry in log)
|
||||||
|
|
||||||
|
# Test describe_table
|
||||||
|
clear_mock_request_log()
|
||||||
|
result = await client.call_tool(
|
||||||
|
"describe_table",
|
||||||
|
{"document": "test-doc", "table": "People"}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "columns" in data
|
||||||
|
column_ids = [c["id"] for c in data["columns"]]
|
||||||
|
assert "Name" in column_ids
|
||||||
|
assert "Age" in column_ids
|
||||||
|
log = get_mock_request_log()
|
||||||
|
assert any("/columns" in entry["path"] for entry in log)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# Test get_records
|
||||||
async def test_list_tables(mcp_client, mock_grist_client):
|
clear_mock_request_log()
|
||||||
"""Test list_tables calls correct Grist API endpoint."""
|
result = await client.call_tool(
|
||||||
result = await mcp_client.call_tool("list_tables", {"document": "test-doc"})
|
"get_records",
|
||||||
|
{"document": "test-doc", "table": "People"}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "records" in data
|
||||||
|
assert len(data["records"]) == 2
|
||||||
|
assert data["records"][0]["Name"] == "Alice"
|
||||||
|
log = get_mock_request_log()
|
||||||
|
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
|
||||||
|
|
||||||
# Check response
|
# Test sql_query
|
||||||
data = json.loads(result.content[0].text)
|
clear_mock_request_log()
|
||||||
assert "tables" in data
|
result = await client.call_tool(
|
||||||
assert "People" in data["tables"]
|
"sql_query",
|
||||||
assert "Tasks" in data["tables"]
|
{"document": "test-doc", "query": "SELECT Name, Age FROM People"}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "records" in data
|
||||||
|
assert len(data["records"]) >= 1
|
||||||
|
log = get_mock_request_log()
|
||||||
|
assert any("/sql" in entry["path"] for entry in log)
|
||||||
|
|
||||||
# Verify mock received correct request
|
# ===== WRITE TOOLS =====
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
assert len(log) >= 1
|
|
||||||
assert log[-1]["method"] == "GET"
|
|
||||||
assert "/tables" in log[-1]["path"]
|
|
||||||
|
|
||||||
|
# Test add_records
|
||||||
|
clear_mock_request_log()
|
||||||
|
new_records = [
|
||||||
|
{"Name": "Charlie", "Age": 35, "Email": "charlie@example.com"}
|
||||||
|
]
|
||||||
|
result = await client.call_tool(
|
||||||
|
"add_records",
|
||||||
|
{"document": "test-doc", "table": "People", "records": new_records}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "inserted_ids" in data
|
||||||
|
assert len(data["inserted_ids"]) == 1
|
||||||
|
log = get_mock_request_log()
|
||||||
|
post_requests = [e for e in log if e["method"] == "POST" and "/records" in e["path"]]
|
||||||
|
assert len(post_requests) >= 1
|
||||||
|
assert post_requests[-1]["body"]["records"][0]["fields"]["Name"] == "Charlie"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# Test update_records
|
||||||
async def test_describe_table(mcp_client, mock_grist_client):
|
clear_mock_request_log()
|
||||||
"""Test describe_table returns column information."""
|
updates = [{"id": 1, "fields": {"Age": 31}}]
|
||||||
result = await mcp_client.call_tool(
|
result = await client.call_tool(
|
||||||
"describe_table",
|
"update_records",
|
||||||
{"document": "test-doc", "table": "People"}
|
{"document": "test-doc", "table": "People", "records": updates}
|
||||||
)
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "updated" in data
|
||||||
|
log = get_mock_request_log()
|
||||||
|
patch_requests = [e for e in log if e["method"] == "PATCH" and "/records" in e["path"]]
|
||||||
|
assert len(patch_requests) >= 1
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
# Test delete_records
|
||||||
assert "columns" in data
|
clear_mock_request_log()
|
||||||
|
result = await client.call_tool(
|
||||||
|
"delete_records",
|
||||||
|
{"document": "test-doc", "table": "People", "record_ids": [1, 2]}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "deleted" in data
|
||||||
|
log = get_mock_request_log()
|
||||||
|
delete_requests = [e for e in log if "/data/delete" in e["path"]]
|
||||||
|
assert len(delete_requests) >= 1
|
||||||
|
assert delete_requests[-1]["body"] == [1, 2]
|
||||||
|
|
||||||
column_ids = [c["id"] for c in data["columns"]]
|
# ===== SCHEMA TOOLS =====
|
||||||
assert "Name" in column_ids
|
|
||||||
assert "Age" in column_ids
|
|
||||||
|
|
||||||
# Verify API call
|
# Test create_table
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
clear_mock_request_log()
|
||||||
assert any("/columns" in entry["path"] for entry in log)
|
columns = [
|
||||||
|
{"id": "Title", "type": "Text"},
|
||||||
|
{"id": "Count", "type": "Int"},
|
||||||
|
]
|
||||||
|
result = await client.call_tool(
|
||||||
|
"create_table",
|
||||||
|
{"document": "test-doc", "table_id": "NewTable", "columns": columns}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "table_id" in data
|
||||||
|
log = get_mock_request_log()
|
||||||
|
post_tables = [e for e in log if e["method"] == "POST" and e["path"].endswith("/tables")]
|
||||||
|
assert len(post_tables) >= 1
|
||||||
|
|
||||||
|
# Test add_column
|
||||||
|
clear_mock_request_log()
|
||||||
|
result = await client.call_tool(
|
||||||
|
"add_column",
|
||||||
|
{
|
||||||
|
"document": "test-doc",
|
||||||
|
"table": "People",
|
||||||
|
"column_id": "Phone",
|
||||||
|
"column_type": "Text",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "column_id" in data
|
||||||
|
log = get_mock_request_log()
|
||||||
|
post_cols = [e for e in log if e["method"] == "POST" and "/columns" in e["path"]]
|
||||||
|
assert len(post_cols) >= 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
# Test modify_column
|
||||||
async def test_get_records(mcp_client, mock_grist_client):
|
clear_mock_request_log()
|
||||||
"""Test get_records fetches records from table."""
|
result = await client.call_tool(
|
||||||
result = await mcp_client.call_tool(
|
"modify_column",
|
||||||
"get_records",
|
{
|
||||||
{"document": "test-doc", "table": "People"}
|
"document": "test-doc",
|
||||||
)
|
"table": "People",
|
||||||
|
"column_id": "Age",
|
||||||
|
"type": "Numeric",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "modified" in data
|
||||||
|
log = get_mock_request_log()
|
||||||
|
patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]]
|
||||||
|
assert len(patch_cols) >= 1
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
# Test delete_column
|
||||||
assert "records" in data
|
clear_mock_request_log()
|
||||||
assert len(data["records"]) == 2
|
result = await client.call_tool(
|
||||||
assert data["records"][0]["Name"] == "Alice"
|
"delete_column",
|
||||||
|
{
|
||||||
|
"document": "test-doc",
|
||||||
|
"table": "People",
|
||||||
|
"column_id": "Email",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "deleted" in data
|
||||||
|
log = get_mock_request_log()
|
||||||
|
delete_cols = [e for e in log if e["method"] == "DELETE" and "/columns/" in e["path"]]
|
||||||
|
assert len(delete_cols) >= 1
|
||||||
|
|
||||||
# Verify API call
|
# ===== AUTHORIZATION =====
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
|
|
||||||
|
|
||||||
|
# Test unauthorized document fails
|
||||||
@pytest.mark.asyncio
|
result = await client.call_tool(
|
||||||
async def test_sql_query(mcp_client, mock_grist_client):
|
"list_tables",
|
||||||
"""Test sql_query executes SQL and returns results."""
|
{"document": "unauthorized-doc"}
|
||||||
result = await mcp_client.call_tool(
|
)
|
||||||
"sql_query",
|
assert "error" in result.content[0].text.lower() or "authorization" in result.content[0].text.lower()
|
||||||
{"document": "test-doc", "query": "SELECT Name, Age FROM People"}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
|
||||||
assert "records" in data
|
|
||||||
assert len(data["records"]) >= 1
|
|
||||||
|
|
||||||
# Verify API call
|
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
assert any("/sql" in entry["path"] for entry in log)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_records(mcp_client, mock_grist_client):
|
|
||||||
"""Test add_records sends correct payload to Grist."""
|
|
||||||
new_records = [
|
|
||||||
{"Name": "Charlie", "Age": 35, "Email": "charlie@example.com"}
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await mcp_client.call_tool(
|
|
||||||
"add_records",
|
|
||||||
{"document": "test-doc", "table": "People", "records": new_records}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
|
||||||
assert "record_ids" in data
|
|
||||||
assert len(data["record_ids"]) == 1
|
|
||||||
|
|
||||||
# Verify API call body
|
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
post_requests = [e for e in log if e["method"] == "POST" and "/records" in e["path"]]
|
|
||||||
assert len(post_requests) >= 1
|
|
||||||
assert post_requests[-1]["body"]["records"][0]["fields"]["Name"] == "Charlie"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_records(mcp_client, mock_grist_client):
|
|
||||||
"""Test update_records sends correct payload to Grist."""
|
|
||||||
updates = [
|
|
||||||
{"id": 1, "fields": {"Age": 31}}
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await mcp_client.call_tool(
|
|
||||||
"update_records",
|
|
||||||
{"document": "test-doc", "table": "People", "records": updates}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
|
||||||
assert "updated" in data
|
|
||||||
|
|
||||||
# Verify API call
|
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
patch_requests = [e for e in log if e["method"] == "PATCH" and "/records" in e["path"]]
|
|
||||||
assert len(patch_requests) >= 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_records(mcp_client, mock_grist_client):
|
|
||||||
"""Test delete_records sends correct IDs to Grist."""
|
|
||||||
result = await mcp_client.call_tool(
|
|
||||||
"delete_records",
|
|
||||||
{"document": "test-doc", "table": "People", "record_ids": [1, 2]}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
|
||||||
assert "deleted" in data
|
|
||||||
|
|
||||||
# Verify API call
|
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
delete_requests = [e for e in log if "/data/delete" in e["path"]]
|
|
||||||
assert len(delete_requests) >= 1
|
|
||||||
assert delete_requests[-1]["body"] == [1, 2]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_table(mcp_client, mock_grist_client):
|
|
||||||
"""Test create_table sends correct schema to Grist."""
|
|
||||||
columns = [
|
|
||||||
{"id": "Title", "type": "Text"},
|
|
||||||
{"id": "Count", "type": "Int"},
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await mcp_client.call_tool(
|
|
||||||
"create_table",
|
|
||||||
{"document": "test-doc", "table_id": "NewTable", "columns": columns}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
|
||||||
assert "table_id" in data
|
|
||||||
|
|
||||||
# Verify API call
|
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
post_tables = [e for e in log if e["method"] == "POST" and e["path"].endswith("/tables")]
|
|
||||||
assert len(post_tables) >= 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_column(mcp_client, mock_grist_client):
|
|
||||||
"""Test add_column sends correct column definition."""
|
|
||||||
result = await mcp_client.call_tool(
|
|
||||||
"add_column",
|
|
||||||
{
|
|
||||||
"document": "test-doc",
|
|
||||||
"table": "People",
|
|
||||||
"column_id": "Phone",
|
|
||||||
"column_type": "Text",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
|
||||||
assert "column_id" in data
|
|
||||||
|
|
||||||
# Verify API call
|
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
post_cols = [e for e in log if e["method"] == "POST" and "/columns" in e["path"]]
|
|
||||||
assert len(post_cols) >= 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_modify_column(mcp_client, mock_grist_client):
|
|
||||||
"""Test modify_column sends correct update."""
|
|
||||||
result = await mcp_client.call_tool(
|
|
||||||
"modify_column",
|
|
||||||
{
|
|
||||||
"document": "test-doc",
|
|
||||||
"table": "People",
|
|
||||||
"column_id": "Age",
|
|
||||||
"type": "Numeric",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
|
||||||
assert "modified" in data
|
|
||||||
|
|
||||||
# Verify API call
|
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]]
|
|
||||||
assert len(patch_cols) >= 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_column(mcp_client, mock_grist_client):
|
|
||||||
"""Test delete_column calls correct endpoint."""
|
|
||||||
result = await mcp_client.call_tool(
|
|
||||||
"delete_column",
|
|
||||||
{
|
|
||||||
"document": "test-doc",
|
|
||||||
"table": "People",
|
|
||||||
"column_id": "Email",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
data = json.loads(result.content[0].text)
|
|
||||||
assert "deleted" in data
|
|
||||||
|
|
||||||
# Verify API call
|
|
||||||
log = mock_grist_client.get("/_test/requests").json()
|
|
||||||
delete_cols = [e for e in log if e["method"] == "DELETE" and "/columns/" in e["path"]]
|
|
||||||
assert len(delete_cols) >= 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unauthorized_document_fails(mcp_client):
|
|
||||||
"""Test that accessing unauthorized document returns error."""
|
|
||||||
result = await mcp_client.call_tool(
|
|
||||||
"list_tables",
|
|
||||||
{"document": "unauthorized-doc"}
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "error" in result.content[0].text.lower() or "authorization" in result.content[0].text.lower()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user