refactor: per-connection auth via Authorization header

Replace startup token authentication with per-SSE-connection auth.
Each client now passes Bearer token in Authorization header when
connecting. Server validates against config.yaml tokens and creates
isolated Server instance per connection.

- server.py: accept (auth, agent) instead of (config_path, token)
- main.py: extract Bearer token, authenticate, create server per connection
- Remove GRIST_MCP_TOKEN from docker-compose environments
This commit is contained in:
2026-01-01 08:49:58 -05:00
parent a2e8d76237
commit 8809095549
8 changed files with 58 additions and 35 deletions

View File

@@ -1,5 +1,6 @@
"""Main entry point for the MCP server with SSE transport."""
import json
import os
import sys
from typing import Any
@@ -8,7 +9,8 @@ import uvicorn
from mcp.server.sse import SseServerTransport
from grist_mcp.server import create_server
from grist_mcp.auth import AuthError
from grist_mcp.config import load_config
from grist_mcp.auth import Authenticator, AuthError
Scope = dict[str, Any]
@@ -16,6 +18,29 @@ Receive = Any
Send = Any
def _get_bearer_token(scope: Scope) -> str | None:
"""Extract Bearer token from Authorization header."""
headers = dict(scope.get("headers", []))
auth_header = headers.get(b"authorization", b"").decode()
if auth_header.startswith("Bearer "):
return auth_header[7:]
return None
async def send_error(send: Send, status: int, message: str) -> None:
"""Send an HTTP error response."""
body = json.dumps({"error": message}).encode()
await send({
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
})
await send({
"type": "http.response.body",
"body": body,
})
def create_app():
"""Create the ASGI application."""
config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml")
@@ -24,15 +49,27 @@ def create_app():
print(f"Error: Config file not found at {config_path}", file=sys.stderr)
sys.exit(1)
try:
server = create_server(config_path)
except AuthError as e:
print(f"Authentication error: {e}", file=sys.stderr)
sys.exit(1)
config = load_config(config_path)
auth = Authenticator(config)
sse = SseServerTransport("/messages")
async def handle_sse(scope: Scope, receive: Receive, send: Send) -> None:
# Extract and validate token from Authorization header
token = _get_bearer_token(scope)
if not token:
await send_error(send, 401, "Missing Authorization header")
return
try:
agent = auth.authenticate(token)
except AuthError as e:
await send_error(send, 401, str(e))
return
# Create a server instance for this authenticated connection
server = create_server(auth, agent)
async with sse.connect_sse(scope, receive, send) as streams:
await server.run(
streams[0], streams[1], server.create_initialization_options()

View File

@@ -1,13 +1,11 @@
"""MCP server setup and tool registration."""
import json
import os
from mcp.server import Server
from mcp.types import Tool, TextContent
from grist_mcp.config import load_config
from grist_mcp.auth import Authenticator, AuthError, Agent
from grist_mcp.auth import Authenticator, Agent
from grist_mcp.tools.discovery import list_documents as _list_documents
from grist_mcp.tools.read import list_tables as _list_tables
@@ -23,27 +21,18 @@ from grist_mcp.tools.schema import modify_column as _modify_column
from grist_mcp.tools.schema import delete_column as _delete_column
def create_server(config_path: str, token: str | None = None) -> Server:
"""Create and configure the MCP server.
def create_server(auth: Authenticator, agent: Agent) -> Server:
"""Create and configure the MCP server for an authenticated agent.
Args:
config_path: Path to the configuration YAML file.
token: Agent token for authentication. If not provided, reads from
GRIST_MCP_TOKEN environment variable.
auth: Authenticator instance for permission checks.
agent: The authenticated agent for this server instance.
Raises:
AuthError: If token is invalid or not provided.
Returns:
Configured MCP Server instance.
"""
config = load_config(config_path)
auth = Authenticator(config)
server = Server("grist-mcp")
# Authenticate agent from token (required for all tool calls)
auth_token = token or os.environ.get("GRIST_MCP_TOKEN")
if not auth_token:
raise AuthError("No token provided. Set GRIST_MCP_TOKEN environment variable.")
_current_agent: Agent = auth.authenticate(auth_token)
_current_agent = agent
@server.list_tools()
async def list_tools() -> list[Tool]: