refactor: per-connection auth via Authorization header
Replace startup token authentication with per-SSE-connection auth. Each client now passes Bearer token in Authorization header when connecting. Server validates against config.yaml tokens and creates isolated Server instance per connection. - server.py: accept (auth, agent) instead of (config_path, token) - main.py: extract Bearer token, authenticate, create server per connection - Remove GRIST_MCP_TOKEN from docker-compose environments
This commit is contained in:
@@ -1,3 +1 @@
|
|||||||
PORT=3000
|
PORT=3010
|
||||||
GRIST_MCP_TOKEN=your-token-here
|
|
||||||
CONFIG_PATH=/app/config.yaml
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ services:
|
|||||||
- ../../src:/app/src:ro
|
- ../../src:/app/src:ro
|
||||||
- ../../config.yaml:/app/config.yaml:ro
|
- ../../config.yaml:/app/config.yaml:ro
|
||||||
environment:
|
environment:
|
||||||
- GRIST_MCP_TOKEN=${GRIST_MCP_TOKEN}
|
|
||||||
- CONFIG_PATH=/app/config.yaml
|
- CONFIG_PATH=/app/config.yaml
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:3000/health')"]
|
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:3000/health')"]
|
||||||
|
|||||||
@@ -1,3 +1 @@
|
|||||||
PORT=3000
|
PORT=3000
|
||||||
GRIST_MCP_TOKEN=your-production-token
|
|
||||||
CONFIG_PATH=/app/config.yaml
|
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- ./config.yaml:/app/config.yaml:ro
|
- ./config.yaml:/app/config.yaml:ro
|
||||||
environment:
|
environment:
|
||||||
- GRIST_MCP_TOKEN=${GRIST_MCP_TOKEN}
|
|
||||||
- CONFIG_PATH=/app/config.yaml
|
- CONFIG_PATH=/app/config.yaml
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
deploy:
|
deploy:
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ services:
|
|||||||
- "3000" # Dynamic port
|
- "3000" # Dynamic port
|
||||||
environment:
|
environment:
|
||||||
- CONFIG_PATH=/app/config.yaml
|
- CONFIG_PATH=/app/config.yaml
|
||||||
- GRIST_MCP_TOKEN=test-token
|
|
||||||
- PORT=3000
|
|
||||||
volumes:
|
volumes:
|
||||||
- ../../tests/integration/config.test.yaml:/app/config.yaml:ro
|
- ../../tests/integration/config.test.yaml:/app/config.yaml:ro
|
||||||
depends_on:
|
depends_on:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Main entry point for the MCP server with SSE transport."""
|
"""Main entry point for the MCP server with SSE transport."""
|
||||||
|
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -8,7 +9,8 @@ import uvicorn
|
|||||||
from mcp.server.sse import SseServerTransport
|
from mcp.server.sse import SseServerTransport
|
||||||
|
|
||||||
from grist_mcp.server import create_server
|
from grist_mcp.server import create_server
|
||||||
from grist_mcp.auth import AuthError
|
from grist_mcp.config import load_config
|
||||||
|
from grist_mcp.auth import Authenticator, AuthError
|
||||||
|
|
||||||
|
|
||||||
Scope = dict[str, Any]
|
Scope = dict[str, Any]
|
||||||
@@ -16,6 +18,29 @@ Receive = Any
|
|||||||
Send = Any
|
Send = Any
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bearer_token(scope: Scope) -> str | None:
|
||||||
|
"""Extract Bearer token from Authorization header."""
|
||||||
|
headers = dict(scope.get("headers", []))
|
||||||
|
auth_header = headers.get(b"authorization", b"").decode()
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
return auth_header[7:]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def send_error(send: Send, status: int, message: str) -> None:
|
||||||
|
"""Send an HTTP error response."""
|
||||||
|
body = json.dumps({"error": message}).encode()
|
||||||
|
await send({
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": status,
|
||||||
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
|
})
|
||||||
|
await send({
|
||||||
|
"type": "http.response.body",
|
||||||
|
"body": body,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
"""Create the ASGI application."""
|
"""Create the ASGI application."""
|
||||||
config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml")
|
config_path = os.environ.get("CONFIG_PATH", "/app/config.yaml")
|
||||||
@@ -24,15 +49,27 @@ def create_app():
|
|||||||
print(f"Error: Config file not found at {config_path}", file=sys.stderr)
|
print(f"Error: Config file not found at {config_path}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
try:
|
config = load_config(config_path)
|
||||||
server = create_server(config_path)
|
auth = Authenticator(config)
|
||||||
except AuthError as e:
|
|
||||||
print(f"Authentication error: {e}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
sse = SseServerTransport("/messages")
|
sse = SseServerTransport("/messages")
|
||||||
|
|
||||||
async def handle_sse(scope: Scope, receive: Receive, send: Send) -> None:
|
async def handle_sse(scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
|
# Extract and validate token from Authorization header
|
||||||
|
token = _get_bearer_token(scope)
|
||||||
|
if not token:
|
||||||
|
await send_error(send, 401, "Missing Authorization header")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = auth.authenticate(token)
|
||||||
|
except AuthError as e:
|
||||||
|
await send_error(send, 401, str(e))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create a server instance for this authenticated connection
|
||||||
|
server = create_server(auth, agent)
|
||||||
|
|
||||||
async with sse.connect_sse(scope, receive, send) as streams:
|
async with sse.connect_sse(scope, receive, send) as streams:
|
||||||
await server.run(
|
await server.run(
|
||||||
streams[0], streams[1], server.create_initialization_options()
|
streams[0], streams[1], server.create_initialization_options()
|
||||||
|
|||||||
@@ -1,13 +1,11 @@
|
|||||||
"""MCP server setup and tool registration."""
|
"""MCP server setup and tool registration."""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
|
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.types import Tool, TextContent
|
from mcp.types import Tool, TextContent
|
||||||
|
|
||||||
from grist_mcp.config import load_config
|
from grist_mcp.auth import Authenticator, Agent
|
||||||
from grist_mcp.auth import Authenticator, AuthError, Agent
|
|
||||||
|
|
||||||
from grist_mcp.tools.discovery import list_documents as _list_documents
|
from grist_mcp.tools.discovery import list_documents as _list_documents
|
||||||
from grist_mcp.tools.read import list_tables as _list_tables
|
from grist_mcp.tools.read import list_tables as _list_tables
|
||||||
@@ -23,27 +21,18 @@ from grist_mcp.tools.schema import modify_column as _modify_column
|
|||||||
from grist_mcp.tools.schema import delete_column as _delete_column
|
from grist_mcp.tools.schema import delete_column as _delete_column
|
||||||
|
|
||||||
|
|
||||||
def create_server(config_path: str, token: str | None = None) -> Server:
|
def create_server(auth: Authenticator, agent: Agent) -> Server:
|
||||||
"""Create and configure the MCP server.
|
"""Create and configure the MCP server for an authenticated agent.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_path: Path to the configuration YAML file.
|
auth: Authenticator instance for permission checks.
|
||||||
token: Agent token for authentication. If not provided, reads from
|
agent: The authenticated agent for this server instance.
|
||||||
GRIST_MCP_TOKEN environment variable.
|
|
||||||
|
|
||||||
Raises:
|
Returns:
|
||||||
AuthError: If token is invalid or not provided.
|
Configured MCP Server instance.
|
||||||
"""
|
"""
|
||||||
config = load_config(config_path)
|
|
||||||
auth = Authenticator(config)
|
|
||||||
server = Server("grist-mcp")
|
server = Server("grist-mcp")
|
||||||
|
_current_agent = agent
|
||||||
# Authenticate agent from token (required for all tool calls)
|
|
||||||
auth_token = token or os.environ.get("GRIST_MCP_TOKEN")
|
|
||||||
if not auth_token:
|
|
||||||
raise AuthError("No token provided. Set GRIST_MCP_TOKEN environment variable.")
|
|
||||||
|
|
||||||
_current_agent: Agent = auth.authenticate(auth_token)
|
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def list_tools() -> list[Tool]:
|
async def list_tools() -> list[Tool]:
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
import pytest
|
import pytest
|
||||||
from mcp.types import ListToolsRequest
|
from mcp.types import ListToolsRequest
|
||||||
from grist_mcp.server import create_server
|
from grist_mcp.server import create_server
|
||||||
|
from grist_mcp.config import load_config
|
||||||
|
from grist_mcp.auth import Authenticator
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -21,7 +23,10 @@ tokens:
|
|||||||
permissions: [read, write, schema]
|
permissions: [read, write, schema]
|
||||||
""")
|
""")
|
||||||
|
|
||||||
server = create_server(str(config_file), token="test-token")
|
config = load_config(str(config_file))
|
||||||
|
auth = Authenticator(config)
|
||||||
|
agent = auth.authenticate("test-token")
|
||||||
|
server = create_server(auth, agent)
|
||||||
|
|
||||||
# Server should have tools registered
|
# Server should have tools registered
|
||||||
assert server is not None
|
assert server is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user