fix: implement token-based authentication at server startup

- Server now authenticates from GRIST_MCP_TOKEN env var or token parameter
- Removed unused code (_set_agent, nonlocal check)
- Added AuthError handling in main.py
- Updated test to pass token explicitly
This commit is contained in:
2025-12-03 15:07:06 -05:00
parent 1ed5554944
commit f716e5d37e
3 changed files with 27 additions and 17 deletions

View File

@@ -1,7 +1,9 @@
"""MCP server setup and tool registration."""
import json
import os
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
from grist_mcp.config import load_config
@@ -21,14 +23,27 @@ from grist_mcp.tools.schema import modify_column as _modify_column
from grist_mcp.tools.schema import delete_column as _delete_column
def create_server(config_path: str) -> Server:
"""Create and configure the MCP server."""
def create_server(config_path: str, token: str | None = None) -> Server:
"""Create and configure the MCP server.
Args:
config_path: Path to the configuration YAML file.
token: Agent token for authentication. If not provided, reads from
GRIST_MCP_TOKEN environment variable.
Raises:
AuthError: If token is invalid or not provided.
"""
config = load_config(config_path)
auth = Authenticator(config)
server = Server("grist-mcp")
# Current agent context (set during authentication)
_current_agent: Agent | None = None
# Authenticate agent from token (required for all tool calls)
auth_token = token or os.environ.get("GRIST_MCP_TOKEN")
if not auth_token:
raise AuthError("No token provided. Set GRIST_MCP_TOKEN environment variable.")
_current_agent: Agent = auth.authenticate(auth_token)
@server.list_tools()
async def list_tools() -> list[Tool]:
@@ -203,11 +218,6 @@ def create_server(config_path: str) -> Server:
@server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
nonlocal _current_agent
if _current_agent is None:
return [TextContent(type="text", text="Error: Not authenticated")]
try:
if name == "list_documents":
result = await _list_documents(_current_agent)
@@ -269,7 +279,6 @@ def create_server(config_path: str) -> Server:
else:
return [TextContent(type="text", text=f"Unknown tool: {name}")]
import json
return [TextContent(type="text", text=json.dumps(result))]
except AuthError as e:
@@ -277,8 +286,4 @@ def create_server(config_path: str) -> Server:
except Exception as e:
return [TextContent(type="text", text=f"Error: {e}")]
# Store auth for external access
server._auth = auth
server._set_agent = lambda agent: setattr(server, '_current_agent', agent) or setattr(type(server), '_current_agent', agent)
return server