diff --git a/src/grist_mcp/main.py b/src/grist_mcp/main.py index 0357fff..7798088 100644 --- a/src/grist_mcp/main.py +++ b/src/grist_mcp/main.py @@ -11,6 +11,8 @@ from mcp.server.sse import SseServerTransport from grist_mcp.server import create_server from grist_mcp.config import Config, load_config from grist_mcp.auth import Authenticator, AuthError +from grist_mcp.session import SessionTokenManager +from grist_mcp.proxy import parse_proxy_request, dispatch_proxy_request, ProxyError Scope = dict[str, Any] @@ -41,6 +43,20 @@ async def send_error(send: Send, status: int, message: str) -> None: }) +async def send_json_response(send: Send, status: int, data: dict) -> None: + """Send a JSON response.""" + body = json.dumps(data).encode() + await send({ + "type": "http.response.start", + "status": status, + "headers": [[b"content-type", b"application/json"]], + }) + await send({ + "type": "http.response.body", + "body": body, + }) + + CONFIG_TEMPLATE = """\ # grist-mcp configuration # @@ -108,6 +124,7 @@ def _ensure_config(config_path: str) -> bool: def create_app(config: Config): """Create the ASGI application.""" auth = Authenticator(config) + token_manager = SessionTokenManager() sse = SseServerTransport("/messages") @@ -125,7 +142,7 @@ def create_app(config: Config): return # Create a server instance for this authenticated connection - server = create_server(auth, agent) + server = create_server(auth, agent, token_manager) async with sse.connect_sse(scope, receive, send) as streams: await server.run( @@ -157,6 +174,58 @@ def create_app(config: Config): "body": b'{"error":"Not found"}', }) + async def handle_proxy(scope: Scope, receive: Receive, send: Send) -> None: + # Extract token + token = _get_bearer_token(scope) + if not token: + await send_json_response(send, 401, { + "success": False, + "error": "Missing Authorization header", + "code": "INVALID_TOKEN", + }) + return + + # Validate session token + session = token_manager.validate_token(token) + if session is None: + await send_json_response(send, 401, { + "success": False, + "error": "Invalid or expired token", + "code": "TOKEN_EXPIRED", + }) + return + + # Read request body + body = b"" + while True: + message = await receive() + body += message.get("body", b"") + if not message.get("more_body", False): + break + + try: + request_data = json.loads(body) + except json.JSONDecodeError: + await send_json_response(send, 400, { + "success": False, + "error": "Invalid JSON", + "code": "INVALID_REQUEST", + }) + return + + # Parse and dispatch + try: + request = parse_proxy_request(request_data) + result = await dispatch_proxy_request(request, session, auth) + await send_json_response(send, 200, result) + except ProxyError as e: + status = 403 if e.code == "UNAUTHORIZED" else 400 + await send_json_response(send, status, { + "success": False, + "error": e.message, + "code": e.code, + }) + async def app(scope: Scope, receive: Receive, send: Send) -> None: if scope["type"] != "http": return @@ -170,6 +239,8 @@ def create_app(config: Config): await handle_sse(scope, receive, send) elif path == "/messages" and method == "POST": await handle_messages(scope, receive, send) + elif path == "/api/v1/proxy" and method == "POST": + await handle_proxy(scope, receive, send) else: await handle_not_found(scope, receive, send)