diff --git a/src/grist_mcp/proxy.py b/src/grist_mcp/proxy.py index ee96bea..3def85c 100644 --- a/src/grist_mcp/proxy.py +++ b/src/grist_mcp/proxy.py @@ -3,6 +3,10 @@ from dataclasses import dataclass from typing import Any +from grist_mcp.auth import Authenticator +from grist_mcp.grist_client import GristClient +from grist_mcp.session import SessionToken + class ProxyError(Exception): """Error during proxy request processing.""" @@ -64,3 +68,125 @@ def parse_proxy_request(body: dict[str, Any]) -> ProxyRequest: formula=body.get("formula"), type=body.get("type"), ) + + +# Map methods to required permissions +METHOD_PERMISSIONS = { + "list_tables": "read", + "describe_table": "read", + "get_records": "read", + "sql_query": "read", + "add_records": "write", + "update_records": "write", + "delete_records": "write", + "create_table": "schema", + "add_column": "schema", + "modify_column": "schema", + "delete_column": "schema", +} + + +async def dispatch_proxy_request( + request: ProxyRequest, + session: SessionToken, + auth: Authenticator, + client: GristClient | None = None, +) -> dict[str, Any]: + """Dispatch a proxy request to the appropriate handler.""" + # Check permission + required_perm = METHOD_PERMISSIONS.get(request.method) + if required_perm is None: + raise ProxyError(f"Unknown method: {request.method}", "INVALID_REQUEST") + + if required_perm not in session.permissions: + raise ProxyError( + f"Permission '{required_perm}' required for {request.method}", + "UNAUTHORIZED", + ) + + # Create client if not provided + if client is None: + doc = auth.get_document(session.document) + client = GristClient(doc) + + # Dispatch to appropriate method + try: + if request.method == "list_tables": + data = await client.list_tables() + return {"success": True, "data": {"tables": data}} + + elif request.method == "describe_table": + data = await client.describe_table(request.table) + return {"success": True, "data": {"table": request.table, "columns": data}} + + elif request.method == "get_records": + data = await client.get_records( + request.table, + filter=request.filter, + sort=request.sort, + limit=request.limit, + ) + return {"success": True, "data": {"records": data}} + + elif request.method == "sql_query": + if request.query is None: + raise ProxyError("Missing required field: query", "INVALID_REQUEST") + data = await client.sql_query(request.query) + return {"success": True, "data": {"records": data}} + + elif request.method == "add_records": + if request.records is None: + raise ProxyError("Missing required field: records", "INVALID_REQUEST") + data = await client.add_records(request.table, request.records) + return {"success": True, "data": {"record_ids": data}} + + elif request.method == "update_records": + if request.records is None: + raise ProxyError("Missing required field: records", "INVALID_REQUEST") + await client.update_records(request.table, request.records) + return {"success": True, "data": {"updated": len(request.records)}} + + elif request.method == "delete_records": + if request.record_ids is None: + raise ProxyError("Missing required field: record_ids", "INVALID_REQUEST") + await client.delete_records(request.table, request.record_ids) + return {"success": True, "data": {"deleted": len(request.record_ids)}} + + elif request.method == "create_table": + if request.table_id is None or request.columns is None: + raise ProxyError("Missing required fields: table_id, columns", "INVALID_REQUEST") + data = await client.create_table(request.table_id, request.columns) + return {"success": True, "data": {"table_id": data}} + + elif request.method == "add_column": + if request.column_id is None or request.column_type is None: + raise ProxyError("Missing required fields: column_id, column_type", "INVALID_REQUEST") + await client.add_column( + request.table, request.column_id, request.column_type, + formula=request.formula, + ) + return {"success": True, "data": {"column_id": request.column_id}} + + elif request.method == "modify_column": + if request.column_id is None: + raise ProxyError("Missing required field: column_id", "INVALID_REQUEST") + await client.modify_column( + request.table, request.column_id, + type=request.type, + formula=request.formula, + ) + return {"success": True, "data": {"column_id": request.column_id}} + + elif request.method == "delete_column": + if request.column_id is None: + raise ProxyError("Missing required field: column_id", "INVALID_REQUEST") + await client.delete_column(request.table, request.column_id) + return {"success": True, "data": {"deleted": request.column_id}} + + else: + raise ProxyError(f"Unknown method: {request.method}", "INVALID_REQUEST") + + except ProxyError: + raise + except Exception as e: + raise ProxyError(str(e), "GRIST_ERROR") diff --git a/tests/unit/test_proxy.py b/tests/unit/test_proxy.py index b59000d..8eb8664 100644 --- a/tests/unit/test_proxy.py +++ b/tests/unit/test_proxy.py @@ -1,5 +1,33 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock + import pytest -from grist_mcp.proxy import parse_proxy_request, ProxyRequest, ProxyError + +from grist_mcp.proxy import parse_proxy_request, ProxyRequest, ProxyError, dispatch_proxy_request +from grist_mcp.session import SessionToken + + +@pytest.fixture +def mock_session(): + return SessionToken( + token="sess_test", + document="sales", + permissions=["read", "write"], + agent_name="test-agent", + created_at=datetime.now(timezone.utc), + expires_at=datetime.now(timezone.utc), + ) + + +@pytest.fixture +def mock_auth(): + auth = MagicMock() + doc = MagicMock() + doc.url = "https://grist.example.com" + doc.doc_id = "abc123" + doc.api_key = "key" + auth.get_document.return_value = doc + return auth def test_parse_proxy_request_valid_add_records(): @@ -24,3 +52,23 @@ def test_parse_proxy_request_missing_method(): assert exc_info.value.code == "INVALID_REQUEST" assert "method" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_dispatch_add_records(mock_session, mock_auth): + request = ProxyRequest( + method="add_records", + table="Orders", + records=[{"item": "Widget"}], + ) + + mock_client = AsyncMock() + mock_client.add_records.return_value = [1, 2, 3] + + result = await dispatch_proxy_request( + request, mock_session, mock_auth, client=mock_client + ) + + assert result["success"] is True + assert result["data"]["record_ids"] == [1, 2, 3] + mock_client.add_records.assert_called_once_with("Orders", [{"item": "Widget"}]) diff --git a/uv.lock b/uv.lock index a09f2f5..d8b7372 100644 --- a/uv.lock +++ b/uv.lock @@ -153,7 +153,7 @@ wheels = [ [[package]] name = "grist-mcp" -version = "0.1.0" +version = "1.0.0" source = { editable = "." } dependencies = [ { name = "httpx" },