Files
grist-mcp-server/src/grist_mcp/grist_client.py
Bill 204d00caf4
All checks were successful
Build and Push Docker Image / build (push) Successful in 14s
feat: add host_header config for Docker networking
When Grist validates the Host header (common with reverse proxy setups),
internal Docker networking fails because requests arrive with
Host: container-name instead of the external domain.

The new host_header config option allows overriding the Host header
sent to Grist while still connecting via internal Docker hostnames.
2026-01-01 14:06:31 -05:00

171 lines
5.8 KiB
Python

"""Grist API client."""
import json
import httpx
from grist_mcp.config import Document
# Default timeout for HTTP requests (30 seconds)
DEFAULT_TIMEOUT = 30.0
class GristClient:
"""Async client for Grist API operations."""
def __init__(self, document: Document, timeout: float = DEFAULT_TIMEOUT):
self._doc = document
self._base_url = f"{document.url.rstrip('/')}/api/docs/{document.doc_id}"
self._headers = {"Authorization": f"Bearer {document.api_key}"}
if document.host_header:
self._headers["Host"] = document.host_header
self._timeout = timeout
async def _request(self, method: str, path: str, **kwargs) -> dict:
"""Make an authenticated request to Grist API."""
async with httpx.AsyncClient(timeout=self._timeout) as client:
response = await client.request(
method,
f"{self._base_url}{path}",
headers=self._headers,
**kwargs,
)
response.raise_for_status()
return response.json() if response.content else {}
# Read operations
async def list_tables(self) -> list[str]:
"""List all tables in the document."""
data = await self._request("GET", "/tables")
return [t["id"] for t in data.get("tables", [])]
async def describe_table(self, table: str) -> list[dict]:
"""Get column information for a table."""
data = await self._request("GET", f"/tables/{table}/columns")
return [
{
"id": col["id"],
"type": col["fields"].get("type", "Any"),
"formula": col["fields"].get("formula", ""),
}
for col in data.get("columns", [])
]
async def get_records(
self,
table: str,
filter: dict | None = None,
sort: str | None = None,
limit: int | None = None,
) -> list[dict]:
"""Fetch records from a table."""
params = {}
if filter:
params["filter"] = json.dumps(filter)
if sort:
params["sort"] = sort
if limit:
params["limit"] = limit
data = await self._request("GET", f"/tables/{table}/records", params=params)
return [
{"id": r["id"], **r["fields"]}
for r in data.get("records", [])
]
async def sql_query(self, sql: str) -> list[dict]:
"""Run a read-only SQL query.
Raises:
ValueError: If query is not a SELECT statement or contains multiple statements.
"""
self._validate_sql_query(sql)
data = await self._request("GET", "/sql", params={"q": sql})
return [r["fields"] for r in data.get("records", [])]
@staticmethod
def _validate_sql_query(sql: str) -> None:
"""Validate SQL query for safety.
Only allows SELECT statements and rejects multiple statements.
"""
sql_stripped = sql.strip()
if not sql_stripped.upper().startswith("SELECT"):
raise ValueError("Only SELECT queries are allowed")
if ";" in sql_stripped[:-1]: # Allow trailing semicolon
raise ValueError("Multiple statements not allowed")
# Write operations
async def add_records(self, table: str, records: list[dict]) -> list[int]:
"""Add records to a table. Returns list of new record IDs."""
payload = {
"records": [{"fields": r} for r in records]
}
data = await self._request("POST", f"/tables/{table}/records", json=payload)
return [r["id"] for r in data.get("records", [])]
async def update_records(self, table: str, records: list[dict]) -> None:
"""Update records. Each record must have 'id' and 'fields' keys."""
payload = {"records": records}
await self._request("PATCH", f"/tables/{table}/records", json=payload)
async def delete_records(self, table: str, record_ids: list[int]) -> None:
"""Delete records by ID."""
await self._request("POST", f"/tables/{table}/data/delete", json=record_ids)
# Schema operations
async def create_table(self, table_id: str, columns: list[dict]) -> str:
"""Create a new table with columns. Returns table ID."""
payload = {
"tables": [{
"id": table_id,
"columns": [
{"id": c["id"], "fields": {"type": c["type"]}}
for c in columns
],
}]
}
data = await self._request("POST", "/tables", json=payload)
return data["tables"][0]["id"]
async def add_column(
self,
table: str,
column_id: str,
column_type: str,
formula: str | None = None,
) -> str:
"""Add a column to a table. Returns column ID."""
fields = {"type": column_type}
if formula:
fields["formula"] = formula
payload = {"columns": [{"id": column_id, "fields": fields}]}
data = await self._request("POST", f"/tables/{table}/columns", json=payload)
return data["columns"][0]["id"]
async def modify_column(
self,
table: str,
column_id: str,
type: str | None = None,
formula: str | None = None,
) -> None:
"""Modify a column's type or formula."""
fields = {}
if type is not None:
fields["type"] = type
if formula is not None:
fields["formula"] = formula
payload = {"columns": [{"id": column_id, "fields": fields}]}
await self._request("PATCH", f"/tables/{table}/columns", json=payload)
async def delete_column(self, table: str, column_id: str) -> None:
"""Delete a column from a table."""
await self._request("DELETE", f"/tables/{table}/columns/{column_id}")