feat: normalize filter values to array format for Grist API
The Grist API requires all filter values to be arrays. This change adds automatic normalization of filter values in get_records, wrapping single values in lists before sending to the API. This fixes 400 errors when filtering on Ref columns with single integer IDs. Changes: - Add filters.py module with normalize_filter function - Update get_records to normalize filters before API call - Add Orders table with Ref column to mock Grist server - Add filter validation to mock server (rejects non-array values) - Fix shell script shebangs for portability (#!/usr/bin/env bash)
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
# scripts/get-test-instance-id.sh
|
# scripts/get-test-instance-id.sh
|
||||||
# Generate a unique instance ID from git branch for parallel test isolation
|
# Generate a unique instance ID from git branch for parallel test isolation
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
#!/bin/bash
|
#!/usr/bin/env bash
|
||||||
# scripts/run-integration-tests.sh
|
# scripts/run-integration-tests.sh
|
||||||
# Run integration tests with branch isolation and dynamic port discovery
|
# Run integration tests with branch isolation and dynamic port discovery
|
||||||
set -e
|
set -e
|
||||||
|
|||||||
37
src/grist_mcp/tools/filters.py
Normal file
37
src/grist_mcp/tools/filters.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Filter normalization for Grist API queries."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_filter_value(value: Any) -> list:
|
||||||
|
"""Ensure a filter value is a list.
|
||||||
|
|
||||||
|
Grist API expects filter values to be arrays.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Single value or list of values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Value wrapped in list, or original list if already a list.
|
||||||
|
"""
|
||||||
|
if isinstance(value, list):
|
||||||
|
return value
|
||||||
|
return [value]
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_filter(filter: dict | None) -> dict | None:
|
||||||
|
"""Normalize filter values to array format for Grist API.
|
||||||
|
|
||||||
|
Grist expects all filter values to be arrays. This function
|
||||||
|
wraps single values in lists.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter: Filter dict with column names as keys.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized filter dict, or None if input was None.
|
||||||
|
"""
|
||||||
|
if not filter:
|
||||||
|
return filter
|
||||||
|
|
||||||
|
return {key: normalize_filter_value(value) for key, value in filter.items()}
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from grist_mcp.auth import Agent, Authenticator, Permission
|
from grist_mcp.auth import Agent, Authenticator, Permission
|
||||||
from grist_mcp.grist_client import GristClient
|
from grist_mcp.grist_client import GristClient
|
||||||
|
from grist_mcp.tools.filters import normalize_filter
|
||||||
|
|
||||||
|
|
||||||
async def list_tables(
|
async def list_tables(
|
||||||
@@ -56,7 +57,10 @@ async def get_records(
|
|||||||
doc = auth.get_document(document)
|
doc = auth.get_document(document)
|
||||||
client = GristClient(doc)
|
client = GristClient(doc)
|
||||||
|
|
||||||
records = await client.get_records(table, filter=filter, sort=sort, limit=limit)
|
# Normalize filter values to array format for Grist API
|
||||||
|
normalized_filter = normalize_filter(filter)
|
||||||
|
|
||||||
|
records = await client.get_records(table, filter=normalized_filter, sort=sort, limit=limit)
|
||||||
return {"records": records}
|
return {"records": records}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,18 @@ MOCK_TABLES = {
|
|||||||
{"id": 2, "fields": {"Title": "Deploy", "Done": False}},
|
{"id": 2, "fields": {"Title": "Deploy", "Done": False}},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
|
"Orders": {
|
||||||
|
"columns": [
|
||||||
|
{"id": "OrderNum", "fields": {"type": "Int"}},
|
||||||
|
{"id": "Customer", "fields": {"type": "Ref:People"}},
|
||||||
|
{"id": "Amount", "fields": {"type": "Numeric"}},
|
||||||
|
],
|
||||||
|
"records": [
|
||||||
|
{"id": 1, "fields": {"OrderNum": 1001, "Customer": 1, "Amount": 100.0}},
|
||||||
|
{"id": 2, "fields": {"OrderNum": 1002, "Customer": 2, "Amount": 200.0}},
|
||||||
|
{"id": 3, "fields": {"OrderNum": 1003, "Customer": 1, "Amount": 150.0}},
|
||||||
|
],
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Track requests for test assertions
|
# Track requests for test assertions
|
||||||
@@ -93,12 +105,40 @@ async def get_records(request):
|
|||||||
"""GET /api/docs/{doc_id}/tables/{table_id}/records"""
|
"""GET /api/docs/{doc_id}/tables/{table_id}/records"""
|
||||||
doc_id = request.path_params["doc_id"]
|
doc_id = request.path_params["doc_id"]
|
||||||
table_id = request.path_params["table_id"]
|
table_id = request.path_params["table_id"]
|
||||||
log_request("GET", f"/api/docs/{doc_id}/tables/{table_id}/records")
|
filter_param = request.query_params.get("filter")
|
||||||
|
log_request("GET", f"/api/docs/{doc_id}/tables/{table_id}/records?filter={filter_param}")
|
||||||
|
|
||||||
if table_id not in MOCK_TABLES:
|
if table_id not in MOCK_TABLES:
|
||||||
return JSONResponse({"error": "Table not found"}, status_code=404)
|
return JSONResponse({"error": "Table not found"}, status_code=404)
|
||||||
|
|
||||||
return JSONResponse({"records": MOCK_TABLES[table_id]["records"]})
|
records = MOCK_TABLES[table_id]["records"]
|
||||||
|
|
||||||
|
# Apply filtering if provided
|
||||||
|
if filter_param:
|
||||||
|
try:
|
||||||
|
filters = json.loads(filter_param)
|
||||||
|
# Validate filter format: all values must be arrays (Grist API requirement)
|
||||||
|
for key, values in filters.items():
|
||||||
|
if not isinstance(values, list):
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": f"Filter values must be arrays, got {type(values).__name__} for '{key}'"},
|
||||||
|
status_code=400
|
||||||
|
)
|
||||||
|
# Apply filters: record matches if field value is in the filter list
|
||||||
|
filtered_records = []
|
||||||
|
for record in records:
|
||||||
|
match = True
|
||||||
|
for key, allowed_values in filters.items():
|
||||||
|
if record["fields"].get(key) not in allowed_values:
|
||||||
|
match = False
|
||||||
|
break
|
||||||
|
if match:
|
||||||
|
filtered_records.append(record)
|
||||||
|
records = filtered_records
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return JSONResponse({"error": "Invalid filter JSON"}, status_code=400)
|
||||||
|
|
||||||
|
return JSONResponse({"records": records})
|
||||||
|
|
||||||
|
|
||||||
async def add_records(request):
|
async def add_records(request):
|
||||||
|
|||||||
@@ -90,6 +90,36 @@ async def test_all_tools(services_ready):
|
|||||||
log = get_mock_request_log()
|
log = get_mock_request_log()
|
||||||
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
|
assert any("/records" in entry["path"] and entry["method"] == "GET" for entry in log)
|
||||||
|
|
||||||
|
# Test get_records with Ref column filter
|
||||||
|
# This tests that single values are normalized to arrays for the Grist API
|
||||||
|
clear_mock_request_log()
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_records",
|
||||||
|
{"document": "test-doc", "table": "Orders", "filter": {"Customer": 1}}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "records" in data
|
||||||
|
# Should return only orders for Customer 1 (orders 1 and 3)
|
||||||
|
assert len(data["records"]) == 2
|
||||||
|
for record in data["records"]:
|
||||||
|
assert record["Customer"] == 1
|
||||||
|
log = get_mock_request_log()
|
||||||
|
# Verify the filter was sent as array format
|
||||||
|
filter_requests = [e for e in log if "/records" in e["path"] and "filter=" in e["path"]]
|
||||||
|
assert len(filter_requests) >= 1
|
||||||
|
# The filter value should be [1] not 1
|
||||||
|
assert "[1]" in filter_requests[0]["path"]
|
||||||
|
|
||||||
|
# Test get_records with multiple filter values
|
||||||
|
clear_mock_request_log()
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_records",
|
||||||
|
{"document": "test-doc", "table": "Orders", "filter": {"Customer": [1, 2]}}
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert "records" in data
|
||||||
|
assert len(data["records"]) == 3 # All 3 orders (customers 1 and 2)
|
||||||
|
|
||||||
# Test sql_query
|
# Test sql_query
|
||||||
clear_mock_request_log()
|
clear_mock_request_log()
|
||||||
result = await client.call_tool(
|
result = await client.call_tool(
|
||||||
|
|||||||
89
tests/unit/test_filters.py
Normal file
89
tests/unit/test_filters.py
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
"""Unit tests for filter normalization."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from grist_mcp.tools.filters import normalize_filter, normalize_filter_value
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeFilterValue:
|
||||||
|
"""Tests for normalize_filter_value function."""
|
||||||
|
|
||||||
|
def test_int_becomes_list(self):
|
||||||
|
assert normalize_filter_value(5) == [5]
|
||||||
|
|
||||||
|
def test_string_becomes_list(self):
|
||||||
|
assert normalize_filter_value("foo") == ["foo"]
|
||||||
|
|
||||||
|
def test_float_becomes_list(self):
|
||||||
|
assert normalize_filter_value(3.14) == [3.14]
|
||||||
|
|
||||||
|
def test_list_unchanged(self):
|
||||||
|
assert normalize_filter_value([1, 2, 3]) == [1, 2, 3]
|
||||||
|
|
||||||
|
def test_empty_list_unchanged(self):
|
||||||
|
assert normalize_filter_value([]) == []
|
||||||
|
|
||||||
|
def test_single_item_list_unchanged(self):
|
||||||
|
assert normalize_filter_value([42]) == [42]
|
||||||
|
|
||||||
|
def test_mixed_type_list_unchanged(self):
|
||||||
|
assert normalize_filter_value([1, "foo", 3.14]) == [1, "foo", 3.14]
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeFilter:
|
||||||
|
"""Tests for normalize_filter function."""
|
||||||
|
|
||||||
|
def test_none_returns_none(self):
|
||||||
|
assert normalize_filter(None) is None
|
||||||
|
|
||||||
|
def test_empty_dict_returns_empty_dict(self):
|
||||||
|
assert normalize_filter({}) == {}
|
||||||
|
|
||||||
|
def test_single_int_value_wrapped(self):
|
||||||
|
result = normalize_filter({"Transaction": 44})
|
||||||
|
assert result == {"Transaction": [44]}
|
||||||
|
|
||||||
|
def test_single_string_value_wrapped(self):
|
||||||
|
result = normalize_filter({"Status": "active"})
|
||||||
|
assert result == {"Status": ["active"]}
|
||||||
|
|
||||||
|
def test_list_value_unchanged(self):
|
||||||
|
result = normalize_filter({"Transaction": [44, 45, 46]})
|
||||||
|
assert result == {"Transaction": [44, 45, 46]}
|
||||||
|
|
||||||
|
def test_mixed_columns_all_normalized(self):
|
||||||
|
"""Both ref and non-ref columns are normalized to arrays."""
|
||||||
|
result = normalize_filter({
|
||||||
|
"Transaction": 44, # Ref column (int)
|
||||||
|
"Debit": 500, # Non-ref column (int)
|
||||||
|
"Memo": "test", # Non-ref column (str)
|
||||||
|
})
|
||||||
|
assert result == {
|
||||||
|
"Transaction": [44],
|
||||||
|
"Debit": [500],
|
||||||
|
"Memo": ["test"],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_multiple_values_list_unchanged(self):
|
||||||
|
"""Filter with multiple values passes through."""
|
||||||
|
result = normalize_filter({
|
||||||
|
"Status": ["pending", "active"],
|
||||||
|
"Priority": [1, 2, 3],
|
||||||
|
})
|
||||||
|
assert result == {
|
||||||
|
"Status": ["pending", "active"],
|
||||||
|
"Priority": [1, 2, 3],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_mixed_single_and_list_values(self):
|
||||||
|
"""Mix of single values and lists."""
|
||||||
|
result = normalize_filter({
|
||||||
|
"Transaction": 44, # Single int
|
||||||
|
"Status": ["open", "closed"], # List
|
||||||
|
"Amount": 100.50, # Single float
|
||||||
|
})
|
||||||
|
assert result == {
|
||||||
|
"Transaction": [44],
|
||||||
|
"Status": ["open", "closed"],
|
||||||
|
"Amount": [100.50],
|
||||||
|
}
|
||||||
@@ -75,6 +75,45 @@ async def test_get_records(agent, auth, mock_client):
|
|||||||
assert result == {"records": [{"id": 1, "Name": "Alice"}]}
|
assert result == {"records": [{"id": 1, "Name": "Alice"}]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_records_normalizes_filter(agent, auth, mock_client):
|
||||||
|
"""Test that filter values are normalized to array format for Grist API."""
|
||||||
|
mock_client.get_records.return_value = [{"id": 1, "Customer": 5}]
|
||||||
|
|
||||||
|
await get_records(
|
||||||
|
agent, auth, "budget", "Orders",
|
||||||
|
filter={"Customer": 5, "Status": "active"},
|
||||||
|
client=mock_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify filter was normalized: single values wrapped in lists
|
||||||
|
mock_client.get_records.assert_called_once_with(
|
||||||
|
"Orders",
|
||||||
|
filter={"Customer": [5], "Status": ["active"]},
|
||||||
|
sort=None,
|
||||||
|
limit=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_records_preserves_list_filter(agent, auth, mock_client):
|
||||||
|
"""Test that filter values already in list format are preserved."""
|
||||||
|
mock_client.get_records.return_value = []
|
||||||
|
|
||||||
|
await get_records(
|
||||||
|
agent, auth, "budget", "Orders",
|
||||||
|
filter={"Customer": [5, 6, 7]},
|
||||||
|
client=mock_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.get_records.assert_called_once_with(
|
||||||
|
"Orders",
|
||||||
|
filter={"Customer": [5, 6, 7]},
|
||||||
|
sort=None,
|
||||||
|
limit=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sql_query(agent, auth, mock_client):
|
async def test_sql_query(agent, auth, mock_client):
|
||||||
result = await sql_query(agent, auth, "budget", "SELECT * FROM Table1", client=mock_client)
|
result = await sql_query(agent, auth, "budget", "SELECT * FROM Table1", client=mock_client)
|
||||||
|
|||||||
Reference in New Issue
Block a user