Files
grist-mcp-server/tests/integration/mock_grist/server.py
Bill Ballou a97930848b feat: normalize filter values to array format for Grist API
The Grist API requires all filter values to be arrays. This change adds
automatic normalization of filter values in get_records, wrapping single
values in lists before sending to the API.

This fixes 400 errors when filtering on Ref columns with single integer IDs.

Changes:
- Add filters.py module with normalize_filter function
- Update get_records to normalize filters before API call
- Add Orders table with Ref column to mock Grist server
- Add filter validation to mock server (rejects non-array values)
- Fix shell script shebangs for portability (#!/usr/bin/env bash)
2026-01-14 17:56:18 -05:00

268 lines
9.7 KiB
Python

"""Mock Grist API server for integration testing."""
import json
import logging
import os
from datetime import datetime, timezone
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
logging.basicConfig(level=logging.INFO, format="%(asctime)s [MOCK-GRIST] %(message)s")
logger = logging.getLogger(__name__)
# Mock data
MOCK_TABLES = {
"People": {
"columns": [
{"id": "Name", "fields": {"type": "Text"}},
{"id": "Age", "fields": {"type": "Int"}},
{"id": "Email", "fields": {"type": "Text"}},
],
"records": [
{"id": 1, "fields": {"Name": "Alice", "Age": 30, "Email": "alice@example.com"}},
{"id": 2, "fields": {"Name": "Bob", "Age": 25, "Email": "bob@example.com"}},
],
},
"Tasks": {
"columns": [
{"id": "Title", "fields": {"type": "Text"}},
{"id": "Done", "fields": {"type": "Bool"}},
],
"records": [
{"id": 1, "fields": {"Title": "Write tests", "Done": False}},
{"id": 2, "fields": {"Title": "Deploy", "Done": False}},
],
},
"Orders": {
"columns": [
{"id": "OrderNum", "fields": {"type": "Int"}},
{"id": "Customer", "fields": {"type": "Ref:People"}},
{"id": "Amount", "fields": {"type": "Numeric"}},
],
"records": [
{"id": 1, "fields": {"OrderNum": 1001, "Customer": 1, "Amount": 100.0}},
{"id": 2, "fields": {"OrderNum": 1002, "Customer": 2, "Amount": 200.0}},
{"id": 3, "fields": {"OrderNum": 1003, "Customer": 1, "Amount": 150.0}},
],
},
}
# Track requests for test assertions
request_log: list[dict] = []
def log_request(method: str, path: str, body: dict | None = None):
"""Log a request for later inspection."""
entry = {
"timestamp": datetime.now(timezone.utc).isoformat(),
"method": method,
"path": path,
"body": body,
}
request_log.append(entry)
logger.info(f"{method} {path}" + (f" body={json.dumps(body)}" if body else ""))
async def health(request):
"""Health check endpoint."""
return JSONResponse({"status": "ok"})
async def get_request_log(request):
"""Return the request log for test assertions."""
return JSONResponse(request_log)
async def clear_request_log(request):
"""Clear the request log."""
request_log.clear()
return JSONResponse({"status": "cleared"})
async def list_tables(request):
"""GET /api/docs/{doc_id}/tables"""
doc_id = request.path_params["doc_id"]
log_request("GET", f"/api/docs/{doc_id}/tables")
tables = [{"id": name} for name in MOCK_TABLES.keys()]
return JSONResponse({"tables": tables})
async def get_table_columns(request):
"""GET /api/docs/{doc_id}/tables/{table_id}/columns"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
log_request("GET", f"/api/docs/{doc_id}/tables/{table_id}/columns")
if table_id not in MOCK_TABLES:
return JSONResponse({"error": "Table not found"}, status_code=404)
return JSONResponse({"columns": MOCK_TABLES[table_id]["columns"]})
async def get_records(request):
"""GET /api/docs/{doc_id}/tables/{table_id}/records"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
filter_param = request.query_params.get("filter")
log_request("GET", f"/api/docs/{doc_id}/tables/{table_id}/records?filter={filter_param}")
if table_id not in MOCK_TABLES:
return JSONResponse({"error": "Table not found"}, status_code=404)
records = MOCK_TABLES[table_id]["records"]
# Apply filtering if provided
if filter_param:
try:
filters = json.loads(filter_param)
# Validate filter format: all values must be arrays (Grist API requirement)
for key, values in filters.items():
if not isinstance(values, list):
return JSONResponse(
{"error": f"Filter values must be arrays, got {type(values).__name__} for '{key}'"},
status_code=400
)
# Apply filters: record matches if field value is in the filter list
filtered_records = []
for record in records:
match = True
for key, allowed_values in filters.items():
if record["fields"].get(key) not in allowed_values:
match = False
break
if match:
filtered_records.append(record)
records = filtered_records
except json.JSONDecodeError:
return JSONResponse({"error": "Invalid filter JSON"}, status_code=400)
return JSONResponse({"records": records})
async def add_records(request):
"""POST /api/docs/{doc_id}/tables/{table_id}/records"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
body = await request.json()
log_request("POST", f"/api/docs/{doc_id}/tables/{table_id}/records", body)
# Return mock IDs for new records
new_ids = [{"id": 100 + i} for i in range(len(body.get("records", [])))]
return JSONResponse({"records": new_ids})
async def update_records(request):
"""PATCH /api/docs/{doc_id}/tables/{table_id}/records"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
body = await request.json()
log_request("PATCH", f"/api/docs/{doc_id}/tables/{table_id}/records", body)
return JSONResponse({})
async def delete_records(request):
"""POST /api/docs/{doc_id}/tables/{table_id}/data/delete"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
body = await request.json()
log_request("POST", f"/api/docs/{doc_id}/tables/{table_id}/data/delete", body)
return JSONResponse({})
async def sql_query(request):
"""GET /api/docs/{doc_id}/sql"""
doc_id = request.path_params["doc_id"]
query = request.query_params.get("q", "")
log_request("GET", f"/api/docs/{doc_id}/sql?q={query}")
# Return mock SQL results
return JSONResponse({
"records": [
{"fields": {"Name": "Alice", "Age": 30}},
{"fields": {"Name": "Bob", "Age": 25}},
]
})
async def create_tables(request):
"""POST /api/docs/{doc_id}/tables"""
doc_id = request.path_params["doc_id"]
body = await request.json()
log_request("POST", f"/api/docs/{doc_id}/tables", body)
# Return the created tables with their IDs
tables = [{"id": t["id"]} for t in body.get("tables", [])]
return JSONResponse({"tables": tables})
async def add_column(request):
"""POST /api/docs/{doc_id}/tables/{table_id}/columns"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
body = await request.json()
log_request("POST", f"/api/docs/{doc_id}/tables/{table_id}/columns", body)
columns = [{"id": c["id"]} for c in body.get("columns", [])]
return JSONResponse({"columns": columns})
async def modify_column(request):
"""PATCH /api/docs/{doc_id}/tables/{table_id}/columns/{col_id}"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
col_id = request.path_params["col_id"]
body = await request.json()
log_request("PATCH", f"/api/docs/{doc_id}/tables/{table_id}/columns/{col_id}", body)
return JSONResponse({})
async def modify_columns(request):
"""PATCH /api/docs/{doc_id}/tables/{table_id}/columns - batch modify columns"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
body = await request.json()
log_request("PATCH", f"/api/docs/{doc_id}/tables/{table_id}/columns", body)
return JSONResponse({})
async def delete_column(request):
"""DELETE /api/docs/{doc_id}/tables/{table_id}/columns/{col_id}"""
doc_id = request.path_params["doc_id"]
table_id = request.path_params["table_id"]
col_id = request.path_params["col_id"]
log_request("DELETE", f"/api/docs/{doc_id}/tables/{table_id}/columns/{col_id}")
return JSONResponse({})
app = Starlette(
routes=[
# Test control endpoints
Route("/health", endpoint=health),
Route("/_test/requests", endpoint=get_request_log),
Route("/_test/requests/clear", endpoint=clear_request_log, methods=["POST"]),
# Grist API endpoints
Route("/api/docs/{doc_id}/tables", endpoint=list_tables),
Route("/api/docs/{doc_id}/tables", endpoint=create_tables, methods=["POST"]),
Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=get_table_columns),
Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=add_column, methods=["POST"]),
Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=modify_columns, methods=["PATCH"]),
Route("/api/docs/{doc_id}/tables/{table_id}/columns/{col_id}", endpoint=modify_column, methods=["PATCH"]),
Route("/api/docs/{doc_id}/tables/{table_id}/columns/{col_id}", endpoint=delete_column, methods=["DELETE"]),
Route("/api/docs/{doc_id}/tables/{table_id}/records", endpoint=get_records),
Route("/api/docs/{doc_id}/tables/{table_id}/records", endpoint=add_records, methods=["POST"]),
Route("/api/docs/{doc_id}/tables/{table_id}/records", endpoint=update_records, methods=["PATCH"]),
Route("/api/docs/{doc_id}/tables/{table_id}/data/delete", endpoint=delete_records, methods=["POST"]),
Route("/api/docs/{doc_id}/sql", endpoint=sql_query),
]
)
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", "8484"))
logger.info(f"Starting mock Grist server on port {port}")
uvicorn.run(app, host="0.0.0.0", port=port)