Add GET /api/v1/attachments/{id} endpoint for downloading attachments
through the MCP proxy. This complements the existing upload endpoint and
enables complete attachment workflows via the proxy API.
236 lines
8.0 KiB
Python
236 lines
8.0 KiB
Python
"""Grist API client."""
|
|
|
|
import json
|
|
|
|
import httpx
|
|
|
|
from grist_mcp.config import Document
|
|
|
|
# Default timeout for HTTP requests (30 seconds)
|
|
DEFAULT_TIMEOUT = 30.0
|
|
|
|
|
|
class GristClient:
|
|
"""Async client for Grist API operations."""
|
|
|
|
def __init__(self, document: Document, timeout: float = DEFAULT_TIMEOUT):
|
|
self._doc = document
|
|
self._base_url = f"{document.url.rstrip('/')}/api/docs/{document.doc_id}"
|
|
self._headers = {"Authorization": f"Bearer {document.api_key}"}
|
|
if document.host_header:
|
|
self._headers["Host"] = document.host_header
|
|
self._timeout = timeout
|
|
|
|
async def _request(self, method: str, path: str, **kwargs) -> dict:
|
|
"""Make an authenticated request to Grist API."""
|
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
|
response = await client.request(
|
|
method,
|
|
f"{self._base_url}{path}",
|
|
headers=self._headers,
|
|
**kwargs,
|
|
)
|
|
response.raise_for_status()
|
|
return response.json() if response.content else {}
|
|
|
|
# Read operations
|
|
|
|
async def list_tables(self) -> list[str]:
|
|
"""List all tables in the document."""
|
|
data = await self._request("GET", "/tables")
|
|
return [t["id"] for t in data.get("tables", [])]
|
|
|
|
async def describe_table(self, table: str) -> list[dict]:
|
|
"""Get column information for a table."""
|
|
data = await self._request("GET", f"/tables/{table}/columns")
|
|
return [
|
|
{
|
|
"id": col["id"],
|
|
"type": col["fields"].get("type", "Any"),
|
|
"formula": col["fields"].get("formula", ""),
|
|
}
|
|
for col in data.get("columns", [])
|
|
]
|
|
|
|
async def get_records(
|
|
self,
|
|
table: str,
|
|
filter: dict | None = None,
|
|
sort: str | None = None,
|
|
limit: int | None = None,
|
|
) -> list[dict]:
|
|
"""Fetch records from a table."""
|
|
params = {}
|
|
if filter:
|
|
params["filter"] = json.dumps(filter)
|
|
if sort:
|
|
params["sort"] = sort
|
|
if limit:
|
|
params["limit"] = limit
|
|
|
|
data = await self._request("GET", f"/tables/{table}/records", params=params)
|
|
|
|
return [
|
|
{"id": r["id"], **r["fields"]}
|
|
for r in data.get("records", [])
|
|
]
|
|
|
|
async def sql_query(self, sql: str) -> list[dict]:
|
|
"""Run a read-only SQL query.
|
|
|
|
Raises:
|
|
ValueError: If query is not a SELECT statement or contains multiple statements.
|
|
"""
|
|
self._validate_sql_query(sql)
|
|
data = await self._request("GET", "/sql", params={"q": sql})
|
|
return [r["fields"] for r in data.get("records", [])]
|
|
|
|
@staticmethod
|
|
def _validate_sql_query(sql: str) -> None:
|
|
"""Validate SQL query for safety.
|
|
|
|
Only allows SELECT statements and rejects multiple statements.
|
|
"""
|
|
sql_stripped = sql.strip()
|
|
if not sql_stripped.upper().startswith("SELECT"):
|
|
raise ValueError("Only SELECT queries are allowed")
|
|
if ";" in sql_stripped[:-1]: # Allow trailing semicolon
|
|
raise ValueError("Multiple statements not allowed")
|
|
|
|
# Write operations
|
|
|
|
async def add_records(self, table: str, records: list[dict]) -> list[int]:
|
|
"""Add records to a table. Returns list of new record IDs."""
|
|
payload = {
|
|
"records": [{"fields": r} for r in records]
|
|
}
|
|
data = await self._request("POST", f"/tables/{table}/records", json=payload)
|
|
return [r["id"] for r in data.get("records", [])]
|
|
|
|
async def update_records(self, table: str, records: list[dict]) -> None:
|
|
"""Update records. Each record must have 'id' and 'fields' keys."""
|
|
payload = {"records": records}
|
|
await self._request("PATCH", f"/tables/{table}/records", json=payload)
|
|
|
|
async def delete_records(self, table: str, record_ids: list[int]) -> None:
|
|
"""Delete records by ID."""
|
|
await self._request("POST", f"/tables/{table}/data/delete", json=record_ids)
|
|
|
|
async def upload_attachment(
|
|
self,
|
|
filename: str,
|
|
content: bytes,
|
|
content_type: str = "application/octet-stream",
|
|
) -> dict:
|
|
"""Upload a file attachment. Returns attachment metadata.
|
|
|
|
Args:
|
|
filename: Name for the uploaded file.
|
|
content: File content as bytes.
|
|
content_type: MIME type of the file.
|
|
|
|
Returns:
|
|
Dict with attachment_id, filename, and size_bytes.
|
|
"""
|
|
files = {"upload": (filename, content, content_type)}
|
|
|
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
|
response = await client.post(
|
|
f"{self._base_url}/attachments",
|
|
headers=self._headers,
|
|
files=files,
|
|
)
|
|
response.raise_for_status()
|
|
# Grist returns list of attachment IDs
|
|
attachment_ids = response.json()
|
|
return {
|
|
"attachment_id": attachment_ids[0],
|
|
"filename": filename,
|
|
"size_bytes": len(content),
|
|
}
|
|
|
|
async def download_attachment(self, attachment_id: int) -> dict:
|
|
"""Download an attachment by ID.
|
|
|
|
Args:
|
|
attachment_id: The ID of the attachment to download.
|
|
|
|
Returns:
|
|
Dict with content (bytes), content_type, and filename.
|
|
"""
|
|
import re
|
|
|
|
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
|
response = await client.get(
|
|
f"{self._base_url}/attachments/{attachment_id}/download",
|
|
headers=self._headers,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
# Extract filename from Content-Disposition header
|
|
content_disp = response.headers.get("content-disposition", "")
|
|
filename = None
|
|
if "filename=" in content_disp:
|
|
match = re.search(r'filename="?([^";]+)"?', content_disp)
|
|
if match:
|
|
filename = match.group(1)
|
|
|
|
return {
|
|
"content": response.content,
|
|
"content_type": response.headers.get("content-type", "application/octet-stream"),
|
|
"filename": filename,
|
|
}
|
|
|
|
# Schema operations
|
|
|
|
async def create_table(self, table_id: str, columns: list[dict]) -> str:
|
|
"""Create a new table with columns. Returns table ID."""
|
|
payload = {
|
|
"tables": [{
|
|
"id": table_id,
|
|
"columns": [
|
|
{"id": c["id"], "fields": {"type": c["type"]}}
|
|
for c in columns
|
|
],
|
|
}]
|
|
}
|
|
data = await self._request("POST", "/tables", json=payload)
|
|
return data["tables"][0]["id"]
|
|
|
|
async def add_column(
|
|
self,
|
|
table: str,
|
|
column_id: str,
|
|
column_type: str,
|
|
formula: str | None = None,
|
|
) -> str:
|
|
"""Add a column to a table. Returns column ID."""
|
|
fields = {"type": column_type}
|
|
if formula:
|
|
fields["formula"] = formula
|
|
|
|
payload = {"columns": [{"id": column_id, "fields": fields}]}
|
|
data = await self._request("POST", f"/tables/{table}/columns", json=payload)
|
|
return data["columns"][0]["id"]
|
|
|
|
async def modify_column(
|
|
self,
|
|
table: str,
|
|
column_id: str,
|
|
type: str | None = None,
|
|
formula: str | None = None,
|
|
) -> None:
|
|
"""Modify a column's type or formula."""
|
|
fields = {}
|
|
if type is not None:
|
|
fields["type"] = type
|
|
if formula is not None:
|
|
fields["formula"] = formula
|
|
|
|
payload = {"columns": [{"id": column_id, "fields": fields}]}
|
|
await self._request("PATCH", f"/tables/{table}/columns", json=payload)
|
|
|
|
async def delete_column(self, table: str, column_id: str) -> None:
|
|
"""Delete a column from a table."""
|
|
await self._request("DELETE", f"/tables/{table}/columns/{column_id}")
|