Compare commits
40 Commits
v1.1.0
...
renovate/g
| Author | SHA1 | Date | |
|---|---|---|---|
| abfed5c06d | |||
| 540e57ec81 | |||
| d1e1043896 | |||
| 6521078b6a | |||
| 2f0a24aceb | |||
| 77bf95817d | |||
| 29a72ab005 | |||
| 33bb464102 | |||
| d4e793224b | |||
| bf8f301ded | |||
| a97930848b | |||
| c868e8a7fa | |||
| 734cc0a525 | |||
| a7c87128ef | |||
| 848cfd684f | |||
| ea175d55a2 | |||
| db12fca615 | |||
| d540105d09 | |||
| d40ae0b238 | |||
| 2a60de1bf1 | |||
| ba45de4582 | |||
| d176b03d56 | |||
| 50c5cfbab1 | |||
| 8484536aae | |||
| b3bfdf97c2 | |||
| eabddee737 | |||
| 3d1ac1fe60 | |||
| ed1d14a4d4 | |||
| 80e93ab3d9 | |||
| 7073182f9e | |||
| caa435d972 | |||
| ba88ba01f3 | |||
| fb6d4af973 | |||
| a7bb11d765 | |||
| c65ec0489c | |||
| 681cb0f67c | |||
| 3c97ad407c | |||
| b310ee10a9 | |||
| 4923d3110c | |||
| f79ae5546f |
10
.github/workflows/build.yaml
vendored
10
.github/workflows/build.yaml
vendored
@@ -18,10 +18,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6
|
||||
|
||||
- name: Log in to Container Registry
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
with:
|
||||
registry: ${{ env.REGISTRY }}
|
||||
username: ${{ github.actor }}
|
||||
@@ -29,7 +29,7 @@ jobs:
|
||||
|
||||
- name: Extract metadata for Docker
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
uses: docker/metadata-action@c299e40c65443455700f0fdfc63efafe5b349051 # v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||
tags: |
|
||||
@@ -38,10 +38,10 @@ jobs:
|
||||
type=raw,value=latest,enable=${{ !contains(github.ref, '-alpha') && !contains(github.ref, '-beta') && !contains(github.ref, '-rc') }}
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: docker/build-push-action@v6
|
||||
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
|
||||
131
CHANGELOG.md
131
CHANGELOG.md
@@ -5,6 +5,137 @@ All notable changes to this project will be documented in this file.
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [1.5.0] - 2026-01-26
|
||||
|
||||
### Added
|
||||
|
||||
#### Column Label Support
|
||||
- **`add_column`**: New optional `label` parameter for setting display name
|
||||
- **`modify_column`**: New optional `label` parameter for updating display name
|
||||
|
||||
Labels are human-readable names shown in Grist column headers, separate from the `column_id` used in formulas and API calls. If not provided, Grist defaults the label to the column ID.
|
||||
|
||||
#### Usage
|
||||
```python
|
||||
# Create column with display label
|
||||
add_column(document="crm", table="Contacts", column_id="first_name", column_type="Text", label="First Name")
|
||||
|
||||
# Update existing column's label
|
||||
modify_column(document="crm", table="Contacts", column_id="first_name", label="Given Name")
|
||||
```
|
||||
|
||||
## [1.4.1] - 2026-01-14
|
||||
|
||||
### Added
|
||||
|
||||
#### Reference Column Filter Support
|
||||
- **Filter normalization**: `get_records` now automatically normalizes filter values to array format
|
||||
- Fixes 400 errors when filtering on `Ref:*` (reference/foreign key) columns
|
||||
- Single values are wrapped in arrays before sending to Grist API
|
||||
|
||||
#### Usage
|
||||
```python
|
||||
# Before: Failed with 400 Bad Request
|
||||
get_records(document="accounting", table="TransactionLines", filter={"Transaction": 44})
|
||||
|
||||
# After: Works - filter normalized to {"Transaction": [44]}
|
||||
get_records(document="accounting", table="TransactionLines", filter={"Transaction": 44})
|
||||
|
||||
# Multiple values also supported
|
||||
get_records(document="accounting", table="TransactionLines", filter={"Transaction": [44, 45, 46]})
|
||||
```
|
||||
|
||||
### Fixed
|
||||
- Shell script shebangs updated to `#!/usr/bin/env bash` for portability across environments
|
||||
|
||||
## [1.4.0] - 2026-01-12
|
||||
|
||||
### Added
|
||||
|
||||
#### Attachment Download via Proxy
|
||||
- **`GET /api/v1/attachments/{id}`**: New HTTP endpoint for downloading attachments
|
||||
- Returns binary content with appropriate `Content-Type` and `Content-Disposition` headers
|
||||
- Requires read permission in session token
|
||||
- Complements the existing upload endpoint for complete attachment workflows
|
||||
|
||||
#### Usage
|
||||
```bash
|
||||
# Get session token with read permission
|
||||
TOKEN=$(curl -s ... | jq -r '.token')
|
||||
|
||||
# Download attachment
|
||||
curl -H "Authorization: Bearer $TOKEN" \
|
||||
https://example.com/api/v1/attachments/42 \
|
||||
-o downloaded.pdf
|
||||
```
|
||||
|
||||
```python
|
||||
# Python example
|
||||
import requests
|
||||
|
||||
response = requests.get(
|
||||
f'{base_url}/api/v1/attachments/42',
|
||||
headers={'Authorization': f'Bearer {token}'}
|
||||
)
|
||||
with open('downloaded.pdf', 'wb') as f:
|
||||
f.write(response.content)
|
||||
```
|
||||
|
||||
## [1.3.0] - 2026-01-03
|
||||
|
||||
### Added
|
||||
|
||||
#### Attachment Upload via Proxy
|
||||
- **`POST /api/v1/attachments`**: New HTTP endpoint for file uploads
|
||||
- Uses `multipart/form-data` for efficient binary transfer (no base64 overhead)
|
||||
- Automatic MIME type detection from filename
|
||||
- Returns attachment ID for linking to records via `update_records`
|
||||
- Requires write permission in session token
|
||||
|
||||
#### Usage
|
||||
```bash
|
||||
# Get session token with write permission
|
||||
TOKEN=$(curl -s ... | jq -r '.token')
|
||||
|
||||
# Upload file
|
||||
curl -X POST \
|
||||
-H "Authorization: Bearer $TOKEN" \
|
||||
-F "file=@invoice.pdf" \
|
||||
https://example.com/api/v1/attachments
|
||||
|
||||
# Returns: {"success": true, "data": {"attachment_id": 42, "filename": "invoice.pdf", "size_bytes": 31395}}
|
||||
```
|
||||
|
||||
```python
|
||||
# Python example
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
f'{proxy_url.replace("/proxy", "/attachments")}',
|
||||
headers={'Authorization': f'Bearer {token}'},
|
||||
files={'file': open('invoice.pdf', 'rb')}
|
||||
)
|
||||
attachment_id = response.json()['data']['attachment_id']
|
||||
|
||||
# Link to record via proxy
|
||||
requests.post(proxy_url, headers={'Authorization': f'Bearer {token}'}, json={
|
||||
'method': 'update_records',
|
||||
'table': 'Bills',
|
||||
'records': [{'id': 1, 'fields': {'Attachment': [attachment_id]}}]
|
||||
})
|
||||
```
|
||||
|
||||
## [1.2.0] - 2026-01-02
|
||||
|
||||
### Added
|
||||
|
||||
#### Session Token Proxy
|
||||
- **Session token proxy**: Agents can request short-lived tokens for bulk operations
|
||||
- `get_proxy_documentation` MCP tool: returns complete proxy API spec
|
||||
- `request_session_token` MCP tool: creates scoped session tokens with TTL (max 1 hour)
|
||||
- `POST /api/v1/proxy` HTTP endpoint: accepts session tokens for direct API access
|
||||
- Supports all 11 Grist operations (read, write, schema) via HTTP
|
||||
|
||||
## [1.1.0] - 2026-01-02
|
||||
|
||||
### Added
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Stage 1: Builder
|
||||
FROM python:3.14-slim AS builder
|
||||
FROM python:3.14-slim@sha256:fb83750094b46fd6b8adaa80f66e2302ecbe45d513f6cece637a841e1025b4ca AS builder
|
||||
|
||||
# Install uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /usr/local/bin/uv
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest@sha256:90bbb3c16635e9627f49eec6539f956d70746c409209041800a0280b93152823 /uv /usr/local/bin/uv
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
@@ -20,7 +20,7 @@ RUN uv sync --frozen --no-dev
|
||||
|
||||
|
||||
# Stage 2: Runtime
|
||||
FROM python:3.14-slim
|
||||
FROM python:3.14-slim@sha256:fb83750094b46fd6b8adaa80f66e2302ecbe45d513f6cece637a841e1025b4ca
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd --create-home --shell /bin/bash appuser
|
||||
|
||||
@@ -150,6 +150,7 @@ Add to your MCP client configuration (e.g., Claude Desktop):
|
||||
| `GRIST_MCP_TOKEN` | Agent authentication token (required) | - |
|
||||
| `CONFIG_PATH` | Path to config file inside container | `/app/config.yaml` |
|
||||
| `LOG_LEVEL` | Logging verbosity (`DEBUG`, `INFO`, `WARNING`, `ERROR`) | `INFO` |
|
||||
| `GRIST_MCP_URL` | Public URL of this server (for session proxy tokens) | - |
|
||||
|
||||
### config.yaml Structure
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Production environment
|
||||
services:
|
||||
grist-mcp:
|
||||
image: ghcr.io/xe138/grist-mcp-server:latest
|
||||
image: ghcr.io/xe138/grist-mcp-server:latest@sha256:2ef22bfac6cfbcbbfc513f61eaea3414b3a531d79e9d1d39bf6757cc9e27ea9a
|
||||
ports:
|
||||
- "${PORT:-3000}:3000"
|
||||
volumes:
|
||||
|
||||
1471
docs/plans/2026-01-02-session-proxy-impl.md
Normal file
1471
docs/plans/2026-01-02-session-proxy-impl.md
Normal file
File diff suppressed because it is too large
Load Diff
187
docs/plans/2026-01-03-attachment-upload-design.md
Normal file
187
docs/plans/2026-01-03-attachment-upload-design.md
Normal file
@@ -0,0 +1,187 @@
|
||||
# Attachment Upload Feature Design
|
||||
|
||||
**Date:** 2026-01-03
|
||||
**Status:** Approved
|
||||
|
||||
## Summary
|
||||
|
||||
Add an `upload_attachment` MCP tool to upload files to Grist documents and receive an attachment ID for linking to records.
|
||||
|
||||
## Design Decisions
|
||||
|
||||
| Decision | Choice | Rationale |
|
||||
|----------|--------|-----------|
|
||||
| Content encoding | Base64 string | MCP tools use JSON; binary must be encoded |
|
||||
| Batch support | Single file only | YAGNI; caller can loop if needed |
|
||||
| Linking behavior | Upload only, return ID | Single responsibility; use existing `update_records` to link |
|
||||
| Download support | Not included | YAGNI; can add later if needed |
|
||||
| Permission level | Write | Attachments are data, not schema |
|
||||
| Proxy support | MCP tool only | Reduces scope; scripts can use Grist API directly |
|
||||
|
||||
## Tool Interface
|
||||
|
||||
### Input Schema
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"document": {
|
||||
"type": "string",
|
||||
"description": "Document name"
|
||||
},
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "Filename with extension (e.g., 'invoice.pdf')"
|
||||
},
|
||||
"content_base64": {
|
||||
"type": "string",
|
||||
"description": "File content as base64-encoded string"
|
||||
},
|
||||
"content_type": {
|
||||
"type": "string",
|
||||
"description": "MIME type (optional, auto-detected from filename if omitted)"
|
||||
}
|
||||
},
|
||||
"required": ["document", "filename", "content_base64"]
|
||||
}
|
||||
```
|
||||
|
||||
### Response
|
||||
|
||||
```json
|
||||
{
|
||||
"attachment_id": 42,
|
||||
"filename": "invoice.pdf",
|
||||
"size_bytes": 30720
|
||||
}
|
||||
```
|
||||
|
||||
### Usage Example
|
||||
|
||||
```python
|
||||
# 1. Upload attachment
|
||||
result = upload_attachment(
|
||||
document="accounting",
|
||||
filename="Invoice-001.pdf",
|
||||
content_base64="JVBERi0xLjQK..."
|
||||
)
|
||||
|
||||
# 2. Link to record via existing update_records tool
|
||||
update_records("Bills", [{
|
||||
"id": 1,
|
||||
"fields": {"Attachment": [result["attachment_id"]]}
|
||||
}])
|
||||
```
|
||||
|
||||
## Implementation
|
||||
|
||||
### Files to Modify
|
||||
|
||||
1. **`src/grist_mcp/grist_client.py`** - Add `upload_attachment()` method
|
||||
2. **`src/grist_mcp/tools/write.py`** - Add tool function
|
||||
3. **`src/grist_mcp/server.py`** - Register tool
|
||||
|
||||
### GristClient Method
|
||||
|
||||
```python
|
||||
async def upload_attachment(
|
||||
self,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
content_type: str | None = None
|
||||
) -> dict:
|
||||
"""Upload a file attachment. Returns attachment metadata."""
|
||||
if content_type is None:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
files = {"upload": (filename, content, content_type)}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self._base_url}/attachments",
|
||||
headers=self._headers,
|
||||
files=files,
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Grist returns list of attachment IDs
|
||||
attachment_ids = response.json()
|
||||
return {
|
||||
"attachment_id": attachment_ids[0],
|
||||
"filename": filename,
|
||||
"size_bytes": len(content),
|
||||
}
|
||||
```
|
||||
|
||||
### Tool Function
|
||||
|
||||
```python
|
||||
import base64
|
||||
import mimetypes
|
||||
|
||||
async def upload_attachment(
|
||||
agent: Agent,
|
||||
auth: Authenticator,
|
||||
document: str,
|
||||
filename: str,
|
||||
content_base64: str,
|
||||
content_type: str | None = None,
|
||||
client: GristClient | None = None,
|
||||
) -> dict:
|
||||
"""Upload a file attachment to a document."""
|
||||
auth.authorize(agent, document, Permission.WRITE)
|
||||
|
||||
# Decode base64
|
||||
try:
|
||||
content = base64.b64decode(content_base64)
|
||||
except Exception:
|
||||
raise ValueError("Invalid base64 encoding")
|
||||
|
||||
# Auto-detect MIME type if not provided
|
||||
if content_type is None:
|
||||
content_type, _ = mimetypes.guess_type(filename)
|
||||
if content_type is None:
|
||||
content_type = "application/octet-stream"
|
||||
|
||||
if client is None:
|
||||
doc = auth.get_document(document)
|
||||
client = GristClient(doc)
|
||||
|
||||
return await client.upload_attachment(filename, content, content_type)
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
| Error | Cause | Response |
|
||||
|-------|-------|----------|
|
||||
| Invalid base64 | Malformed content_base64 | `ValueError: Invalid base64 encoding` |
|
||||
| Authorization | Agent lacks write permission | `AuthError` (existing pattern) |
|
||||
| Grist API error | Upload fails | `httpx.HTTPStatusError` (existing pattern) |
|
||||
|
||||
## Testing
|
||||
|
||||
### Unit Tests
|
||||
|
||||
**`tests/unit/test_tools_write.py`:**
|
||||
- `test_upload_attachment_success` - Valid base64, returns attachment_id
|
||||
- `test_upload_attachment_invalid_base64` - Raises ValueError
|
||||
- `test_upload_attachment_auth_required` - Verifies write permission check
|
||||
- `test_upload_attachment_mime_detection` - Auto-detects type from filename
|
||||
|
||||
**`tests/unit/test_grist_client.py`:**
|
||||
- `test_upload_attachment_api_call` - Correct multipart request format
|
||||
- `test_upload_attachment_with_explicit_content_type` - Passes through MIME type
|
||||
|
||||
### Mock Approach
|
||||
|
||||
Mock `httpx.AsyncClient` responses; no Grist server needed for unit tests.
|
||||
|
||||
## Future Considerations
|
||||
|
||||
Not included in this implementation (YAGNI):
|
||||
- Batch upload (multiple files)
|
||||
- Download attachment
|
||||
- Proxy API support
|
||||
- Size limit validation (rely on Grist's limits)
|
||||
|
||||
These can be added if real use cases emerge.
|
||||
@@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "grist-mcp"
|
||||
version = "1.1.0"
|
||||
version = "1.4.1"
|
||||
description = "MCP server for AI agents to interact with Grist documents"
|
||||
requires-python = ">=3.14"
|
||||
dependencies = [
|
||||
@@ -28,3 +28,6 @@ build-backend = "hatchling.build"
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests/unit", "tests/integration"]
|
||||
markers = [
|
||||
"integration: marks tests as integration tests (require Docker containers)",
|
||||
]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
# scripts/get-test-instance-id.sh
|
||||
# Generate a unique instance ID from git branch for parallel test isolation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/bin/bash
|
||||
#!/usr/bin/env bash
|
||||
# scripts/run-integration-tests.sh
|
||||
# Run integration tests with branch isolation and dynamic port discovery
|
||||
set -e
|
||||
|
||||
@@ -116,6 +116,71 @@ class GristClient:
|
||||
"""Delete records by ID."""
|
||||
await self._request("POST", f"/tables/{table}/data/delete", json=record_ids)
|
||||
|
||||
async def upload_attachment(
|
||||
self,
|
||||
filename: str,
|
||||
content: bytes,
|
||||
content_type: str = "application/octet-stream",
|
||||
) -> dict:
|
||||
"""Upload a file attachment. Returns attachment metadata.
|
||||
|
||||
Args:
|
||||
filename: Name for the uploaded file.
|
||||
content: File content as bytes.
|
||||
content_type: MIME type of the file.
|
||||
|
||||
Returns:
|
||||
Dict with attachment_id, filename, and size_bytes.
|
||||
"""
|
||||
files = {"upload": (filename, content, content_type)}
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self._base_url}/attachments",
|
||||
headers=self._headers,
|
||||
files=files,
|
||||
)
|
||||
response.raise_for_status()
|
||||
# Grist returns list of attachment IDs
|
||||
attachment_ids = response.json()
|
||||
return {
|
||||
"attachment_id": attachment_ids[0],
|
||||
"filename": filename,
|
||||
"size_bytes": len(content),
|
||||
}
|
||||
|
||||
async def download_attachment(self, attachment_id: int) -> dict:
|
||||
"""Download an attachment by ID.
|
||||
|
||||
Args:
|
||||
attachment_id: The ID of the attachment to download.
|
||||
|
||||
Returns:
|
||||
Dict with content (bytes), content_type, and filename.
|
||||
"""
|
||||
import re
|
||||
|
||||
async with httpx.AsyncClient(timeout=self._timeout) as client:
|
||||
response = await client.get(
|
||||
f"{self._base_url}/attachments/{attachment_id}/download",
|
||||
headers=self._headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Extract filename from Content-Disposition header
|
||||
content_disp = response.headers.get("content-disposition", "")
|
||||
filename = None
|
||||
if "filename=" in content_disp:
|
||||
match = re.search(r'filename="?([^";]+)"?', content_disp)
|
||||
if match:
|
||||
filename = match.group(1)
|
||||
|
||||
return {
|
||||
"content": response.content,
|
||||
"content_type": response.headers.get("content-type", "application/octet-stream"),
|
||||
"filename": filename,
|
||||
}
|
||||
|
||||
# Schema operations
|
||||
|
||||
async def create_table(self, table_id: str, columns: list[dict]) -> str:
|
||||
@@ -138,11 +203,14 @@ class GristClient:
|
||||
column_id: str,
|
||||
column_type: str,
|
||||
formula: str | None = None,
|
||||
label: str | None = None,
|
||||
) -> str:
|
||||
"""Add a column to a table. Returns column ID."""
|
||||
fields = {"type": column_type}
|
||||
if formula:
|
||||
fields["formula"] = formula
|
||||
if label:
|
||||
fields["label"] = label
|
||||
|
||||
payload = {"columns": [{"id": column_id, "fields": fields}]}
|
||||
data = await self._request("POST", f"/tables/{table}/columns", json=payload)
|
||||
@@ -154,13 +222,16 @@ class GristClient:
|
||||
column_id: str,
|
||||
type: str | None = None,
|
||||
formula: str | None = None,
|
||||
label: str | None = None,
|
||||
) -> None:
|
||||
"""Modify a column's type or formula."""
|
||||
"""Modify a column's type, formula, or label."""
|
||||
fields = {}
|
||||
if type is not None:
|
||||
fields["type"] = type
|
||||
if formula is not None:
|
||||
fields["formula"] = formula
|
||||
if label is not None:
|
||||
fields["label"] = label
|
||||
|
||||
payload = {"columns": [{"id": column_id, "fields": fields}]}
|
||||
await self._request("PATCH", f"/tables/{table}/columns", json=payload)
|
||||
|
||||
@@ -12,6 +12,9 @@ from mcp.server.sse import SseServerTransport
|
||||
from grist_mcp.server import create_server
|
||||
from grist_mcp.config import Config, load_config
|
||||
from grist_mcp.auth import Authenticator, AuthError
|
||||
from grist_mcp.session import SessionTokenManager
|
||||
from grist_mcp.proxy import parse_proxy_request, dispatch_proxy_request, ProxyError
|
||||
from grist_mcp.grist_client import GristClient
|
||||
from grist_mcp.logging import setup_logging
|
||||
|
||||
|
||||
@@ -43,6 +46,76 @@ async def send_error(send: Send, status: int, message: str) -> None:
|
||||
})
|
||||
|
||||
|
||||
async def send_json_response(send: Send, status: int, data: dict) -> None:
|
||||
"""Send a JSON response."""
|
||||
body = json.dumps(data).encode()
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": body,
|
||||
})
|
||||
|
||||
|
||||
def _parse_multipart(content_type: str, body: bytes) -> tuple[str | None, bytes | None]:
|
||||
"""Parse multipart/form-data to extract uploaded file.
|
||||
|
||||
Returns (filename, content) or (None, None) if parsing fails.
|
||||
"""
|
||||
import re
|
||||
|
||||
# Extract boundary from content-type
|
||||
match = re.search(r'boundary=([^\s;]+)', content_type)
|
||||
if not match:
|
||||
return None, None
|
||||
|
||||
boundary = match.group(1).encode()
|
||||
if boundary.startswith(b'"') and boundary.endswith(b'"'):
|
||||
boundary = boundary[1:-1]
|
||||
|
||||
# Split by boundary
|
||||
parts = body.split(b'--' + boundary)
|
||||
|
||||
for part in parts:
|
||||
if b'Content-Disposition' not in part:
|
||||
continue
|
||||
|
||||
# Split headers from content
|
||||
if b'\r\n\r\n' in part:
|
||||
header_section, content = part.split(b'\r\n\r\n', 1)
|
||||
elif b'\n\n' in part:
|
||||
header_section, content = part.split(b'\n\n', 1)
|
||||
else:
|
||||
continue
|
||||
|
||||
headers = header_section.decode('utf-8', errors='replace')
|
||||
|
||||
# Check if this is a file upload
|
||||
if 'filename=' not in headers:
|
||||
continue
|
||||
|
||||
# Extract filename
|
||||
filename_match = re.search(r'filename="([^"]+)"', headers)
|
||||
if not filename_match:
|
||||
filename_match = re.search(r"filename=([^\s;]+)", headers)
|
||||
if not filename_match:
|
||||
continue
|
||||
|
||||
filename = filename_match.group(1)
|
||||
|
||||
# Remove trailing boundary marker and whitespace
|
||||
content = content.rstrip()
|
||||
if content.endswith(b'--'):
|
||||
content = content[:-2].rstrip()
|
||||
|
||||
return filename, content
|
||||
|
||||
return None, None
|
||||
|
||||
|
||||
CONFIG_TEMPLATE = """\
|
||||
# grist-mcp configuration
|
||||
#
|
||||
@@ -110,6 +183,8 @@ def _ensure_config(config_path: str) -> bool:
|
||||
def create_app(config: Config):
|
||||
"""Create the ASGI application."""
|
||||
auth = Authenticator(config)
|
||||
token_manager = SessionTokenManager()
|
||||
proxy_base_url = os.environ.get("GRIST_MCP_URL")
|
||||
|
||||
sse = SseServerTransport("/messages")
|
||||
|
||||
@@ -127,7 +202,7 @@ def create_app(config: Config):
|
||||
return
|
||||
|
||||
# Create a server instance for this authenticated connection
|
||||
server = create_server(auth, agent)
|
||||
server = create_server(auth, agent, token_manager, proxy_base_url)
|
||||
|
||||
async with sse.connect_sse(scope, receive, send) as streams:
|
||||
await server.run(
|
||||
@@ -159,6 +234,196 @@ def create_app(config: Config):
|
||||
"body": b'{"error":"Not found"}',
|
||||
})
|
||||
|
||||
async def handle_proxy(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
# Extract token
|
||||
token = _get_bearer_token(scope)
|
||||
if not token:
|
||||
await send_json_response(send, 401, {
|
||||
"success": False,
|
||||
"error": "Missing Authorization header",
|
||||
"code": "INVALID_TOKEN",
|
||||
})
|
||||
return
|
||||
|
||||
# Validate session token
|
||||
session = token_manager.validate_token(token)
|
||||
if session is None:
|
||||
await send_json_response(send, 401, {
|
||||
"success": False,
|
||||
"error": "Invalid or expired token",
|
||||
"code": "TOKEN_EXPIRED",
|
||||
})
|
||||
return
|
||||
|
||||
# Read request body
|
||||
body = b""
|
||||
while True:
|
||||
message = await receive()
|
||||
body += message.get("body", b"")
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
|
||||
try:
|
||||
request_data = json.loads(body)
|
||||
except json.JSONDecodeError:
|
||||
await send_json_response(send, 400, {
|
||||
"success": False,
|
||||
"error": "Invalid JSON",
|
||||
"code": "INVALID_REQUEST",
|
||||
})
|
||||
return
|
||||
|
||||
# Parse and dispatch
|
||||
try:
|
||||
request = parse_proxy_request(request_data)
|
||||
result = await dispatch_proxy_request(request, session, auth)
|
||||
await send_json_response(send, 200, result)
|
||||
except ProxyError as e:
|
||||
status = 403 if e.code == "UNAUTHORIZED" else 400
|
||||
await send_json_response(send, status, {
|
||||
"success": False,
|
||||
"error": e.message,
|
||||
"code": e.code,
|
||||
})
|
||||
|
||||
async def handle_attachments(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
"""Handle file attachment uploads via multipart/form-data."""
|
||||
# Extract token
|
||||
token = _get_bearer_token(scope)
|
||||
if not token:
|
||||
await send_json_response(send, 401, {
|
||||
"success": False,
|
||||
"error": "Missing Authorization header",
|
||||
"code": "INVALID_TOKEN",
|
||||
})
|
||||
return
|
||||
|
||||
# Validate session token
|
||||
session = token_manager.validate_token(token)
|
||||
if session is None:
|
||||
await send_json_response(send, 401, {
|
||||
"success": False,
|
||||
"error": "Invalid or expired token",
|
||||
"code": "TOKEN_EXPIRED",
|
||||
})
|
||||
return
|
||||
|
||||
# Check write permission
|
||||
if "write" not in session.permissions:
|
||||
await send_json_response(send, 403, {
|
||||
"success": False,
|
||||
"error": "Write permission required for attachment upload",
|
||||
"code": "UNAUTHORIZED",
|
||||
})
|
||||
return
|
||||
|
||||
# Get content-type header
|
||||
headers = dict(scope.get("headers", []))
|
||||
content_type = headers.get(b"content-type", b"").decode()
|
||||
|
||||
if not content_type.startswith("multipart/form-data"):
|
||||
await send_json_response(send, 400, {
|
||||
"success": False,
|
||||
"error": "Content-Type must be multipart/form-data",
|
||||
"code": "INVALID_REQUEST",
|
||||
})
|
||||
return
|
||||
|
||||
# Read request body
|
||||
body = b""
|
||||
while True:
|
||||
message = await receive()
|
||||
body += message.get("body", b"")
|
||||
if not message.get("more_body", False):
|
||||
break
|
||||
|
||||
# Parse multipart
|
||||
filename, content = _parse_multipart(content_type, body)
|
||||
if filename is None or content is None:
|
||||
await send_json_response(send, 400, {
|
||||
"success": False,
|
||||
"error": "No file found in request",
|
||||
"code": "INVALID_REQUEST",
|
||||
})
|
||||
return
|
||||
|
||||
# Upload to Grist
|
||||
try:
|
||||
doc = auth.get_document(session.document)
|
||||
client = GristClient(doc)
|
||||
result = await client.upload_attachment(filename, content)
|
||||
await send_json_response(send, 200, {
|
||||
"success": True,
|
||||
"data": result,
|
||||
})
|
||||
except Exception as e:
|
||||
await send_json_response(send, 500, {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": "GRIST_ERROR",
|
||||
})
|
||||
|
||||
async def handle_attachment_download(
|
||||
scope: Scope, receive: Receive, send: Send, attachment_id: int
|
||||
) -> None:
|
||||
"""Handle attachment download by ID."""
|
||||
# Extract token
|
||||
token = _get_bearer_token(scope)
|
||||
if not token:
|
||||
await send_json_response(send, 401, {
|
||||
"success": False,
|
||||
"error": "Missing Authorization header",
|
||||
"code": "INVALID_TOKEN",
|
||||
})
|
||||
return
|
||||
|
||||
# Validate session token
|
||||
session = token_manager.validate_token(token)
|
||||
if session is None:
|
||||
await send_json_response(send, 401, {
|
||||
"success": False,
|
||||
"error": "Invalid or expired token",
|
||||
"code": "TOKEN_EXPIRED",
|
||||
})
|
||||
return
|
||||
|
||||
# Check read permission
|
||||
if "read" not in session.permissions:
|
||||
await send_json_response(send, 403, {
|
||||
"success": False,
|
||||
"error": "Read permission required for attachment download",
|
||||
"code": "UNAUTHORIZED",
|
||||
})
|
||||
return
|
||||
|
||||
# Download from Grist
|
||||
try:
|
||||
doc = auth.get_document(session.document)
|
||||
client = GristClient(doc)
|
||||
result = await client.download_attachment(attachment_id)
|
||||
|
||||
# Build response headers
|
||||
headers = [[b"content-type", result["content_type"].encode()]]
|
||||
if result["filename"]:
|
||||
disposition = f'attachment; filename="{result["filename"]}"'
|
||||
headers.append([b"content-disposition", disposition.encode()])
|
||||
|
||||
await send({
|
||||
"type": "http.response.start",
|
||||
"status": 200,
|
||||
"headers": headers,
|
||||
})
|
||||
await send({
|
||||
"type": "http.response.body",
|
||||
"body": result["content"],
|
||||
})
|
||||
except Exception as e:
|
||||
await send_json_response(send, 500, {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"code": "GRIST_ERROR",
|
||||
})
|
||||
|
||||
async def app(scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if scope["type"] != "http":
|
||||
return
|
||||
@@ -172,6 +437,21 @@ def create_app(config: Config):
|
||||
await handle_sse(scope, receive, send)
|
||||
elif path == "/messages" and method == "POST":
|
||||
await handle_messages(scope, receive, send)
|
||||
elif path == "/api/v1/proxy" and method == "POST":
|
||||
await handle_proxy(scope, receive, send)
|
||||
elif path == "/api/v1/attachments" and method == "POST":
|
||||
await handle_attachments(scope, receive, send)
|
||||
elif path.startswith("/api/v1/attachments/") and method == "GET":
|
||||
# Parse attachment ID from path: /api/v1/attachments/{id}
|
||||
try:
|
||||
attachment_id = int(path.split("/")[-1])
|
||||
await handle_attachment_download(scope, receive, send, attachment_id)
|
||||
except ValueError:
|
||||
await send_json_response(send, 400, {
|
||||
"success": False,
|
||||
"error": "Invalid attachment ID",
|
||||
"code": "INVALID_REQUEST",
|
||||
})
|
||||
else:
|
||||
await handle_not_found(scope, receive, send)
|
||||
|
||||
@@ -180,11 +460,18 @@ def create_app(config: Config):
|
||||
|
||||
def _print_mcp_config(external_port: int, tokens: list) -> None:
|
||||
"""Print Claude Code MCP configuration."""
|
||||
# Use GRIST_MCP_URL if set, otherwise fall back to localhost
|
||||
base_url = os.environ.get("GRIST_MCP_URL")
|
||||
if base_url:
|
||||
sse_url = f"{base_url.rstrip('/')}/sse"
|
||||
else:
|
||||
sse_url = f"http://localhost:{external_port}/sse"
|
||||
|
||||
print()
|
||||
print("Claude Code MCP configuration (copy-paste to add):")
|
||||
for t in tokens:
|
||||
config = (
|
||||
f'{{"type": "sse", "url": "http://localhost:{external_port}/sse", '
|
||||
f'{{"type": "sse", "url": "{sse_url}", '
|
||||
f'"headers": {{"Authorization": "Bearer {t.token}"}}}}'
|
||||
)
|
||||
print(f" claude mcp add-json grist-{t.name} '{config}'")
|
||||
|
||||
192
src/grist_mcp/proxy.py
Normal file
192
src/grist_mcp/proxy.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""HTTP proxy handler for session token access."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from grist_mcp.auth import Authenticator
|
||||
from grist_mcp.grist_client import GristClient
|
||||
from grist_mcp.session import SessionToken
|
||||
|
||||
|
||||
class ProxyError(Exception):
|
||||
"""Error during proxy request processing."""
|
||||
|
||||
def __init__(self, message: str, code: str):
|
||||
self.message = message
|
||||
self.code = code
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProxyRequest:
|
||||
"""Parsed proxy request."""
|
||||
method: str
|
||||
table: str | None = None
|
||||
records: list[dict] | None = None
|
||||
record_ids: list[int] | None = None
|
||||
filter: dict | None = None
|
||||
sort: str | None = None
|
||||
limit: int | None = None
|
||||
query: str | None = None
|
||||
table_id: str | None = None
|
||||
columns: list[dict] | None = None
|
||||
column_id: str | None = None
|
||||
column_type: str | None = None
|
||||
formula: str | None = None
|
||||
type: str | None = None
|
||||
|
||||
|
||||
METHODS_REQUIRING_TABLE = {
|
||||
"get_records", "describe_table", "add_records", "update_records",
|
||||
"delete_records", "add_column", "modify_column", "delete_column",
|
||||
}
|
||||
|
||||
|
||||
def parse_proxy_request(body: dict[str, Any]) -> ProxyRequest:
|
||||
"""Parse and validate a proxy request body."""
|
||||
if "method" not in body:
|
||||
raise ProxyError("Missing required field: method", "INVALID_REQUEST")
|
||||
|
||||
method = body["method"]
|
||||
|
||||
if method in METHODS_REQUIRING_TABLE and "table" not in body:
|
||||
raise ProxyError(f"Missing required field 'table' for method '{method}'", "INVALID_REQUEST")
|
||||
|
||||
return ProxyRequest(
|
||||
method=method,
|
||||
table=body.get("table"),
|
||||
records=body.get("records"),
|
||||
record_ids=body.get("record_ids"),
|
||||
filter=body.get("filter"),
|
||||
sort=body.get("sort"),
|
||||
limit=body.get("limit"),
|
||||
query=body.get("query"),
|
||||
table_id=body.get("table_id"),
|
||||
columns=body.get("columns"),
|
||||
column_id=body.get("column_id"),
|
||||
column_type=body.get("column_type"),
|
||||
formula=body.get("formula"),
|
||||
type=body.get("type"),
|
||||
)
|
||||
|
||||
|
||||
# Map methods to required permissions
|
||||
METHOD_PERMISSIONS = {
|
||||
"list_tables": "read",
|
||||
"describe_table": "read",
|
||||
"get_records": "read",
|
||||
"sql_query": "read",
|
||||
"add_records": "write",
|
||||
"update_records": "write",
|
||||
"delete_records": "write",
|
||||
"create_table": "schema",
|
||||
"add_column": "schema",
|
||||
"modify_column": "schema",
|
||||
"delete_column": "schema",
|
||||
}
|
||||
|
||||
|
||||
async def dispatch_proxy_request(
|
||||
request: ProxyRequest,
|
||||
session: SessionToken,
|
||||
auth: Authenticator,
|
||||
client: GristClient | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Dispatch a proxy request to the appropriate handler."""
|
||||
# Check permission
|
||||
required_perm = METHOD_PERMISSIONS.get(request.method)
|
||||
if required_perm is None:
|
||||
raise ProxyError(f"Unknown method: {request.method}", "INVALID_REQUEST")
|
||||
|
||||
if required_perm not in session.permissions:
|
||||
raise ProxyError(
|
||||
f"Permission '{required_perm}' required for {request.method}",
|
||||
"UNAUTHORIZED",
|
||||
)
|
||||
|
||||
# Create client if not provided
|
||||
if client is None:
|
||||
doc = auth.get_document(session.document)
|
||||
client = GristClient(doc)
|
||||
|
||||
# Dispatch to appropriate method
|
||||
try:
|
||||
if request.method == "list_tables":
|
||||
data = await client.list_tables()
|
||||
return {"success": True, "data": {"tables": data}}
|
||||
|
||||
elif request.method == "describe_table":
|
||||
data = await client.describe_table(request.table)
|
||||
return {"success": True, "data": {"table": request.table, "columns": data}}
|
||||
|
||||
elif request.method == "get_records":
|
||||
data = await client.get_records(
|
||||
request.table,
|
||||
filter=request.filter,
|
||||
sort=request.sort,
|
||||
limit=request.limit,
|
||||
)
|
||||
return {"success": True, "data": {"records": data}}
|
||||
|
||||
elif request.method == "sql_query":
|
||||
if request.query is None:
|
||||
raise ProxyError("Missing required field: query", "INVALID_REQUEST")
|
||||
data = await client.sql_query(request.query)
|
||||
return {"success": True, "data": {"records": data}}
|
||||
|
||||
elif request.method == "add_records":
|
||||
if request.records is None:
|
||||
raise ProxyError("Missing required field: records", "INVALID_REQUEST")
|
||||
data = await client.add_records(request.table, request.records)
|
||||
return {"success": True, "data": {"record_ids": data}}
|
||||
|
||||
elif request.method == "update_records":
|
||||
if request.records is None:
|
||||
raise ProxyError("Missing required field: records", "INVALID_REQUEST")
|
||||
await client.update_records(request.table, request.records)
|
||||
return {"success": True, "data": {"updated": len(request.records)}}
|
||||
|
||||
elif request.method == "delete_records":
|
||||
if request.record_ids is None:
|
||||
raise ProxyError("Missing required field: record_ids", "INVALID_REQUEST")
|
||||
await client.delete_records(request.table, request.record_ids)
|
||||
return {"success": True, "data": {"deleted": len(request.record_ids)}}
|
||||
|
||||
elif request.method == "create_table":
|
||||
if request.table_id is None or request.columns is None:
|
||||
raise ProxyError("Missing required fields: table_id, columns", "INVALID_REQUEST")
|
||||
data = await client.create_table(request.table_id, request.columns)
|
||||
return {"success": True, "data": {"table_id": data}}
|
||||
|
||||
elif request.method == "add_column":
|
||||
if request.column_id is None or request.column_type is None:
|
||||
raise ProxyError("Missing required fields: column_id, column_type", "INVALID_REQUEST")
|
||||
await client.add_column(
|
||||
request.table, request.column_id, request.column_type,
|
||||
formula=request.formula,
|
||||
)
|
||||
return {"success": True, "data": {"column_id": request.column_id}}
|
||||
|
||||
elif request.method == "modify_column":
|
||||
if request.column_id is None:
|
||||
raise ProxyError("Missing required field: column_id", "INVALID_REQUEST")
|
||||
await client.modify_column(
|
||||
request.table, request.column_id,
|
||||
type=request.type,
|
||||
formula=request.formula,
|
||||
)
|
||||
return {"success": True, "data": {"column_id": request.column_id}}
|
||||
|
||||
elif request.method == "delete_column":
|
||||
if request.column_id is None:
|
||||
raise ProxyError("Missing required field: column_id", "INVALID_REQUEST")
|
||||
await client.delete_column(request.table, request.column_id)
|
||||
return {"success": True, "data": {"deleted": request.column_id}}
|
||||
|
||||
else:
|
||||
raise ProxyError(f"Unknown method: {request.method}", "INVALID_REQUEST")
|
||||
|
||||
except ProxyError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ProxyError(str(e), "GRIST_ERROR")
|
||||
@@ -7,6 +7,9 @@ from mcp.server import Server
|
||||
from mcp.types import Tool, TextContent
|
||||
|
||||
from grist_mcp.auth import Authenticator, Agent, AuthError
|
||||
from grist_mcp.session import SessionTokenManager
|
||||
from grist_mcp.tools.session import get_proxy_documentation as _get_proxy_documentation
|
||||
from grist_mcp.tools.session import request_session_token as _request_session_token
|
||||
from grist_mcp.logging import get_logger, extract_stats, format_tool_log
|
||||
|
||||
logger = get_logger("server")
|
||||
@@ -25,18 +28,26 @@ from grist_mcp.tools.schema import modify_column as _modify_column
|
||||
from grist_mcp.tools.schema import delete_column as _delete_column
|
||||
|
||||
|
||||
def create_server(auth: Authenticator, agent: Agent) -> Server:
|
||||
def create_server(
|
||||
auth: Authenticator,
|
||||
agent: Agent,
|
||||
token_manager: SessionTokenManager | None = None,
|
||||
proxy_base_url: str | None = None,
|
||||
) -> Server:
|
||||
"""Create and configure the MCP server for an authenticated agent.
|
||||
|
||||
Args:
|
||||
auth: Authenticator instance for permission checks.
|
||||
agent: The authenticated agent for this server instance.
|
||||
token_manager: Optional session token manager for HTTP proxy access.
|
||||
proxy_base_url: Base URL for the proxy endpoint (e.g., "https://example.com").
|
||||
|
||||
Returns:
|
||||
Configured MCP Server instance.
|
||||
"""
|
||||
server = Server("grist-mcp")
|
||||
_current_agent = agent
|
||||
_proxy_base_url = proxy_base_url
|
||||
|
||||
@server.list_tools()
|
||||
async def list_tools() -> list[Tool]:
|
||||
@@ -175,13 +186,14 @@ def create_server(auth: Authenticator, agent: Agent) -> Server:
|
||||
"column_id": {"type": "string"},
|
||||
"column_type": {"type": "string"},
|
||||
"formula": {"type": "string"},
|
||||
"label": {"type": "string", "description": "Display label for the column"},
|
||||
},
|
||||
"required": ["document", "table", "column_id", "column_type"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="modify_column",
|
||||
description="Modify a column's type or formula",
|
||||
description="Modify a column's type, formula, or label",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -190,6 +202,7 @@ def create_server(auth: Authenticator, agent: Agent) -> Server:
|
||||
"column_id": {"type": "string"},
|
||||
"type": {"type": "string"},
|
||||
"formula": {"type": "string"},
|
||||
"label": {"type": "string", "description": "Display label for the column"},
|
||||
},
|
||||
"required": ["document", "table", "column_id"],
|
||||
},
|
||||
@@ -207,6 +220,34 @@ def create_server(auth: Authenticator, agent: Agent) -> Server:
|
||||
"required": ["document", "table", "column_id"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_proxy_documentation",
|
||||
description="Get complete documentation for the HTTP proxy API",
|
||||
inputSchema={"type": "object", "properties": {}, "required": []},
|
||||
),
|
||||
Tool(
|
||||
name="request_session_token",
|
||||
description="Request a short-lived token for direct HTTP API access. Use this to delegate bulk data operations to scripts.",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"document": {
|
||||
"type": "string",
|
||||
"description": "Document name to grant access to",
|
||||
},
|
||||
"permissions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "enum": ["read", "write", "schema"]},
|
||||
"description": "Permission levels to grant",
|
||||
},
|
||||
"ttl_seconds": {
|
||||
"type": "integer",
|
||||
"description": "Token lifetime in seconds (max 3600, default 300)",
|
||||
},
|
||||
},
|
||||
"required": ["document", "permissions"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
@server.call_tool()
|
||||
@@ -272,6 +313,7 @@ def create_server(auth: Authenticator, agent: Agent) -> Server:
|
||||
_current_agent, auth, arguments["document"], arguments["table"],
|
||||
arguments["column_id"], arguments["column_type"],
|
||||
formula=arguments.get("formula"),
|
||||
label=arguments.get("label"),
|
||||
)
|
||||
elif name == "modify_column":
|
||||
result = await _modify_column(
|
||||
@@ -279,12 +321,25 @@ def create_server(auth: Authenticator, agent: Agent) -> Server:
|
||||
arguments["column_id"],
|
||||
type=arguments.get("type"),
|
||||
formula=arguments.get("formula"),
|
||||
label=arguments.get("label"),
|
||||
)
|
||||
elif name == "delete_column":
|
||||
result = await _delete_column(
|
||||
_current_agent, auth, arguments["document"], arguments["table"],
|
||||
arguments["column_id"],
|
||||
)
|
||||
elif name == "get_proxy_documentation":
|
||||
result = await _get_proxy_documentation()
|
||||
elif name == "request_session_token":
|
||||
if token_manager is None:
|
||||
return [TextContent(type="text", text="Session tokens not enabled")]
|
||||
result = await _request_session_token(
|
||||
_current_agent, auth, token_manager,
|
||||
arguments["document"],
|
||||
arguments["permissions"],
|
||||
ttl_seconds=arguments.get("ttl_seconds", 300),
|
||||
proxy_base_url=_proxy_base_url,
|
||||
)
|
||||
else:
|
||||
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
||||
|
||||
|
||||
73
src/grist_mcp/session.py
Normal file
73
src/grist_mcp/session.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Session token management for HTTP proxy access."""
|
||||
|
||||
import secrets
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
MAX_TTL_SECONDS = 3600 # 1 hour
|
||||
DEFAULT_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionToken:
|
||||
"""A short-lived session token for proxy access."""
|
||||
token: str
|
||||
document: str
|
||||
permissions: list[str]
|
||||
agent_name: str
|
||||
created_at: datetime
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
class SessionTokenManager:
|
||||
"""Manages creation and validation of session tokens."""
|
||||
|
||||
def __init__(self):
|
||||
self._tokens: dict[str, SessionToken] = {}
|
||||
|
||||
def create_token(
|
||||
self,
|
||||
agent_name: str,
|
||||
document: str,
|
||||
permissions: list[str],
|
||||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
) -> SessionToken:
|
||||
"""Create a new session token.
|
||||
|
||||
TTL is capped at MAX_TTL_SECONDS (1 hour).
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
token_str = f"sess_{secrets.token_urlsafe(32)}"
|
||||
|
||||
# Cap TTL at maximum
|
||||
effective_ttl = min(ttl_seconds, MAX_TTL_SECONDS)
|
||||
|
||||
session = SessionToken(
|
||||
token=token_str,
|
||||
document=document,
|
||||
permissions=permissions,
|
||||
agent_name=agent_name,
|
||||
created_at=now,
|
||||
expires_at=now + timedelta(seconds=effective_ttl),
|
||||
)
|
||||
|
||||
self._tokens[token_str] = session
|
||||
return session
|
||||
|
||||
def validate_token(self, token: str) -> SessionToken | None:
|
||||
"""Validate a session token.
|
||||
|
||||
Returns the SessionToken if valid and not expired, None otherwise.
|
||||
Also removes expired tokens lazily.
|
||||
"""
|
||||
session = self._tokens.get(token)
|
||||
if session is None:
|
||||
return None
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if session.expires_at < now:
|
||||
# Token expired, remove it
|
||||
del self._tokens[token]
|
||||
return None
|
||||
|
||||
return session
|
||||
37
src/grist_mcp/tools/filters.py
Normal file
37
src/grist_mcp/tools/filters.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Filter normalization for Grist API queries."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
def normalize_filter_value(value: Any) -> list:
|
||||
"""Ensure a filter value is a list.
|
||||
|
||||
Grist API expects filter values to be arrays.
|
||||
|
||||
Args:
|
||||
value: Single value or list of values.
|
||||
|
||||
Returns:
|
||||
Value wrapped in list, or original list if already a list.
|
||||
"""
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return [value]
|
||||
|
||||
|
||||
def normalize_filter(filter: dict | None) -> dict | None:
|
||||
"""Normalize filter values to array format for Grist API.
|
||||
|
||||
Grist expects all filter values to be arrays. This function
|
||||
wraps single values in lists.
|
||||
|
||||
Args:
|
||||
filter: Filter dict with column names as keys.
|
||||
|
||||
Returns:
|
||||
Normalized filter dict, or None if input was None.
|
||||
"""
|
||||
if not filter:
|
||||
return filter
|
||||
|
||||
return {key: normalize_filter_value(value) for key, value in filter.items()}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from grist_mcp.auth import Agent, Authenticator, Permission
|
||||
from grist_mcp.grist_client import GristClient
|
||||
from grist_mcp.tools.filters import normalize_filter
|
||||
|
||||
|
||||
async def list_tables(
|
||||
@@ -56,7 +57,10 @@ async def get_records(
|
||||
doc = auth.get_document(document)
|
||||
client = GristClient(doc)
|
||||
|
||||
records = await client.get_records(table, filter=filter, sort=sort, limit=limit)
|
||||
# Normalize filter values to array format for Grist API
|
||||
normalized_filter = normalize_filter(filter)
|
||||
|
||||
records = await client.get_records(table, filter=normalized_filter, sort=sort, limit=limit)
|
||||
return {"records": records}
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ async def add_column(
|
||||
column_id: str,
|
||||
column_type: str,
|
||||
formula: str | None = None,
|
||||
label: str | None = None,
|
||||
client: GristClient | None = None,
|
||||
) -> dict:
|
||||
"""Add a column to a table."""
|
||||
@@ -40,7 +41,9 @@ async def add_column(
|
||||
doc = auth.get_document(document)
|
||||
client = GristClient(doc)
|
||||
|
||||
created_id = await client.add_column(table, column_id, column_type, formula=formula)
|
||||
created_id = await client.add_column(
|
||||
table, column_id, column_type, formula=formula, label=label
|
||||
)
|
||||
return {"column_id": created_id}
|
||||
|
||||
|
||||
@@ -52,16 +55,17 @@ async def modify_column(
|
||||
column_id: str,
|
||||
type: str | None = None,
|
||||
formula: str | None = None,
|
||||
label: str | None = None,
|
||||
client: GristClient | None = None,
|
||||
) -> dict:
|
||||
"""Modify a column's type or formula."""
|
||||
"""Modify a column's type, formula, or label."""
|
||||
auth.authorize(agent, document, Permission.SCHEMA)
|
||||
|
||||
if client is None:
|
||||
doc = auth.get_document(document)
|
||||
client = GristClient(doc)
|
||||
|
||||
await client.modify_column(table, column_id, type=type, formula=formula)
|
||||
await client.modify_column(table, column_id, type=type, formula=formula, label=label)
|
||||
return {"modified": True}
|
||||
|
||||
|
||||
|
||||
192
src/grist_mcp/tools/session.py
Normal file
192
src/grist_mcp/tools/session.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Session token tools for HTTP proxy access."""
|
||||
|
||||
from grist_mcp.auth import Agent, Authenticator, AuthError, Permission
|
||||
from grist_mcp.session import SessionTokenManager
|
||||
|
||||
|
||||
PROXY_DOCUMENTATION = {
|
||||
"description": "HTTP proxy API for bulk data operations. Use request_session_token to get a short-lived token, then call the proxy endpoint directly from scripts.",
|
||||
"endpoints": {
|
||||
"proxy": "POST /api/v1/proxy - JSON operations (CRUD, schema)",
|
||||
"attachments_upload": "POST /api/v1/attachments - File uploads (multipart/form-data)",
|
||||
"attachments_download": "GET /api/v1/attachments/{id} - File downloads (binary response)",
|
||||
},
|
||||
"endpoint_note": "The full URL is returned in the 'proxy_url' field of request_session_token response. Replace /proxy with /attachments for file operations.",
|
||||
"authentication": "Bearer token in Authorization header",
|
||||
"attachment_upload": {
|
||||
"endpoint": "POST /api/v1/attachments",
|
||||
"content_type": "multipart/form-data",
|
||||
"permission": "write",
|
||||
"description": "Upload file attachments to the document. Returns attachment_id for linking to records via update_records.",
|
||||
"response": {"success": True, "data": {"attachment_id": 42, "filename": "invoice.pdf", "size_bytes": 31395}},
|
||||
"example_curl": "curl -X POST -H 'Authorization: Bearer TOKEN' -F 'file=@invoice.pdf' URL/api/v1/attachments",
|
||||
"example_python": """import requests
|
||||
response = requests.post(
|
||||
f'{proxy_url.replace("/proxy", "/attachments")}',
|
||||
headers={'Authorization': f'Bearer {token}'},
|
||||
files={'file': open('invoice.pdf', 'rb')}
|
||||
)
|
||||
attachment_id = response.json()['data']['attachment_id']
|
||||
# Link to record: update_records with {'Attachment': [attachment_id]}""",
|
||||
},
|
||||
"attachment_download": {
|
||||
"endpoint": "GET /api/v1/attachments/{attachment_id}",
|
||||
"permission": "read",
|
||||
"description": "Download attachment by ID. Returns binary content with appropriate Content-Type and Content-Disposition headers.",
|
||||
"response_headers": ["Content-Type", "Content-Disposition"],
|
||||
"example_curl": "curl -H 'Authorization: Bearer TOKEN' URL/api/v1/attachments/42 -o file.pdf",
|
||||
"example_python": """import requests
|
||||
response = requests.get(
|
||||
f'{base_url}/api/v1/attachments/42',
|
||||
headers={'Authorization': f'Bearer {token}'}
|
||||
)
|
||||
with open('downloaded.pdf', 'wb') as f:
|
||||
f.write(response.content)""",
|
||||
},
|
||||
"request_format": {
|
||||
"method": "Operation name (required)",
|
||||
"table": "Table name (required for most operations)",
|
||||
},
|
||||
"methods": {
|
||||
"get_records": {
|
||||
"description": "Fetch records from a table",
|
||||
"fields": {
|
||||
"table": "string",
|
||||
"filter": "object (optional)",
|
||||
"sort": "string (optional)",
|
||||
"limit": "integer (optional)",
|
||||
},
|
||||
},
|
||||
"sql_query": {
|
||||
"description": "Run a read-only SQL query",
|
||||
"fields": {"query": "string"},
|
||||
},
|
||||
"list_tables": {
|
||||
"description": "List all tables in the document",
|
||||
"fields": {},
|
||||
},
|
||||
"describe_table": {
|
||||
"description": "Get column information for a table",
|
||||
"fields": {"table": "string"},
|
||||
},
|
||||
"add_records": {
|
||||
"description": "Add records to a table",
|
||||
"fields": {"table": "string", "records": "array of objects"},
|
||||
},
|
||||
"update_records": {
|
||||
"description": "Update existing records",
|
||||
"fields": {"table": "string", "records": "array of {id, fields}"},
|
||||
},
|
||||
"delete_records": {
|
||||
"description": "Delete records by ID",
|
||||
"fields": {"table": "string", "record_ids": "array of integers"},
|
||||
},
|
||||
"create_table": {
|
||||
"description": "Create a new table",
|
||||
"fields": {"table_id": "string", "columns": "array of {id, type}"},
|
||||
},
|
||||
"add_column": {
|
||||
"description": "Add a column to a table",
|
||||
"fields": {
|
||||
"table": "string",
|
||||
"column_id": "string",
|
||||
"column_type": "string",
|
||||
"formula": "string (optional)",
|
||||
},
|
||||
},
|
||||
"modify_column": {
|
||||
"description": "Modify a column's type or formula",
|
||||
"fields": {
|
||||
"table": "string",
|
||||
"column_id": "string",
|
||||
"type": "string (optional)",
|
||||
"formula": "string (optional)",
|
||||
},
|
||||
},
|
||||
"delete_column": {
|
||||
"description": "Delete a column",
|
||||
"fields": {"table": "string", "column_id": "string"},
|
||||
},
|
||||
},
|
||||
"response_format": {
|
||||
"success": {"success": True, "data": "..."},
|
||||
"error": {"success": False, "error": "message", "code": "ERROR_CODE"},
|
||||
},
|
||||
"error_codes": [
|
||||
"UNAUTHORIZED",
|
||||
"INVALID_TOKEN",
|
||||
"TOKEN_EXPIRED",
|
||||
"INVALID_REQUEST",
|
||||
"GRIST_ERROR",
|
||||
],
|
||||
"example_script": """#!/usr/bin/env python3
|
||||
import requests
|
||||
import sys
|
||||
|
||||
# Use token and proxy_url from request_session_token response
|
||||
token = sys.argv[1]
|
||||
proxy_url = sys.argv[2]
|
||||
|
||||
response = requests.post(
|
||||
proxy_url,
|
||||
headers={'Authorization': f'Bearer {token}'},
|
||||
json={
|
||||
'method': 'add_records',
|
||||
'table': 'Orders',
|
||||
'records': [{'item': 'Widget', 'qty': 100}]
|
||||
}
|
||||
)
|
||||
print(response.json())
|
||||
""",
|
||||
}
|
||||
|
||||
|
||||
async def get_proxy_documentation() -> dict:
|
||||
"""Return complete documentation for the HTTP proxy API."""
|
||||
return PROXY_DOCUMENTATION
|
||||
|
||||
|
||||
async def request_session_token(
|
||||
agent: Agent,
|
||||
auth: Authenticator,
|
||||
token_manager: SessionTokenManager,
|
||||
document: str,
|
||||
permissions: list[str],
|
||||
ttl_seconds: int = 300,
|
||||
proxy_base_url: str | None = None,
|
||||
) -> dict:
|
||||
"""Request a short-lived session token for HTTP proxy access.
|
||||
|
||||
The token can only grant permissions the agent already has.
|
||||
"""
|
||||
# Verify agent has access to the document
|
||||
# Check each requested permission
|
||||
for perm_str in permissions:
|
||||
try:
|
||||
perm = Permission(perm_str)
|
||||
except ValueError:
|
||||
raise AuthError(f"Invalid permission: {perm_str}")
|
||||
auth.authorize(agent, document, perm)
|
||||
|
||||
# Create the session token
|
||||
session = token_manager.create_token(
|
||||
agent_name=agent.name,
|
||||
document=document,
|
||||
permissions=permissions,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
# Build proxy URL - use base URL if provided, otherwise just path
|
||||
proxy_path = "/api/v1/proxy"
|
||||
if proxy_base_url:
|
||||
proxy_url = f"{proxy_base_url.rstrip('/')}{proxy_path}"
|
||||
else:
|
||||
proxy_url = proxy_path
|
||||
|
||||
return {
|
||||
"token": session.token,
|
||||
"document": session.document,
|
||||
"permissions": session.permissions,
|
||||
"expires_at": session.expires_at.isoformat(),
|
||||
"proxy_url": proxy_url,
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
FROM python:3.14-slim
|
||||
FROM python:3.14-slim@sha256:fb83750094b46fd6b8adaa80f66e2302ecbe45d513f6cece637a841e1025b4ca
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -35,6 +35,18 @@ MOCK_TABLES = {
|
||||
{"id": 2, "fields": {"Title": "Deploy", "Done": False}},
|
||||
],
|
||||
},
|
||||
"Orders": {
|
||||
"columns": [
|
||||
{"id": "OrderNum", "fields": {"type": "Int"}},
|
||||
{"id": "Customer", "fields": {"type": "Ref:People"}},
|
||||
{"id": "Amount", "fields": {"type": "Numeric"}},
|
||||
],
|
||||
"records": [
|
||||
{"id": 1, "fields": {"OrderNum": 1001, "Customer": 1, "Amount": 100.0}},
|
||||
{"id": 2, "fields": {"OrderNum": 1002, "Customer": 2, "Amount": 200.0}},
|
||||
{"id": 3, "fields": {"OrderNum": 1003, "Customer": 1, "Amount": 150.0}},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
# Track requests for test assertions
|
||||
@@ -93,12 +105,40 @@ async def get_records(request):
|
||||
"""GET /api/docs/{doc_id}/tables/{table_id}/records"""
|
||||
doc_id = request.path_params["doc_id"]
|
||||
table_id = request.path_params["table_id"]
|
||||
log_request("GET", f"/api/docs/{doc_id}/tables/{table_id}/records")
|
||||
filter_param = request.query_params.get("filter")
|
||||
log_request("GET", f"/api/docs/{doc_id}/tables/{table_id}/records?filter={filter_param}")
|
||||
|
||||
if table_id not in MOCK_TABLES:
|
||||
return JSONResponse({"error": "Table not found"}, status_code=404)
|
||||
|
||||
return JSONResponse({"records": MOCK_TABLES[table_id]["records"]})
|
||||
records = MOCK_TABLES[table_id]["records"]
|
||||
|
||||
# Apply filtering if provided
|
||||
if filter_param:
|
||||
try:
|
||||
filters = json.loads(filter_param)
|
||||
# Validate filter format: all values must be arrays (Grist API requirement)
|
||||
for key, values in filters.items():
|
||||
if not isinstance(values, list):
|
||||
return JSONResponse(
|
||||
{"error": f"Filter values must be arrays, got {type(values).__name__} for '{key}'"},
|
||||
status_code=400
|
||||
)
|
||||
# Apply filters: record matches if field value is in the filter list
|
||||
filtered_records = []
|
||||
for record in records:
|
||||
match = True
|
||||
for key, allowed_values in filters.items():
|
||||
if record["fields"].get(key) not in allowed_values:
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
filtered_records.append(record)
|
||||
records = filtered_records
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse({"error": "Invalid filter JSON"}, status_code=400)
|
||||
|
||||
return JSONResponse({"records": records})
|
||||
|
||||
|
||||
async def add_records(request):
|
||||
@@ -178,6 +218,15 @@ async def modify_column(request):
|
||||
return JSONResponse({})
|
||||
|
||||
|
||||
async def modify_columns(request):
|
||||
"""PATCH /api/docs/{doc_id}/tables/{table_id}/columns - batch modify columns"""
|
||||
doc_id = request.path_params["doc_id"]
|
||||
table_id = request.path_params["table_id"]
|
||||
body = await request.json()
|
||||
log_request("PATCH", f"/api/docs/{doc_id}/tables/{table_id}/columns", body)
|
||||
return JSONResponse({})
|
||||
|
||||
|
||||
async def delete_column(request):
|
||||
"""DELETE /api/docs/{doc_id}/tables/{table_id}/columns/{col_id}"""
|
||||
doc_id = request.path_params["doc_id"]
|
||||
@@ -199,6 +248,7 @@ app = Starlette(
|
||||
Route("/api/docs/{doc_id}/tables", endpoint=create_tables, methods=["POST"]),
|
||||
Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=get_table_columns),
|
||||
Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=add_column, methods=["POST"]),
|
||||
Route("/api/docs/{doc_id}/tables/{table_id}/columns", endpoint=modify_columns, methods=["PATCH"]),
|
||||
Route("/api/docs/{doc_id}/tables/{table_id}/columns/{col_id}", endpoint=modify_column, methods=["PATCH"]),
|
||||
Route("/api/docs/{doc_id}/tables/{table_id}/columns/{col_id}", endpoint=delete_column, methods=["DELETE"]),
|
||||
Route("/api/docs/{doc_id}/tables/{table_id}/records", endpoint=get_records),
|
||||
|
||||
@@ -9,12 +9,14 @@ from mcp.client.sse import sse_client
|
||||
|
||||
|
||||
GRIST_MCP_URL = os.environ.get("GRIST_MCP_URL", "http://localhost:3000")
|
||||
GRIST_MCP_TOKEN = os.environ.get("GRIST_MCP_TOKEN", "test-token")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_mcp_session():
|
||||
"""Create and yield an MCP session."""
|
||||
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
|
||||
headers = {"Authorization": f"Bearer {GRIST_MCP_TOKEN}"}
|
||||
async with sse_client(f"{GRIST_MCP_URL}/sse", headers=headers) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
@@ -44,12 +46,14 @@ async def test_mcp_protocol_compliance(services_ready):
|
||||
"add_column",
|
||||
"modify_column",
|
||||
"delete_column",
|
||||
"get_proxy_documentation",
|
||||
"request_session_token",
|
||||
]
|
||||
|
||||
for expected in expected_tools:
|
||||
assert expected in tool_names, f"Missing tool: {expected}"
|
||||
|
||||
assert len(result.tools) == 12, f"Expected 12 tools, got {len(result.tools)}"
|
||||
assert len(result.tools) == 14, f"Expected 14 tools, got {len(result.tools)}"
|
||||
|
||||
# Test 3: All tools have descriptions
|
||||
for tool in result.tools:
|
||||
|
||||
52
tests/integration/test_session_proxy.py
Normal file
52
tests/integration/test_session_proxy.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Integration tests for session token proxy."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
|
||||
GRIST_MCP_URL = os.environ.get("GRIST_MCP_URL", "http://localhost:3000")
|
||||
GRIST_MCP_TOKEN = os.environ.get("GRIST_MCP_TOKEN")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_client():
|
||||
"""Client for MCP SSE endpoint."""
|
||||
return httpx.Client(
|
||||
base_url=GRIST_MCP_URL,
|
||||
headers={"Authorization": f"Bearer {GRIST_MCP_TOKEN}"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def proxy_client():
|
||||
"""Client for proxy endpoint (session token set per-test)."""
|
||||
return httpx.Client(base_url=GRIST_MCP_URL)
|
||||
|
||||
|
||||
@pytest.mark.integration
|
||||
def test_full_session_proxy_flow(mcp_client, proxy_client):
|
||||
"""Test: request token via MCP, use token to call proxy."""
|
||||
# This test requires a running grist-mcp server with proper config
|
||||
# Skip if not configured
|
||||
if not GRIST_MCP_TOKEN:
|
||||
pytest.skip("GRIST_MCP_TOKEN not set")
|
||||
|
||||
# Step 1: Request session token (would be via MCP in real usage)
|
||||
# For integration test, we test the proxy endpoint directly
|
||||
# This is a placeholder - full MCP integration would use SSE
|
||||
|
||||
# Step 2: Use proxy endpoint
|
||||
# Note: Need a valid session token to test this fully
|
||||
# For now, verify endpoint exists and rejects bad tokens
|
||||
|
||||
response = proxy_client.post(
|
||||
"/api/v1/proxy",
|
||||
headers={"Authorization": "Bearer invalid_token"},
|
||||
json={"method": "list_tables"},
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
data = response.json()
|
||||
assert data["success"] is False
|
||||
assert data["code"] in ["INVALID_TOKEN", "TOKEN_EXPIRED"]
|
||||
@@ -12,12 +12,14 @@ from mcp.client.sse import sse_client
|
||||
|
||||
GRIST_MCP_URL = os.environ.get("GRIST_MCP_URL", "http://localhost:3000")
|
||||
MOCK_GRIST_URL = os.environ.get("MOCK_GRIST_URL", "http://localhost:8484")
|
||||
GRIST_MCP_TOKEN = os.environ.get("GRIST_MCP_TOKEN", "test-token")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_mcp_session():
|
||||
"""Create and yield an MCP session."""
|
||||
async with sse_client(f"{GRIST_MCP_URL}/sse") as (read_stream, write_stream):
|
||||
headers = {"Authorization": f"Bearer {GRIST_MCP_TOKEN}"}
|
||||
async with sse_client(f"{GRIST_MCP_URL}/sse", headers=headers) as (read_stream, write_stream):
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
await session.initialize()
|
||||
yield session
|
||||
@@ -88,6 +90,36 @@ async def test_all_tools(services_ready):
|
||||
log = get_mock_request_log()
|
||||
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
|
||||
|
||||
# Test get_records with Ref column filter
|
||||
# This tests that single values are normalized to arrays for the Grist API
|
||||
clear_mock_request_log()
|
||||
result = await client.call_tool(
|
||||
"get_records",
|
||||
{"document": "test-doc", "table": "Orders", "filter": {"Customer": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "records" in data
|
||||
# Should return only orders for Customer 1 (orders 1 and 3)
|
||||
assert len(data["records"]) == 2
|
||||
for record in data["records"]:
|
||||
assert record["Customer"] == 1
|
||||
log = get_mock_request_log()
|
||||
# Verify the filter was sent as array format
|
||||
filter_requests = [e for e in log if "/records" in e["path"] and "filter=" in e["path"]]
|
||||
assert len(filter_requests) >= 1
|
||||
# The filter value should be [1] not 1
|
||||
assert "[1]" in filter_requests[0]["path"]
|
||||
|
||||
# Test get_records with multiple filter values
|
||||
clear_mock_request_log()
|
||||
result = await client.call_tool(
|
||||
"get_records",
|
||||
{"document": "test-doc", "table": "Orders", "filter": {"Customer": [1, 2]}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "records" in data
|
||||
assert len(data["records"]) == 3 # All 3 orders (customers 1 and 2)
|
||||
|
||||
# Test sql_query
|
||||
clear_mock_request_log()
|
||||
result = await client.call_tool(
|
||||
@@ -194,7 +226,7 @@ async def test_all_tools(services_ready):
|
||||
data = json.loads(result.content[0].text)
|
||||
assert "modified" in data
|
||||
log = get_mock_request_log()
|
||||
patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns/" in e["path"]]
|
||||
patch_cols = [e for e in log if e["method"] == "PATCH" and "/columns" in e["path"]]
|
||||
assert len(patch_cols) >= 1
|
||||
|
||||
# Test delete_column
|
||||
|
||||
89
tests/unit/test_filters.py
Normal file
89
tests/unit/test_filters.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Unit tests for filter normalization."""
|
||||
|
||||
import pytest
|
||||
|
||||
from grist_mcp.tools.filters import normalize_filter, normalize_filter_value
|
||||
|
||||
|
||||
class TestNormalizeFilterValue:
|
||||
"""Tests for normalize_filter_value function."""
|
||||
|
||||
def test_int_becomes_list(self):
|
||||
assert normalize_filter_value(5) == [5]
|
||||
|
||||
def test_string_becomes_list(self):
|
||||
assert normalize_filter_value("foo") == ["foo"]
|
||||
|
||||
def test_float_becomes_list(self):
|
||||
assert normalize_filter_value(3.14) == [3.14]
|
||||
|
||||
def test_list_unchanged(self):
|
||||
assert normalize_filter_value([1, 2, 3]) == [1, 2, 3]
|
||||
|
||||
def test_empty_list_unchanged(self):
|
||||
assert normalize_filter_value([]) == []
|
||||
|
||||
def test_single_item_list_unchanged(self):
|
||||
assert normalize_filter_value([42]) == [42]
|
||||
|
||||
def test_mixed_type_list_unchanged(self):
|
||||
assert normalize_filter_value([1, "foo", 3.14]) == [1, "foo", 3.14]
|
||||
|
||||
|
||||
class TestNormalizeFilter:
|
||||
"""Tests for normalize_filter function."""
|
||||
|
||||
def test_none_returns_none(self):
|
||||
assert normalize_filter(None) is None
|
||||
|
||||
def test_empty_dict_returns_empty_dict(self):
|
||||
assert normalize_filter({}) == {}
|
||||
|
||||
def test_single_int_value_wrapped(self):
|
||||
result = normalize_filter({"Transaction": 44})
|
||||
assert result == {"Transaction": [44]}
|
||||
|
||||
def test_single_string_value_wrapped(self):
|
||||
result = normalize_filter({"Status": "active"})
|
||||
assert result == {"Status": ["active"]}
|
||||
|
||||
def test_list_value_unchanged(self):
|
||||
result = normalize_filter({"Transaction": [44, 45, 46]})
|
||||
assert result == {"Transaction": [44, 45, 46]}
|
||||
|
||||
def test_mixed_columns_all_normalized(self):
|
||||
"""Both ref and non-ref columns are normalized to arrays."""
|
||||
result = normalize_filter({
|
||||
"Transaction": 44, # Ref column (int)
|
||||
"Debit": 500, # Non-ref column (int)
|
||||
"Memo": "test", # Non-ref column (str)
|
||||
})
|
||||
assert result == {
|
||||
"Transaction": [44],
|
||||
"Debit": [500],
|
||||
"Memo": ["test"],
|
||||
}
|
||||
|
||||
def test_multiple_values_list_unchanged(self):
|
||||
"""Filter with multiple values passes through."""
|
||||
result = normalize_filter({
|
||||
"Status": ["pending", "active"],
|
||||
"Priority": [1, 2, 3],
|
||||
})
|
||||
assert result == {
|
||||
"Status": ["pending", "active"],
|
||||
"Priority": [1, 2, 3],
|
||||
}
|
||||
|
||||
def test_mixed_single_and_list_values(self):
|
||||
"""Mix of single values and lists."""
|
||||
result = normalize_filter({
|
||||
"Transaction": 44, # Single int
|
||||
"Status": ["open", "closed"], # List
|
||||
"Amount": 100.50, # Single float
|
||||
})
|
||||
assert result == {
|
||||
"Transaction": [44],
|
||||
"Status": ["open", "closed"],
|
||||
"Amount": [100.50],
|
||||
}
|
||||
@@ -155,6 +155,27 @@ async def test_add_column(client, httpx_mock: HTTPXMock):
|
||||
col_id = await client.add_column("Table1", "NewCol", "Text", formula=None)
|
||||
|
||||
assert col_id == "NewCol"
|
||||
request = httpx_mock.get_request()
|
||||
import json
|
||||
payload = json.loads(request.content)
|
||||
assert payload == {"columns": [{"id": "NewCol", "fields": {"type": "Text"}}]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_column_with_label(client, httpx_mock: HTTPXMock):
|
||||
httpx_mock.add_response(
|
||||
url="https://grist.example.com/api/docs/abc123/tables/Table1/columns",
|
||||
method="POST",
|
||||
json={"columns": [{"id": "first_name"}]},
|
||||
)
|
||||
|
||||
col_id = await client.add_column("Table1", "first_name", "Text", label="First Name")
|
||||
|
||||
assert col_id == "first_name"
|
||||
request = httpx_mock.get_request()
|
||||
import json
|
||||
payload = json.loads(request.content)
|
||||
assert payload == {"columns": [{"id": "first_name", "fields": {"type": "Text", "label": "First Name"}}]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -169,6 +190,22 @@ async def test_modify_column(client, httpx_mock: HTTPXMock):
|
||||
await client.modify_column("Table1", "Amount", type="Int", formula="$Price * $Qty")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_column_with_label(client, httpx_mock: HTTPXMock):
|
||||
httpx_mock.add_response(
|
||||
url="https://grist.example.com/api/docs/abc123/tables/Table1/columns",
|
||||
method="PATCH",
|
||||
json={},
|
||||
)
|
||||
|
||||
await client.modify_column("Table1", "Col1", label="Column One")
|
||||
|
||||
request = httpx_mock.get_request()
|
||||
import json
|
||||
payload = json.loads(request.content)
|
||||
assert payload == {"columns": [{"id": "Col1", "fields": {"label": "Column One"}}]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_column(client, httpx_mock: HTTPXMock):
|
||||
httpx_mock.add_response(
|
||||
@@ -196,3 +233,99 @@ def test_sql_validation_rejects_multiple_statements(client):
|
||||
def test_sql_validation_allows_trailing_semicolon(client):
|
||||
# Should not raise
|
||||
client._validate_sql_query("SELECT * FROM users;")
|
||||
|
||||
|
||||
# Attachment tests
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_attachment(client, httpx_mock: HTTPXMock):
|
||||
httpx_mock.add_response(
|
||||
url="https://grist.example.com/api/docs/abc123/attachments",
|
||||
method="POST",
|
||||
json=[42],
|
||||
)
|
||||
|
||||
result = await client.upload_attachment(
|
||||
filename="invoice.pdf",
|
||||
content=b"PDF content here",
|
||||
content_type="application/pdf",
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"attachment_id": 42,
|
||||
"filename": "invoice.pdf",
|
||||
"size_bytes": 16,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_attachment_default_content_type(client, httpx_mock: HTTPXMock):
|
||||
httpx_mock.add_response(
|
||||
url="https://grist.example.com/api/docs/abc123/attachments",
|
||||
method="POST",
|
||||
json=[99],
|
||||
)
|
||||
|
||||
result = await client.upload_attachment(
|
||||
filename="data.bin",
|
||||
content=b"\x00\x01\x02",
|
||||
)
|
||||
|
||||
assert result["attachment_id"] == 99
|
||||
assert result["size_bytes"] == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_attachment(client, httpx_mock: HTTPXMock):
|
||||
httpx_mock.add_response(
|
||||
url="https://grist.example.com/api/docs/abc123/attachments/42/download",
|
||||
method="GET",
|
||||
content=b"PDF content here",
|
||||
headers={
|
||||
"content-type": "application/pdf",
|
||||
"content-disposition": 'attachment; filename="invoice.pdf"',
|
||||
},
|
||||
)
|
||||
|
||||
result = await client.download_attachment(42)
|
||||
|
||||
assert result["content"] == b"PDF content here"
|
||||
assert result["content_type"] == "application/pdf"
|
||||
assert result["filename"] == "invoice.pdf"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_attachment_no_filename(client, httpx_mock: HTTPXMock):
|
||||
httpx_mock.add_response(
|
||||
url="https://grist.example.com/api/docs/abc123/attachments/99/download",
|
||||
method="GET",
|
||||
content=b"binary data",
|
||||
headers={
|
||||
"content-type": "application/octet-stream",
|
||||
},
|
||||
)
|
||||
|
||||
result = await client.download_attachment(99)
|
||||
|
||||
assert result["content"] == b"binary data"
|
||||
assert result["content_type"] == "application/octet-stream"
|
||||
assert result["filename"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_attachment_unquoted_filename(client, httpx_mock: HTTPXMock):
|
||||
httpx_mock.add_response(
|
||||
url="https://grist.example.com/api/docs/abc123/attachments/55/download",
|
||||
method="GET",
|
||||
content=b"image data",
|
||||
headers={
|
||||
"content-type": "image/png",
|
||||
"content-disposition": "attachment; filename=photo.png",
|
||||
},
|
||||
)
|
||||
|
||||
result = await client.download_attachment(55)
|
||||
|
||||
assert result["content"] == b"image data"
|
||||
assert result["content_type"] == "image/png"
|
||||
assert result["filename"] == "photo.png"
|
||||
|
||||
98
tests/unit/test_proxy.py
Normal file
98
tests/unit/test_proxy.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from grist_mcp.proxy import parse_proxy_request, ProxyRequest, ProxyError, dispatch_proxy_request
|
||||
from grist_mcp.session import SessionToken
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
return SessionToken(
|
||||
token="sess_test",
|
||||
document="sales",
|
||||
permissions=["read", "write"],
|
||||
agent_name="test-agent",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth():
|
||||
auth = MagicMock()
|
||||
doc = MagicMock()
|
||||
doc.url = "https://grist.example.com"
|
||||
doc.doc_id = "abc123"
|
||||
doc.api_key = "key"
|
||||
auth.get_document.return_value = doc
|
||||
return auth
|
||||
|
||||
|
||||
def test_parse_proxy_request_valid_add_records():
|
||||
body = {
|
||||
"method": "add_records",
|
||||
"table": "Orders",
|
||||
"records": [{"item": "Widget", "qty": 10}],
|
||||
}
|
||||
|
||||
request = parse_proxy_request(body)
|
||||
|
||||
assert request.method == "add_records"
|
||||
assert request.table == "Orders"
|
||||
assert request.records == [{"item": "Widget", "qty": 10}]
|
||||
|
||||
|
||||
def test_parse_proxy_request_missing_method():
|
||||
body = {"table": "Orders"}
|
||||
|
||||
with pytest.raises(ProxyError) as exc_info:
|
||||
parse_proxy_request(body)
|
||||
|
||||
assert exc_info.value.code == "INVALID_REQUEST"
|
||||
assert "method" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_add_records(mock_session, mock_auth):
|
||||
request = ProxyRequest(
|
||||
method="add_records",
|
||||
table="Orders",
|
||||
records=[{"item": "Widget"}],
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.add_records.return_value = [1, 2, 3]
|
||||
|
||||
result = await dispatch_proxy_request(
|
||||
request, mock_session, mock_auth, client=mock_client
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["record_ids"] == [1, 2, 3]
|
||||
mock_client.add_records.assert_called_once_with("Orders", [{"item": "Widget"}])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_denies_without_permission(mock_auth):
|
||||
# Session only has read permission
|
||||
session = SessionToken(
|
||||
token="sess_test",
|
||||
document="sales",
|
||||
permissions=["read"], # No write
|
||||
agent_name="test-agent",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
request = ProxyRequest(
|
||||
method="add_records", # Requires write
|
||||
table="Orders",
|
||||
records=[{"item": "Widget"}],
|
||||
)
|
||||
|
||||
with pytest.raises(ProxyError) as exc_info:
|
||||
await dispatch_proxy_request(request, session, mock_auth)
|
||||
|
||||
assert exc_info.value.code == "UNAUTHORIZED"
|
||||
@@ -53,5 +53,48 @@ tokens:
|
||||
assert "modify_column" in tool_names
|
||||
assert "delete_column" in tool_names
|
||||
|
||||
# Should have all 12 tools
|
||||
assert len(result.root.tools) == 12
|
||||
# Session tools (always registered)
|
||||
assert "get_proxy_documentation" in tool_names
|
||||
assert "request_session_token" in tool_names
|
||||
|
||||
# Should have all 14 tools
|
||||
assert len(result.root.tools) == 14
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_server_registers_session_tools(tmp_path):
|
||||
from grist_mcp.session import SessionTokenManager
|
||||
|
||||
config_file = tmp_path / "config.yaml"
|
||||
config_file.write_text("""
|
||||
documents:
|
||||
test-doc:
|
||||
url: https://grist.example.com
|
||||
doc_id: abc123
|
||||
api_key: test-key
|
||||
|
||||
tokens:
|
||||
- token: valid-token
|
||||
name: test-agent
|
||||
scope:
|
||||
- document: test-doc
|
||||
permissions: [read, write, schema]
|
||||
""")
|
||||
|
||||
config = load_config(str(config_file))
|
||||
auth = Authenticator(config)
|
||||
agent = auth.authenticate("valid-token")
|
||||
token_manager = SessionTokenManager()
|
||||
server = create_server(auth, agent, token_manager)
|
||||
|
||||
# Get the list_tools handler and call it
|
||||
handler = server.request_handlers.get(ListToolsRequest)
|
||||
assert handler is not None
|
||||
|
||||
req = ListToolsRequest(method="tools/list")
|
||||
result = await handler(req)
|
||||
|
||||
tool_names = [t.name for t in result.root.tools]
|
||||
|
||||
assert "get_proxy_documentation" in tool_names
|
||||
assert "request_session_token" in tool_names
|
||||
|
||||
81
tests/unit/test_session.py
Normal file
81
tests/unit/test_session.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import pytest
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from grist_mcp.session import SessionTokenManager, SessionToken
|
||||
|
||||
|
||||
def test_create_token_returns_valid_session_token():
|
||||
manager = SessionTokenManager()
|
||||
|
||||
token = manager.create_token(
|
||||
agent_name="test-agent",
|
||||
document="sales",
|
||||
permissions=["read", "write"],
|
||||
ttl_seconds=300,
|
||||
)
|
||||
|
||||
assert token.token.startswith("sess_")
|
||||
assert len(token.token) > 20
|
||||
assert token.document == "sales"
|
||||
assert token.permissions == ["read", "write"]
|
||||
assert token.agent_name == "test-agent"
|
||||
assert token.expires_at > datetime.now(timezone.utc)
|
||||
assert token.expires_at < datetime.now(timezone.utc) + timedelta(seconds=310)
|
||||
|
||||
|
||||
def test_create_token_caps_ttl_at_maximum():
|
||||
manager = SessionTokenManager()
|
||||
|
||||
# Request 2 hours, should be capped at 1 hour
|
||||
token = manager.create_token(
|
||||
agent_name="test-agent",
|
||||
document="sales",
|
||||
permissions=["read"],
|
||||
ttl_seconds=7200,
|
||||
)
|
||||
|
||||
# Should be capped at 3600 seconds (1 hour)
|
||||
max_expires = datetime.now(timezone.utc) + timedelta(seconds=3610)
|
||||
assert token.expires_at < max_expires
|
||||
|
||||
|
||||
def test_validate_token_returns_session_for_valid_token():
|
||||
manager = SessionTokenManager()
|
||||
created = manager.create_token(
|
||||
agent_name="test-agent",
|
||||
document="sales",
|
||||
permissions=["read"],
|
||||
ttl_seconds=300,
|
||||
)
|
||||
|
||||
session = manager.validate_token(created.token)
|
||||
|
||||
assert session is not None
|
||||
assert session.document == "sales"
|
||||
assert session.agent_name == "test-agent"
|
||||
|
||||
|
||||
def test_validate_token_returns_none_for_unknown_token():
|
||||
manager = SessionTokenManager()
|
||||
|
||||
session = manager.validate_token("sess_unknown_token")
|
||||
|
||||
assert session is None
|
||||
|
||||
|
||||
def test_validate_token_returns_none_for_expired_token():
|
||||
manager = SessionTokenManager()
|
||||
created = manager.create_token(
|
||||
agent_name="test-agent",
|
||||
document="sales",
|
||||
permissions=["read"],
|
||||
ttl_seconds=1,
|
||||
)
|
||||
|
||||
# Wait for expiry
|
||||
import time
|
||||
time.sleep(1.5)
|
||||
|
||||
session = manager.validate_token(created.token)
|
||||
|
||||
assert session is None
|
||||
@@ -75,6 +75,45 @@ async def test_get_records(agent, auth, mock_client):
|
||||
assert result == {"records": [{"id": 1, "Name": "Alice"}]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_records_normalizes_filter(agent, auth, mock_client):
|
||||
"""Test that filter values are normalized to array format for Grist API."""
|
||||
mock_client.get_records.return_value = [{"id": 1, "Customer": 5}]
|
||||
|
||||
await get_records(
|
||||
agent, auth, "budget", "Orders",
|
||||
filter={"Customer": 5, "Status": "active"},
|
||||
client=mock_client,
|
||||
)
|
||||
|
||||
# Verify filter was normalized: single values wrapped in lists
|
||||
mock_client.get_records.assert_called_once_with(
|
||||
"Orders",
|
||||
filter={"Customer": [5], "Status": ["active"]},
|
||||
sort=None,
|
||||
limit=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_records_preserves_list_filter(agent, auth, mock_client):
|
||||
"""Test that filter values already in list format are preserved."""
|
||||
mock_client.get_records.return_value = []
|
||||
|
||||
await get_records(
|
||||
agent, auth, "budget", "Orders",
|
||||
filter={"Customer": [5, 6, 7]},
|
||||
client=mock_client,
|
||||
)
|
||||
|
||||
mock_client.get_records.assert_called_once_with(
|
||||
"Orders",
|
||||
filter={"Customer": [5, 6, 7]},
|
||||
sort=None,
|
||||
limit=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_query(agent, auth, mock_client):
|
||||
result = await sql_query(agent, auth, "budget", "SELECT * FROM Table1", client=mock_client)
|
||||
|
||||
@@ -81,6 +81,25 @@ async def test_add_column(auth, mock_client):
|
||||
)
|
||||
|
||||
assert result == {"column_id": "NewCol"}
|
||||
mock_client.add_column.assert_called_once_with(
|
||||
"Table1", "NewCol", "Text", formula=None, label=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_column_with_label(auth, mock_client):
|
||||
agent = auth.authenticate("schema-token")
|
||||
|
||||
result = await add_column(
|
||||
agent, auth, "budget", "Table1", "first_name", "Text",
|
||||
label="First Name",
|
||||
client=mock_client,
|
||||
)
|
||||
|
||||
assert result == {"column_id": "NewCol"}
|
||||
mock_client.add_column.assert_called_once_with(
|
||||
"Table1", "first_name", "Text", formula=None, label="First Name"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -95,6 +114,25 @@ async def test_modify_column(auth, mock_client):
|
||||
)
|
||||
|
||||
assert result == {"modified": True}
|
||||
mock_client.modify_column.assert_called_once_with(
|
||||
"Table1", "Col1", type="Int", formula="$A + $B", label=None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modify_column_with_label(auth, mock_client):
|
||||
agent = auth.authenticate("schema-token")
|
||||
|
||||
result = await modify_column(
|
||||
agent, auth, "budget", "Table1", "Col1",
|
||||
label="Column One",
|
||||
client=mock_client,
|
||||
)
|
||||
|
||||
assert result == {"modified": True}
|
||||
mock_client.modify_column.assert_called_once_with(
|
||||
"Table1", "Col1", type=None, formula=None, label="Column One"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
126
tests/unit/test_tools_session.py
Normal file
126
tests/unit/test_tools_session.py
Normal file
@@ -0,0 +1,126 @@
|
||||
import pytest
|
||||
from grist_mcp.tools.session import get_proxy_documentation, request_session_token
|
||||
from grist_mcp.auth import Authenticator, Agent, AuthError
|
||||
from grist_mcp.config import Config, Document, Token, TokenScope
|
||||
from grist_mcp.session import SessionTokenManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_config():
|
||||
return Config(
|
||||
documents={
|
||||
"sales": Document(
|
||||
url="https://grist.example.com",
|
||||
doc_id="abc123",
|
||||
api_key="key",
|
||||
),
|
||||
},
|
||||
tokens=[
|
||||
Token(
|
||||
token="agent-token",
|
||||
name="test-agent",
|
||||
scope=[
|
||||
TokenScope(document="sales", permissions=["read", "write"]),
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def auth_and_agent(sample_config):
|
||||
auth = Authenticator(sample_config)
|
||||
agent = auth.authenticate("agent-token")
|
||||
return auth, agent
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_proxy_documentation_returns_complete_spec():
|
||||
result = await get_proxy_documentation()
|
||||
|
||||
assert "description" in result
|
||||
assert "endpoints" in result
|
||||
assert "proxy" in result["endpoints"]
|
||||
assert "attachments_upload" in result["endpoints"]
|
||||
assert "attachments_download" in result["endpoints"]
|
||||
assert "authentication" in result
|
||||
assert "methods" in result
|
||||
assert "add_records" in result["methods"]
|
||||
assert "get_records" in result["methods"]
|
||||
assert "attachment_upload" in result
|
||||
assert "attachment_download" in result
|
||||
assert "example_script" in result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_session_token_creates_valid_token(auth_and_agent):
|
||||
auth, agent = auth_and_agent
|
||||
manager = SessionTokenManager()
|
||||
|
||||
result = await request_session_token(
|
||||
agent=agent,
|
||||
auth=auth,
|
||||
token_manager=manager,
|
||||
document="sales",
|
||||
permissions=["read", "write"],
|
||||
ttl_seconds=300,
|
||||
)
|
||||
|
||||
assert "token" in result
|
||||
assert result["token"].startswith("sess_")
|
||||
assert result["document"] == "sales"
|
||||
assert result["permissions"] == ["read", "write"]
|
||||
assert "expires_at" in result
|
||||
assert result["proxy_url"] == "/api/v1/proxy"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_session_token_rejects_unauthorized_document(sample_config):
|
||||
auth = Authenticator(sample_config)
|
||||
agent = auth.authenticate("agent-token")
|
||||
manager = SessionTokenManager()
|
||||
|
||||
with pytest.raises(AuthError, match="Document not in scope"):
|
||||
await request_session_token(
|
||||
agent=agent,
|
||||
auth=auth,
|
||||
token_manager=manager,
|
||||
document="unauthorized_doc",
|
||||
permissions=["read"],
|
||||
ttl_seconds=300,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_session_token_rejects_unauthorized_permission(sample_config):
|
||||
auth = Authenticator(sample_config)
|
||||
agent = auth.authenticate("agent-token")
|
||||
manager = SessionTokenManager()
|
||||
|
||||
# Agent has read/write on sales, but not schema
|
||||
with pytest.raises(AuthError, match="Permission denied"):
|
||||
await request_session_token(
|
||||
agent=agent,
|
||||
auth=auth,
|
||||
token_manager=manager,
|
||||
document="sales",
|
||||
permissions=["read", "schema"], # schema not granted
|
||||
ttl_seconds=300,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_session_token_rejects_invalid_permission(sample_config):
|
||||
auth = Authenticator(sample_config)
|
||||
agent = auth.authenticate("agent-token")
|
||||
manager = SessionTokenManager()
|
||||
|
||||
with pytest.raises(AuthError, match="Invalid permission"):
|
||||
await request_session_token(
|
||||
agent=agent,
|
||||
auth=auth,
|
||||
token_manager=manager,
|
||||
document="sales",
|
||||
permissions=["read", "invalid_perm"],
|
||||
ttl_seconds=300,
|
||||
)
|
||||
Reference in New Issue
Block a user