fix: implement token-based authentication at server startup
- Server now authenticates from GRIST_MCP_TOKEN env var or token parameter - Removed unused code (_set_agent, nonlocal check) - Added AuthError handling in main.py - Updated test to pass token explicitly
This commit is contained in:
@@ -7,6 +7,7 @@ import sys
|
|||||||
from mcp.server.stdio import stdio_server
|
from mcp.server.stdio import stdio_server
|
||||||
|
|
||||||
from grist_mcp.server import create_server
|
from grist_mcp.server import create_server
|
||||||
|
from grist_mcp.auth import AuthError
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
@@ -16,7 +17,11 @@ async def main():
|
|||||||
print(f"Error: Config file not found at {config_path}", file=sys.stderr)
|
print(f"Error: Config file not found at {config_path}", file=sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
try:
|
||||||
server = create_server(config_path)
|
server = create_server(config_path)
|
||||||
|
except AuthError as e:
|
||||||
|
print(f"Authentication error: {e}", file=sys.stderr)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
async with stdio_server() as (read_stream, write_stream):
|
async with stdio_server() as (read_stream, write_stream):
|
||||||
await server.run(read_stream, write_stream, server.create_initialization_options())
|
await server.run(read_stream, write_stream, server.create_initialization_options())
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
"""MCP server setup and tool registration."""
|
"""MCP server setup and tool registration."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.server.stdio import stdio_server
|
|
||||||
from mcp.types import Tool, TextContent
|
from mcp.types import Tool, TextContent
|
||||||
|
|
||||||
from grist_mcp.config import load_config
|
from grist_mcp.config import load_config
|
||||||
@@ -21,14 +23,27 @@ from grist_mcp.tools.schema import modify_column as _modify_column
|
|||||||
from grist_mcp.tools.schema import delete_column as _delete_column
|
from grist_mcp.tools.schema import delete_column as _delete_column
|
||||||
|
|
||||||
|
|
||||||
def create_server(config_path: str) -> Server:
|
def create_server(config_path: str, token: str | None = None) -> Server:
|
||||||
"""Create and configure the MCP server."""
|
"""Create and configure the MCP server.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the configuration YAML file.
|
||||||
|
token: Agent token for authentication. If not provided, reads from
|
||||||
|
GRIST_MCP_TOKEN environment variable.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthError: If token is invalid or not provided.
|
||||||
|
"""
|
||||||
config = load_config(config_path)
|
config = load_config(config_path)
|
||||||
auth = Authenticator(config)
|
auth = Authenticator(config)
|
||||||
server = Server("grist-mcp")
|
server = Server("grist-mcp")
|
||||||
|
|
||||||
# Current agent context (set during authentication)
|
# Authenticate agent from token (required for all tool calls)
|
||||||
_current_agent: Agent | None = None
|
auth_token = token or os.environ.get("GRIST_MCP_TOKEN")
|
||||||
|
if not auth_token:
|
||||||
|
raise AuthError("No token provided. Set GRIST_MCP_TOKEN environment variable.")
|
||||||
|
|
||||||
|
_current_agent: Agent = auth.authenticate(auth_token)
|
||||||
|
|
||||||
@server.list_tools()
|
@server.list_tools()
|
||||||
async def list_tools() -> list[Tool]:
|
async def list_tools() -> list[Tool]:
|
||||||
@@ -203,11 +218,6 @@ def create_server(config_path: str) -> Server:
|
|||||||
|
|
||||||
@server.call_tool()
|
@server.call_tool()
|
||||||
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
||||||
nonlocal _current_agent
|
|
||||||
|
|
||||||
if _current_agent is None:
|
|
||||||
return [TextContent(type="text", text="Error: Not authenticated")]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if name == "list_documents":
|
if name == "list_documents":
|
||||||
result = await _list_documents(_current_agent)
|
result = await _list_documents(_current_agent)
|
||||||
@@ -269,7 +279,6 @@ def create_server(config_path: str) -> Server:
|
|||||||
else:
|
else:
|
||||||
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
return [TextContent(type="text", text=f"Unknown tool: {name}")]
|
||||||
|
|
||||||
import json
|
|
||||||
return [TextContent(type="text", text=json.dumps(result))]
|
return [TextContent(type="text", text=json.dumps(result))]
|
||||||
|
|
||||||
except AuthError as e:
|
except AuthError as e:
|
||||||
@@ -277,8 +286,4 @@ def create_server(config_path: str) -> Server:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return [TextContent(type="text", text=f"Error: {e}")]
|
return [TextContent(type="text", text=f"Error: {e}")]
|
||||||
|
|
||||||
# Store auth for external access
|
|
||||||
server._auth = auth
|
|
||||||
server._set_agent = lambda agent: setattr(server, '_current_agent', agent) or setattr(type(server), '_current_agent', agent)
|
|
||||||
|
|
||||||
return server
|
return server
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ tokens:
|
|||||||
permissions: [read, write, schema]
|
permissions: [read, write, schema]
|
||||||
""")
|
""")
|
||||||
|
|
||||||
server = create_server(str(config_file))
|
server = create_server(str(config_file), token="test-token")
|
||||||
|
|
||||||
# Server should have tools registered
|
# Server should have tools registered
|
||||||
assert server is not None
|
assert server is not None
|
||||||
|
|||||||
Reference in New Issue
Block a user