From c4ddc3b1b02227f36ade996296e902b404eadcdc Mon Sep 17 00:00:00 2001 From: Bill Date: Wed, 3 Dec 2025 14:34:37 -0500 Subject: [PATCH] feat: add Grist API client --- src/grist_mcp/grist_client.py | 144 +++++++++++++++++++++++++++ tests/test_grist_client.py | 181 ++++++++++++++++++++++++++++++++++ 2 files changed, 325 insertions(+) create mode 100644 src/grist_mcp/grist_client.py create mode 100644 tests/test_grist_client.py diff --git a/src/grist_mcp/grist_client.py b/src/grist_mcp/grist_client.py new file mode 100644 index 0000000..1ae28ee --- /dev/null +++ b/src/grist_mcp/grist_client.py @@ -0,0 +1,144 @@ +"""Grist API client.""" + +import httpx + +from grist_mcp.config import Document + + +class GristClient: + """Async client for Grist API operations.""" + + def __init__(self, document: Document): + self._doc = document + self._base_url = f"{document.url.rstrip('/')}/api/docs/{document.doc_id}" + self._headers = {"Authorization": f"Bearer {document.api_key}"} + + async def _request(self, method: str, path: str, **kwargs) -> dict: + """Make an authenticated request to Grist API.""" + async with httpx.AsyncClient() as client: + response = await client.request( + method, + f"{self._base_url}{path}", + headers=self._headers, + **kwargs, + ) + response.raise_for_status() + return response.json() if response.content else {} + + # Read operations + + async def list_tables(self) -> list[str]: + """List all tables in the document.""" + data = await self._request("GET", "/tables") + return [t["id"] for t in data.get("tables", [])] + + async def describe_table(self, table: str) -> list[dict]: + """Get column information for a table.""" + data = await self._request("GET", f"/tables/{table}/columns") + return [ + { + "id": col["id"], + "type": col["fields"].get("type", "Any"), + "formula": col["fields"].get("formula", ""), + } + for col in data.get("columns", []) + ] + + async def get_records( + self, + table: str, + filter: dict | None = None, + sort: str | None = None, + limit: int | None = None, + ) -> list[dict]: + """Fetch records from a table.""" + params = {} + if filter: + params["filter"] = filter + if sort: + params["sort"] = sort + if limit: + params["limit"] = limit + + data = await self._request("GET", f"/tables/{table}/records", params=params) + + return [ + {"id": r["id"], **r["fields"]} + for r in data.get("records", []) + ] + + async def sql_query(self, sql: str) -> list[dict]: + """Run a read-only SQL query.""" + data = await self._request("GET", "/sql", params={"q": sql}) + return [r["fields"] for r in data.get("records", [])] + + # Write operations + + async def add_records(self, table: str, records: list[dict]) -> list[int]: + """Add records to a table. Returns list of new record IDs.""" + payload = { + "records": [{"fields": r} for r in records] + } + data = await self._request("POST", f"/tables/{table}/records", json=payload) + return [r["id"] for r in data.get("records", [])] + + async def update_records(self, table: str, records: list[dict]) -> None: + """Update records. Each record must have 'id' and 'fields' keys.""" + payload = {"records": records} + await self._request("PATCH", f"/tables/{table}/records", json=payload) + + async def delete_records(self, table: str, record_ids: list[int]) -> None: + """Delete records by ID.""" + await self._request("POST", f"/tables/{table}/data/delete", json=record_ids) + + # Schema operations + + async def create_table(self, table_id: str, columns: list[dict]) -> str: + """Create a new table with columns. Returns table ID.""" + payload = { + "tables": [{ + "id": table_id, + "columns": [ + {"id": c["id"], "fields": {"type": c["type"]}} + for c in columns + ], + }] + } + data = await self._request("POST", "/tables", json=payload) + return data["tables"][0]["id"] + + async def add_column( + self, + table: str, + column_id: str, + column_type: str, + formula: str | None = None, + ) -> str: + """Add a column to a table. Returns column ID.""" + fields = {"type": column_type} + if formula: + fields["formula"] = formula + + payload = {"columns": [{"id": column_id, "fields": fields}]} + data = await self._request("POST", f"/tables/{table}/columns", json=payload) + return data["columns"][0]["id"] + + async def modify_column( + self, + table: str, + column_id: str, + type: str | None = None, + formula: str | None = None, + ) -> None: + """Modify a column's type or formula.""" + fields = {} + if type is not None: + fields["type"] = type + if formula is not None: + fields["formula"] = formula + + await self._request("PATCH", f"/tables/{table}/columns/{column_id}", json={"fields": fields}) + + async def delete_column(self, table: str, column_id: str) -> None: + """Delete a column from a table.""" + await self._request("DELETE", f"/tables/{table}/columns/{column_id}") diff --git a/tests/test_grist_client.py b/tests/test_grist_client.py new file mode 100644 index 0000000..1f3b2c2 --- /dev/null +++ b/tests/test_grist_client.py @@ -0,0 +1,181 @@ +import pytest +from pytest_httpx import HTTPXMock + +from grist_mcp.grist_client import GristClient +from grist_mcp.config import Document + + +@pytest.fixture +def doc(): + return Document( + url="https://grist.example.com", + doc_id="abc123", + api_key="test-api-key", + ) + + +@pytest.fixture +def client(doc): + return GristClient(doc) + + +@pytest.mark.asyncio +async def test_list_tables(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables", + json={"tables": [{"id": "Table1"}, {"id": "Table2"}]}, + ) + + tables = await client.list_tables() + + assert tables == ["Table1", "Table2"] + + +@pytest.mark.asyncio +async def test_describe_table(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables/Table1/columns", + json={ + "columns": [ + {"id": "Name", "fields": {"type": "Text", "formula": ""}}, + {"id": "Amount", "fields": {"type": "Numeric", "formula": "$Price * $Qty"}}, + ] + }, + ) + + columns = await client.describe_table("Table1") + + assert len(columns) == 2 + assert columns[0] == {"id": "Name", "type": "Text", "formula": ""} + assert columns[1] == {"id": "Amount", "type": "Numeric", "formula": "$Price * $Qty"} + + +@pytest.mark.asyncio +async def test_get_records(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables/Table1/records", + json={ + "records": [ + {"id": 1, "fields": {"Name": "Alice", "Amount": 100}}, + {"id": 2, "fields": {"Name": "Bob", "Amount": 200}}, + ] + }, + ) + + records = await client.get_records("Table1") + + assert len(records) == 2 + assert records[0] == {"id": 1, "Name": "Alice", "Amount": 100} + + +@pytest.mark.asyncio +async def test_add_records(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables/Table1/records", + method="POST", + json={"records": [{"id": 3}, {"id": 4}]}, + ) + + ids = await client.add_records("Table1", [ + {"Name": "Charlie", "Amount": 300}, + {"Name": "Diana", "Amount": 400}, + ]) + + assert ids == [3, 4] + + +@pytest.mark.asyncio +async def test_update_records(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables/Table1/records", + method="PATCH", + json={}, + ) + + # Should not raise + await client.update_records("Table1", [ + {"id": 1, "fields": {"Amount": 150}}, + ]) + + +@pytest.mark.asyncio +async def test_delete_records(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables/Table1/data/delete", + method="POST", + json={}, + ) + + # Should not raise + await client.delete_records("Table1", [1, 2]) + + +@pytest.mark.asyncio +async def test_sql_query(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/sql?q=SELECT+*+FROM+Table1", + method="GET", + json={ + "statement": "SELECT * FROM Table1", + "records": [ + {"fields": {"Name": "Alice", "Amount": 100}}, + ], + }, + ) + + result = await client.sql_query("SELECT * FROM Table1") + + assert result == [{"Name": "Alice", "Amount": 100}] + + +@pytest.mark.asyncio +async def test_create_table(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables", + method="POST", + json={"tables": [{"id": "NewTable"}]}, + ) + + table_id = await client.create_table("NewTable", [ + {"id": "Col1", "type": "Text"}, + {"id": "Col2", "type": "Numeric"}, + ]) + + assert table_id == "NewTable" + + +@pytest.mark.asyncio +async def test_add_column(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables/Table1/columns", + method="POST", + json={"columns": [{"id": "NewCol"}]}, + ) + + col_id = await client.add_column("Table1", "NewCol", "Text", formula=None) + + assert col_id == "NewCol" + + +@pytest.mark.asyncio +async def test_modify_column(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables/Table1/columns/Amount", + method="PATCH", + json={}, + ) + + # Should not raise + await client.modify_column("Table1", "Amount", type="Int", formula="$Price * $Qty") + + +@pytest.mark.asyncio +async def test_delete_column(client, httpx_mock: HTTPXMock): + httpx_mock.add_response( + url="https://grist.example.com/api/docs/abc123/tables/Table1/columns/OldCol", + method="DELETE", + json={}, + ) + + # Should not raise + await client.delete_column("Table1", "OldCol")