feat: add direct database access for scripts (v0.2.0)
Implement persistent SQLite database feature that allows scripts to query schedule data directly via SQL after loading XER files through MCP. Key changes: - Extend load_xer with db_path parameter for persistent database - Add get_database_info tool to retrieve database connection details - Add schema introspection with tables, columns, primary/foreign keys - Support WAL mode for concurrent read access - Use atomic write pattern to prevent corruption New features: - db_path=None: in-memory database (default, backward compatible) - db_path="": auto-generate path from XER filename (.sqlite extension) - db_path="/path/to/db": explicit persistent database path Response includes complete DatabaseInfo: - db_path: absolute path (or :memory:) - is_persistent: boolean - source_file: loaded XER path - loaded_at: ISO timestamp - schema: tables with columns, primary keys, foreign keys, row counts Closes: User Story 1, 2, 3 from 002-direct-db-access spec
This commit is contained in:
234
tests/unit/test_db_manager.py
Normal file
234
tests/unit/test_db_manager.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Unit tests for DatabaseManager file-based database support."""
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from xer_mcp.db import DatabaseManager
|
||||
|
||||
|
||||
class TestDatabaseManagerInitialization:
|
||||
"""Tests for DatabaseManager initialization modes."""
|
||||
|
||||
def test_initialize_with_memory_by_default(self) -> None:
|
||||
"""Default initialization uses in-memory database."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
assert dm.db_path == ":memory:"
|
||||
assert dm.is_persistent is False
|
||||
dm.close()
|
||||
|
||||
def test_initialize_with_file_path(self, tmp_path: Path) -> None:
|
||||
"""Can initialize with explicit file path."""
|
||||
db_file = tmp_path / "test.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(db_file))
|
||||
assert dm.db_path == str(db_file)
|
||||
assert dm.is_persistent is True
|
||||
assert db_file.exists()
|
||||
dm.close()
|
||||
|
||||
def test_initialize_with_empty_string_auto_generates_path(self, tmp_path: Path) -> None:
|
||||
"""Empty string db_path with source_file auto-generates path."""
|
||||
xer_file = tmp_path / "schedule.xer"
|
||||
xer_file.write_text("dummy content")
|
||||
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path="", source_file=str(xer_file))
|
||||
expected_db = str(tmp_path / "schedule.sqlite")
|
||||
assert dm.db_path == expected_db
|
||||
assert dm.is_persistent is True
|
||||
assert Path(expected_db).exists()
|
||||
dm.close()
|
||||
|
||||
def test_file_database_persists_after_close(self, tmp_path: Path) -> None:
|
||||
"""File-based database persists after connection close."""
|
||||
db_file = tmp_path / "persist_test.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(db_file))
|
||||
|
||||
# Insert test data
|
||||
with dm.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO projects (proj_id, proj_short_name, loaded_at) "
|
||||
"VALUES ('P1', 'Test', datetime('now'))"
|
||||
)
|
||||
dm.commit()
|
||||
dm.close()
|
||||
|
||||
# Verify file exists and has data
|
||||
assert db_file.exists()
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
cursor = conn.execute("SELECT proj_id FROM projects")
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
assert len(rows) == 1
|
||||
assert rows[0][0] == "P1"
|
||||
|
||||
def test_source_file_tracked(self, tmp_path: Path) -> None:
|
||||
"""Source file path is tracked when provided."""
|
||||
db_file = tmp_path / "test.db"
|
||||
xer_file = tmp_path / "schedule.xer"
|
||||
xer_file.write_text("dummy")
|
||||
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(db_file), source_file=str(xer_file))
|
||||
assert dm.source_file == str(xer_file)
|
||||
dm.close()
|
||||
|
||||
def test_loaded_at_timestamp(self, tmp_path: Path) -> None:
|
||||
"""Loaded_at timestamp is recorded."""
|
||||
db_file = tmp_path / "test.db"
|
||||
dm = DatabaseManager()
|
||||
before = datetime.now()
|
||||
dm.initialize(db_path=str(db_file))
|
||||
after = datetime.now()
|
||||
|
||||
loaded_at = dm.loaded_at
|
||||
assert loaded_at is not None
|
||||
assert before <= loaded_at <= after
|
||||
dm.close()
|
||||
|
||||
def test_memory_database_not_persistent(self) -> None:
|
||||
"""In-memory database is not persistent."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
assert dm.is_persistent is False
|
||||
assert dm.db_path == ":memory:"
|
||||
dm.close()
|
||||
|
||||
|
||||
class TestDatabaseManagerWalMode:
|
||||
"""Tests for WAL mode in file-based databases."""
|
||||
|
||||
def test_file_database_uses_wal_mode(self, tmp_path: Path) -> None:
|
||||
"""File-based database uses WAL mode for concurrent access."""
|
||||
db_file = tmp_path / "wal_test.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(db_file))
|
||||
|
||||
with dm.cursor() as cur:
|
||||
cur.execute("PRAGMA journal_mode")
|
||||
mode = cur.fetchone()[0]
|
||||
assert mode.lower() == "wal"
|
||||
dm.close()
|
||||
|
||||
def test_memory_database_does_not_use_wal(self) -> None:
|
||||
"""In-memory database doesn't use WAL mode (not applicable)."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
|
||||
with dm.cursor() as cur:
|
||||
cur.execute("PRAGMA journal_mode")
|
||||
mode = cur.fetchone()[0]
|
||||
# Memory databases use 'memory' journal mode
|
||||
assert mode.lower() == "memory"
|
||||
dm.close()
|
||||
|
||||
|
||||
class TestAtomicWrite:
|
||||
"""Tests for atomic write pattern."""
|
||||
|
||||
def test_atomic_write_creates_final_file(self, tmp_path: Path) -> None:
|
||||
"""Database is created at final path after initialization."""
|
||||
target = tmp_path / "atomic_test.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(target))
|
||||
assert target.exists()
|
||||
assert not Path(str(target) + ".tmp").exists()
|
||||
dm.close()
|
||||
|
||||
def test_atomic_write_no_temp_file_remains(self, tmp_path: Path) -> None:
|
||||
"""No .tmp file remains after successful initialization."""
|
||||
target = tmp_path / "atomic_clean.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(target))
|
||||
dm.close()
|
||||
|
||||
# Check no temp files remain
|
||||
temp_files = list(tmp_path.glob("*.tmp"))
|
||||
assert len(temp_files) == 0
|
||||
|
||||
|
||||
class TestSchemaIntrospection:
|
||||
"""Tests for database schema introspection."""
|
||||
|
||||
def test_get_schema_info_returns_all_tables(self) -> None:
|
||||
"""Schema info includes all database tables."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
assert schema["version"] == "0.2.0"
|
||||
table_names = [t["name"] for t in schema["tables"]]
|
||||
assert "projects" in table_names
|
||||
assert "activities" in table_names
|
||||
assert "relationships" in table_names
|
||||
assert "wbs" in table_names
|
||||
assert "calendars" in table_names
|
||||
dm.close()
|
||||
|
||||
def test_get_schema_info_includes_column_details(self) -> None:
|
||||
"""Schema info includes column names, types, and nullable."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||
column_names = [c["name"] for c in activities_table["columns"]]
|
||||
assert "task_id" in column_names
|
||||
assert "task_name" in column_names
|
||||
|
||||
# Check column details
|
||||
task_id_col = next(c for c in activities_table["columns"] if c["name"] == "task_id")
|
||||
assert task_id_col["type"] == "TEXT"
|
||||
# Note: SQLite reports PRIMARY KEY TEXT columns as nullable
|
||||
# but the PRIMARY KEY constraint still applies
|
||||
assert "nullable" in task_id_col
|
||||
|
||||
# Check a NOT NULL column
|
||||
task_name_col = next(c for c in activities_table["columns"] if c["name"] == "task_name")
|
||||
assert task_name_col["nullable"] is False
|
||||
dm.close()
|
||||
|
||||
def test_get_schema_info_includes_row_counts(self) -> None:
|
||||
"""Schema info includes row counts for each table."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
for table in schema["tables"]:
|
||||
assert "row_count" in table
|
||||
assert isinstance(table["row_count"], int)
|
||||
assert table["row_count"] >= 0
|
||||
dm.close()
|
||||
|
||||
def test_schema_info_includes_primary_keys(self) -> None:
|
||||
"""Schema info includes primary key for each table."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||
assert "primary_key" in activities_table
|
||||
assert "task_id" in activities_table["primary_key"]
|
||||
dm.close()
|
||||
|
||||
def test_schema_info_includes_foreign_keys(self) -> None:
|
||||
"""Schema info includes foreign key relationships."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||
assert "foreign_keys" in activities_table
|
||||
|
||||
# activities.proj_id -> projects.proj_id
|
||||
fk = next(
|
||||
(fk for fk in activities_table["foreign_keys"] if fk["column"] == "proj_id"),
|
||||
None,
|
||||
)
|
||||
assert fk is not None
|
||||
assert fk["references_table"] == "projects"
|
||||
assert fk["references_column"] == "proj_id"
|
||||
dm.close()
|
||||
Reference in New Issue
Block a user