feat(proxy): add method dispatch
This commit is contained in:
@@ -3,6 +3,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from grist_mcp.auth import Authenticator
|
||||
from grist_mcp.grist_client import GristClient
|
||||
from grist_mcp.session import SessionToken
|
||||
|
||||
|
||||
class ProxyError(Exception):
|
||||
"""Error during proxy request processing."""
|
||||
@@ -64,3 +68,125 @@ def parse_proxy_request(body: dict[str, Any]) -> ProxyRequest:
|
||||
formula=body.get("formula"),
|
||||
type=body.get("type"),
|
||||
)
|
||||
|
||||
|
||||
# Map methods to required permissions
|
||||
METHOD_PERMISSIONS = {
|
||||
"list_tables": "read",
|
||||
"describe_table": "read",
|
||||
"get_records": "read",
|
||||
"sql_query": "read",
|
||||
"add_records": "write",
|
||||
"update_records": "write",
|
||||
"delete_records": "write",
|
||||
"create_table": "schema",
|
||||
"add_column": "schema",
|
||||
"modify_column": "schema",
|
||||
"delete_column": "schema",
|
||||
}
|
||||
|
||||
|
||||
async def dispatch_proxy_request(
|
||||
request: ProxyRequest,
|
||||
session: SessionToken,
|
||||
auth: Authenticator,
|
||||
client: GristClient | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Dispatch a proxy request to the appropriate handler."""
|
||||
# Check permission
|
||||
required_perm = METHOD_PERMISSIONS.get(request.method)
|
||||
if required_perm is None:
|
||||
raise ProxyError(f"Unknown method: {request.method}", "INVALID_REQUEST")
|
||||
|
||||
if required_perm not in session.permissions:
|
||||
raise ProxyError(
|
||||
f"Permission '{required_perm}' required for {request.method}",
|
||||
"UNAUTHORIZED",
|
||||
)
|
||||
|
||||
# Create client if not provided
|
||||
if client is None:
|
||||
doc = auth.get_document(session.document)
|
||||
client = GristClient(doc)
|
||||
|
||||
# Dispatch to appropriate method
|
||||
try:
|
||||
if request.method == "list_tables":
|
||||
data = await client.list_tables()
|
||||
return {"success": True, "data": {"tables": data}}
|
||||
|
||||
elif request.method == "describe_table":
|
||||
data = await client.describe_table(request.table)
|
||||
return {"success": True, "data": {"table": request.table, "columns": data}}
|
||||
|
||||
elif request.method == "get_records":
|
||||
data = await client.get_records(
|
||||
request.table,
|
||||
filter=request.filter,
|
||||
sort=request.sort,
|
||||
limit=request.limit,
|
||||
)
|
||||
return {"success": True, "data": {"records": data}}
|
||||
|
||||
elif request.method == "sql_query":
|
||||
if request.query is None:
|
||||
raise ProxyError("Missing required field: query", "INVALID_REQUEST")
|
||||
data = await client.sql_query(request.query)
|
||||
return {"success": True, "data": {"records": data}}
|
||||
|
||||
elif request.method == "add_records":
|
||||
if request.records is None:
|
||||
raise ProxyError("Missing required field: records", "INVALID_REQUEST")
|
||||
data = await client.add_records(request.table, request.records)
|
||||
return {"success": True, "data": {"record_ids": data}}
|
||||
|
||||
elif request.method == "update_records":
|
||||
if request.records is None:
|
||||
raise ProxyError("Missing required field: records", "INVALID_REQUEST")
|
||||
await client.update_records(request.table, request.records)
|
||||
return {"success": True, "data": {"updated": len(request.records)}}
|
||||
|
||||
elif request.method == "delete_records":
|
||||
if request.record_ids is None:
|
||||
raise ProxyError("Missing required field: record_ids", "INVALID_REQUEST")
|
||||
await client.delete_records(request.table, request.record_ids)
|
||||
return {"success": True, "data": {"deleted": len(request.record_ids)}}
|
||||
|
||||
elif request.method == "create_table":
|
||||
if request.table_id is None or request.columns is None:
|
||||
raise ProxyError("Missing required fields: table_id, columns", "INVALID_REQUEST")
|
||||
data = await client.create_table(request.table_id, request.columns)
|
||||
return {"success": True, "data": {"table_id": data}}
|
||||
|
||||
elif request.method == "add_column":
|
||||
if request.column_id is None or request.column_type is None:
|
||||
raise ProxyError("Missing required fields: column_id, column_type", "INVALID_REQUEST")
|
||||
await client.add_column(
|
||||
request.table, request.column_id, request.column_type,
|
||||
formula=request.formula,
|
||||
)
|
||||
return {"success": True, "data": {"column_id": request.column_id}}
|
||||
|
||||
elif request.method == "modify_column":
|
||||
if request.column_id is None:
|
||||
raise ProxyError("Missing required field: column_id", "INVALID_REQUEST")
|
||||
await client.modify_column(
|
||||
request.table, request.column_id,
|
||||
type=request.type,
|
||||
formula=request.formula,
|
||||
)
|
||||
return {"success": True, "data": {"column_id": request.column_id}}
|
||||
|
||||
elif request.method == "delete_column":
|
||||
if request.column_id is None:
|
||||
raise ProxyError("Missing required field: column_id", "INVALID_REQUEST")
|
||||
await client.delete_column(request.table, request.column_id)
|
||||
return {"success": True, "data": {"deleted": request.column_id}}
|
||||
|
||||
else:
|
||||
raise ProxyError(f"Unknown method: {request.method}", "INVALID_REQUEST")
|
||||
|
||||
except ProxyError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise ProxyError(str(e), "GRIST_ERROR")
|
||||
|
||||
@@ -1,5 +1,33 @@
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from grist_mcp.proxy import parse_proxy_request, ProxyRequest, ProxyError
|
||||
|
||||
from grist_mcp.proxy import parse_proxy_request, ProxyRequest, ProxyError, dispatch_proxy_request
|
||||
from grist_mcp.session import SessionToken
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session():
|
||||
return SessionToken(
|
||||
token="sess_test",
|
||||
document="sales",
|
||||
permissions=["read", "write"],
|
||||
agent_name="test-agent",
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth():
|
||||
auth = MagicMock()
|
||||
doc = MagicMock()
|
||||
doc.url = "https://grist.example.com"
|
||||
doc.doc_id = "abc123"
|
||||
doc.api_key = "key"
|
||||
auth.get_document.return_value = doc
|
||||
return auth
|
||||
|
||||
|
||||
def test_parse_proxy_request_valid_add_records():
|
||||
@@ -24,3 +52,23 @@ def test_parse_proxy_request_missing_method():
|
||||
|
||||
assert exc_info.value.code == "INVALID_REQUEST"
|
||||
assert "method" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dispatch_add_records(mock_session, mock_auth):
|
||||
request = ProxyRequest(
|
||||
method="add_records",
|
||||
table="Orders",
|
||||
records=[{"item": "Widget"}],
|
||||
)
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.add_records.return_value = [1, 2, 3]
|
||||
|
||||
result = await dispatch_proxy_request(
|
||||
request, mock_session, mock_auth, client=mock_client
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"]["record_ids"] == [1, 2, 3]
|
||||
mock_client.add_records.assert_called_once_with("Orders", [{"item": "Widget"}])
|
||||
|
||||
Reference in New Issue
Block a user