From a97930848bb1f414728d75d2189ea9afb8daddbb Mon Sep 17 00:00:00 2001 From: Bill Ballou Date: Wed, 14 Jan 2026 17:56:18 -0500 Subject: [PATCH] feat: normalize filter values to array format for Grist API The Grist API requires all filter values to be arrays. This change adds automatic normalization of filter values in get_records, wrapping single values in lists before sending to the API. This fixes 400 errors when filtering on Ref columns with single integer IDs. Changes: - Add filters.py module with normalize_filter function - Update get_records to normalize filters before API call - Add Orders table with Ref column to mock Grist server - Add filter validation to mock server (rejects non-array values) - Fix shell script shebangs for portability (#!/usr/bin/env bash) --- scripts/get-test-instance-id.sh | 2 +- scripts/run-integration-tests.sh | 2 +- src/grist_mcp/tools/filters.py | 37 +++++++++ src/grist_mcp/tools/read.py | 6 +- tests/integration/mock_grist/server.py | 44 +++++++++- tests/integration/test_tools_integration.py | 30 +++++++ tests/unit/test_filters.py | 89 +++++++++++++++++++++ tests/unit/test_tools_read.py | 39 +++++++++ 8 files changed, 244 insertions(+), 5 deletions(-) create mode 100644 src/grist_mcp/tools/filters.py create mode 100644 tests/unit/test_filters.py diff --git a/scripts/get-test-instance-id.sh b/scripts/get-test-instance-id.sh index 911b8bc..0af15bd 100755 --- a/scripts/get-test-instance-id.sh +++ b/scripts/get-test-instance-id.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # scripts/get-test-instance-id.sh # Generate a unique instance ID from git branch for parallel test isolation diff --git a/scripts/run-integration-tests.sh b/scripts/run-integration-tests.sh index 82cc3b2..94d85af 100755 --- a/scripts/run-integration-tests.sh +++ b/scripts/run-integration-tests.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # scripts/run-integration-tests.sh # Run integration tests with branch isolation and dynamic port discovery set -e diff --git a/src/grist_mcp/tools/filters.py b/src/grist_mcp/tools/filters.py new file mode 100644 index 0000000..fcb69c7 --- /dev/null +++ b/src/grist_mcp/tools/filters.py @@ -0,0 +1,37 @@ +"""Filter normalization for Grist API queries.""" + +from typing import Any + + +def normalize_filter_value(value: Any) -> list: + """Ensure a filter value is a list. + + Grist API expects filter values to be arrays. + + Args: + value: Single value or list of values. + + Returns: + Value wrapped in list, or original list if already a list. + """ + if isinstance(value, list): + return value + return [value] + + +def normalize_filter(filter: dict | None) -> dict | None: + """Normalize filter values to array format for Grist API. + + Grist expects all filter values to be arrays. This function + wraps single values in lists. + + Args: + filter: Filter dict with column names as keys. + + Returns: + Normalized filter dict, or None if input was None. + """ + if not filter: + return filter + + return {key: normalize_filter_value(value) for key, value in filter.items()} diff --git a/src/grist_mcp/tools/read.py b/src/grist_mcp/tools/read.py index c330e91..ebbac17 100644 --- a/src/grist_mcp/tools/read.py +++ b/src/grist_mcp/tools/read.py @@ -2,6 +2,7 @@ from grist_mcp.auth import Agent, Authenticator, Permission from grist_mcp.grist_client import GristClient +from grist_mcp.tools.filters import normalize_filter async def list_tables( @@ -56,7 +57,10 @@ async def get_records( doc = auth.get_document(document) client = GristClient(doc) - records = await client.get_records(table, filter=filter, sort=sort, limit=limit) + # Normalize filter values to array format for Grist API + normalized_filter = normalize_filter(filter) + + records = await client.get_records(table, filter=normalized_filter, sort=sort, limit=limit) return {"records": records} diff --git a/tests/integration/mock_grist/server.py b/tests/integration/mock_grist/server.py index 75dee91..fe0443e 100644 --- a/tests/integration/mock_grist/server.py +++ b/tests/integration/mock_grist/server.py @@ -35,6 +35,18 @@ MOCK_TABLES = { {"id": 2, "fields": {"Title": "Deploy", "Done": False}}, ], }, + "Orders": { + "columns": [ + {"id": "OrderNum", "fields": {"type": "Int"}}, + {"id": "Customer", "fields": {"type": "Ref:People"}}, + {"id": "Amount", "fields": {"type": "Numeric"}}, + ], + "records": [ + {"id": 1, "fields": {"OrderNum": 1001, "Customer": 1, "Amount": 100.0}}, + {"id": 2, "fields": {"OrderNum": 1002, "Customer": 2, "Amount": 200.0}}, + {"id": 3, "fields": {"OrderNum": 1003, "Customer": 1, "Amount": 150.0}}, + ], + }, } # Track requests for test assertions @@ -93,12 +105,40 @@ async def get_records(request): """GET /api/docs/{doc_id}/tables/{table_id}/records""" doc_id = request.path_params["doc_id"] table_id = request.path_params["table_id"] - log_request("GET", f"/api/docs/{doc_id}/tables/{table_id}/records") + filter_param = request.query_params.get("filter") + log_request("GET", f"/api/docs/{doc_id}/tables/{table_id}/records?filter={filter_param}") if table_id not in MOCK_TABLES: return JSONResponse({"error": "Table not found"}, status_code=404) - return JSONResponse({"records": MOCK_TABLES[table_id]["records"]}) + records = MOCK_TABLES[table_id]["records"] + + # Apply filtering if provided + if filter_param: + try: + filters = json.loads(filter_param) + # Validate filter format: all values must be arrays (Grist API requirement) + for key, values in filters.items(): + if not isinstance(values, list): + return JSONResponse( + {"error": f"Filter values must be arrays, got {type(values).__name__} for '{key}'"}, + status_code=400 + ) + # Apply filters: record matches if field value is in the filter list + filtered_records = [] + for record in records: + match = True + for key, allowed_values in filters.items(): + if record["fields"].get(key) not in allowed_values: + match = False + break + if match: + filtered_records.append(record) + records = filtered_records + except json.JSONDecodeError: + return JSONResponse({"error": "Invalid filter JSON"}, status_code=400) + + return JSONResponse({"records": records}) async def add_records(request): diff --git a/tests/integration/test_tools_integration.py b/tests/integration/test_tools_integration.py index 022217b..0447332 100644 --- a/tests/integration/test_tools_integration.py +++ b/tests/integration/test_tools_integration.py @@ -90,6 +90,36 @@ async def test_all_tools(services_ready): log = get_mock_request_log() assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log) + # Test get_records with Ref column filter + # This tests that single values are normalized to arrays for the Grist API + clear_mock_request_log() + result = await client.call_tool( + "get_records", + {"document": "test-doc", "table": "Orders", "filter": {"Customer": 1}} + ) + data = json.loads(result.content[0].text) + assert "records" in data + # Should return only orders for Customer 1 (orders 1 and 3) + assert len(data["records"]) == 2 + for record in data["records"]: + assert record["Customer"] == 1 + log = get_mock_request_log() + # Verify the filter was sent as array format + filter_requests = [e for e in log if "/records" in e["path"] and "filter=" in e["path"]] + assert len(filter_requests) >= 1 + # The filter value should be [1] not 1 + assert "[1]" in filter_requests[0]["path"] + + # Test get_records with multiple filter values + clear_mock_request_log() + result = await client.call_tool( + "get_records", + {"document": "test-doc", "table": "Orders", "filter": {"Customer": [1, 2]}} + ) + data = json.loads(result.content[0].text) + assert "records" in data + assert len(data["records"]) == 3 # All 3 orders (customers 1 and 2) + # Test sql_query clear_mock_request_log() result = await client.call_tool( diff --git a/tests/unit/test_filters.py b/tests/unit/test_filters.py new file mode 100644 index 0000000..81ab601 --- /dev/null +++ b/tests/unit/test_filters.py @@ -0,0 +1,89 @@ +"""Unit tests for filter normalization.""" + +import pytest + +from grist_mcp.tools.filters import normalize_filter, normalize_filter_value + + +class TestNormalizeFilterValue: + """Tests for normalize_filter_value function.""" + + def test_int_becomes_list(self): + assert normalize_filter_value(5) == [5] + + def test_string_becomes_list(self): + assert normalize_filter_value("foo") == ["foo"] + + def test_float_becomes_list(self): + assert normalize_filter_value(3.14) == [3.14] + + def test_list_unchanged(self): + assert normalize_filter_value([1, 2, 3]) == [1, 2, 3] + + def test_empty_list_unchanged(self): + assert normalize_filter_value([]) == [] + + def test_single_item_list_unchanged(self): + assert normalize_filter_value([42]) == [42] + + def test_mixed_type_list_unchanged(self): + assert normalize_filter_value([1, "foo", 3.14]) == [1, "foo", 3.14] + + +class TestNormalizeFilter: + """Tests for normalize_filter function.""" + + def test_none_returns_none(self): + assert normalize_filter(None) is None + + def test_empty_dict_returns_empty_dict(self): + assert normalize_filter({}) == {} + + def test_single_int_value_wrapped(self): + result = normalize_filter({"Transaction": 44}) + assert result == {"Transaction": [44]} + + def test_single_string_value_wrapped(self): + result = normalize_filter({"Status": "active"}) + assert result == {"Status": ["active"]} + + def test_list_value_unchanged(self): + result = normalize_filter({"Transaction": [44, 45, 46]}) + assert result == {"Transaction": [44, 45, 46]} + + def test_mixed_columns_all_normalized(self): + """Both ref and non-ref columns are normalized to arrays.""" + result = normalize_filter({ + "Transaction": 44, # Ref column (int) + "Debit": 500, # Non-ref column (int) + "Memo": "test", # Non-ref column (str) + }) + assert result == { + "Transaction": [44], + "Debit": [500], + "Memo": ["test"], + } + + def test_multiple_values_list_unchanged(self): + """Filter with multiple values passes through.""" + result = normalize_filter({ + "Status": ["pending", "active"], + "Priority": [1, 2, 3], + }) + assert result == { + "Status": ["pending", "active"], + "Priority": [1, 2, 3], + } + + def test_mixed_single_and_list_values(self): + """Mix of single values and lists.""" + result = normalize_filter({ + "Transaction": 44, # Single int + "Status": ["open", "closed"], # List + "Amount": 100.50, # Single float + }) + assert result == { + "Transaction": [44], + "Status": ["open", "closed"], + "Amount": [100.50], + } diff --git a/tests/unit/test_tools_read.py b/tests/unit/test_tools_read.py index a36e576..569d26c 100644 --- a/tests/unit/test_tools_read.py +++ b/tests/unit/test_tools_read.py @@ -75,6 +75,45 @@ async def test_get_records(agent, auth, mock_client): assert result == {"records": [{"id": 1, "Name": "Alice"}]} +@pytest.mark.asyncio +async def test_get_records_normalizes_filter(agent, auth, mock_client): + """Test that filter values are normalized to array format for Grist API.""" + mock_client.get_records.return_value = [{"id": 1, "Customer": 5}] + + await get_records( + agent, auth, "budget", "Orders", + filter={"Customer": 5, "Status": "active"}, + client=mock_client, + ) + + # Verify filter was normalized: single values wrapped in lists + mock_client.get_records.assert_called_once_with( + "Orders", + filter={"Customer": [5], "Status": ["active"]}, + sort=None, + limit=None, + ) + + +@pytest.mark.asyncio +async def test_get_records_preserves_list_filter(agent, auth, mock_client): + """Test that filter values already in list format are preserved.""" + mock_client.get_records.return_value = [] + + await get_records( + agent, auth, "budget", "Orders", + filter={"Customer": [5, 6, 7]}, + client=mock_client, + ) + + mock_client.get_records.assert_called_once_with( + "Orders", + filter={"Customer": [5, 6, 7]}, + sort=None, + limit=None, + ) + + @pytest.mark.asyncio async def test_sql_query(agent, auth, mock_client): result = await sql_query(agent, auth, "budget", "SELECT * FROM Table1", client=mock_client)