From 1ed555494459ef935eb74f470c157dd42c77c4a2 Mon Sep 17 00:00:00 2001 From: Bill Date: Wed, 3 Dec 2025 15:00:48 -0500 Subject: [PATCH] feat: add MCP server with all tools registered --- src/grist_mcp/main.py | 26 ++++ src/grist_mcp/server.py | 284 ++++++++++++++++++++++++++++++++++++++++ tests/test_server.py | 52 ++++++++ 3 files changed, 362 insertions(+) create mode 100644 src/grist_mcp/main.py create mode 100644 src/grist_mcp/server.py create mode 100644 tests/test_server.py diff --git a/src/grist_mcp/main.py b/src/grist_mcp/main.py new file mode 100644 index 0000000..a7584f1 --- /dev/null +++ b/src/grist_mcp/main.py @@ -0,0 +1,26 @@ +"""Main entry point for the MCP server.""" + +import asyncio +import os +import sys + +from mcp.server.stdio import stdio_server + +from grist_mcp.server import create_server + + +async def main(): + config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml") + + if not os.path.exists(config_path): + print(f"Error: Config file not found at {config_path}", file=sys.stderr) + sys.exit(1) + + server = create_server(config_path) + + async with stdio_server() as (read_stream, write_stream): + await server.run(read_stream, write_stream, server.create_initialization_options()) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/grist_mcp/server.py b/src/grist_mcp/server.py new file mode 100644 index 0000000..263f707 --- /dev/null +++ b/src/grist_mcp/server.py @@ -0,0 +1,284 @@ +"""MCP server setup and tool registration.""" + +from mcp.server import Server +from mcp.server.stdio import stdio_server +from mcp.types import Tool, TextContent + +from grist_mcp.config import load_config +from grist_mcp.auth import Authenticator, AuthError, Agent + +from grist_mcp.tools.discovery import list_documents as _list_documents +from grist_mcp.tools.read import list_tables as _list_tables +from grist_mcp.tools.read import describe_table as _describe_table +from grist_mcp.tools.read import get_records as _get_records +from grist_mcp.tools.read import sql_query as _sql_query +from grist_mcp.tools.write import add_records as _add_records +from grist_mcp.tools.write import update_records as _update_records +from grist_mcp.tools.write import delete_records as _delete_records +from grist_mcp.tools.schema import create_table as _create_table +from grist_mcp.tools.schema import add_column as _add_column +from grist_mcp.tools.schema import modify_column as _modify_column +from grist_mcp.tools.schema import delete_column as _delete_column + + +def create_server(config_path: str) -> Server: + """Create and configure the MCP server.""" + config = load_config(config_path) + auth = Authenticator(config) + server = Server("grist-mcp") + + # Current agent context (set during authentication) + _current_agent: Agent | None = None + + @server.list_tools() + async def list_tools() -> list[Tool]: + return [ + Tool( + name="list_documents", + description="List documents this agent can access with their permissions", + inputSchema={"type": "object", "properties": {}, "required": []}, + ), + Tool( + name="list_tables", + description="List all tables in a document", + inputSchema={ + "type": "object", + "properties": {"document": {"type": "string", "description": "Document name"}}, + "required": ["document"], + }, + ), + Tool( + name="describe_table", + description="Get column information for a table", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table": {"type": "string"}, + }, + "required": ["document", "table"], + }, + ), + Tool( + name="get_records", + description="Fetch records from a table", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table": {"type": "string"}, + "filter": {"type": "object"}, + "sort": {"type": "string"}, + "limit": {"type": "integer"}, + }, + "required": ["document", "table"], + }, + ), + Tool( + name="sql_query", + description="Run a read-only SQL query against a document", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "query": {"type": "string"}, + }, + "required": ["document", "query"], + }, + ), + Tool( + name="add_records", + description="Add records to a table", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table": {"type": "string"}, + "records": {"type": "array", "items": {"type": "object"}}, + }, + "required": ["document", "table", "records"], + }, + ), + Tool( + name="update_records", + description="Update existing records", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table": {"type": "string"}, + "records": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "integer"}, + "fields": {"type": "object"}, + }, + }, + }, + }, + "required": ["document", "table", "records"], + }, + ), + Tool( + name="delete_records", + description="Delete records by ID", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table": {"type": "string"}, + "record_ids": {"type": "array", "items": {"type": "integer"}}, + }, + "required": ["document", "table", "record_ids"], + }, + ), + Tool( + name="create_table", + description="Create a new table with columns", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table_id": {"type": "string"}, + "columns": { + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "type": {"type": "string"}, + }, + }, + }, + }, + "required": ["document", "table_id", "columns"], + }, + ), + Tool( + name="add_column", + description="Add a column to a table", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table": {"type": "string"}, + "column_id": {"type": "string"}, + "column_type": {"type": "string"}, + "formula": {"type": "string"}, + }, + "required": ["document", "table", "column_id", "column_type"], + }, + ), + Tool( + name="modify_column", + description="Modify a column's type or formula", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table": {"type": "string"}, + "column_id": {"type": "string"}, + "type": {"type": "string"}, + "formula": {"type": "string"}, + }, + "required": ["document", "table", "column_id"], + }, + ), + Tool( + name="delete_column", + description="Delete a column from a table", + inputSchema={ + "type": "object", + "properties": { + "document": {"type": "string"}, + "table": {"type": "string"}, + "column_id": {"type": "string"}, + }, + "required": ["document", "table", "column_id"], + }, + ), + ] + + @server.call_tool() + async def call_tool(name: str, arguments: dict) -> list[TextContent]: + nonlocal _current_agent + + if _current_agent is None: + return [TextContent(type="text", text="Error: Not authenticated")] + + try: + if name == "list_documents": + result = await _list_documents(_current_agent) + elif name == "list_tables": + result = await _list_tables(_current_agent, auth, arguments["document"]) + elif name == "describe_table": + result = await _describe_table( + _current_agent, auth, arguments["document"], arguments["table"] + ) + elif name == "get_records": + result = await _get_records( + _current_agent, auth, arguments["document"], arguments["table"], + filter=arguments.get("filter"), + sort=arguments.get("sort"), + limit=arguments.get("limit"), + ) + elif name == "sql_query": + result = await _sql_query( + _current_agent, auth, arguments["document"], arguments["query"] + ) + elif name == "add_records": + result = await _add_records( + _current_agent, auth, arguments["document"], arguments["table"], + arguments["records"], + ) + elif name == "update_records": + result = await _update_records( + _current_agent, auth, arguments["document"], arguments["table"], + arguments["records"], + ) + elif name == "delete_records": + result = await _delete_records( + _current_agent, auth, arguments["document"], arguments["table"], + arguments["record_ids"], + ) + elif name == "create_table": + result = await _create_table( + _current_agent, auth, arguments["document"], arguments["table_id"], + arguments["columns"], + ) + elif name == "add_column": + result = await _add_column( + _current_agent, auth, arguments["document"], arguments["table"], + arguments["column_id"], arguments["column_type"], + formula=arguments.get("formula"), + ) + elif name == "modify_column": + result = await _modify_column( + _current_agent, auth, arguments["document"], arguments["table"], + arguments["column_id"], + type=arguments.get("type"), + formula=arguments.get("formula"), + ) + elif name == "delete_column": + result = await _delete_column( + _current_agent, auth, arguments["document"], arguments["table"], + arguments["column_id"], + ) + else: + return [TextContent(type="text", text=f"Unknown tool: {name}")] + + import json + return [TextContent(type="text", text=json.dumps(result))] + + except AuthError as e: + return [TextContent(type="text", text=f"Authorization error: {e}")] + except Exception as e: + return [TextContent(type="text", text=f"Error: {e}")] + + # Store auth for external access + server._auth = auth + server._set_agent = lambda agent: setattr(server, '_current_agent', agent) or setattr(type(server), '_current_agent', agent) + + return server diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..e6e4145 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,52 @@ +import pytest +from mcp.types import ListToolsRequest +from grist_mcp.server import create_server + + +@pytest.mark.asyncio +async def test_create_server_registers_tools(tmp_path): + config_file = tmp_path / "config.yaml" + config_file.write_text(""" +documents: + test-doc: + url: https://grist.example.com + doc_id: abc123 + api_key: test-key + +tokens: + - token: test-token + name: test-agent + scope: + - document: test-doc + permissions: [read, write, schema] +""") + + server = create_server(str(config_file)) + + # Server should have tools registered + assert server is not None + + # Get the list_tools handler and call it + handler = server.request_handlers.get(ListToolsRequest) + assert handler is not None + + req = ListToolsRequest(method="tools/list") + result = await handler(req) + + # Check tool names are registered + tool_names = [t.name for t in result.root.tools] + assert "list_documents" in tool_names + assert "list_tables" in tool_names + assert "describe_table" in tool_names + assert "get_records" in tool_names + assert "sql_query" in tool_names + assert "add_records" in tool_names + assert "update_records" in tool_names + assert "delete_records" in tool_names + assert "create_table" in tool_names + assert "add_column" in tool_names + assert "modify_column" in tool_names + assert "delete_column" in tool_names + + # Should have all 12 tools + assert len(result.root.tools) == 12