feat: add direct database access for scripts (v0.2.0)
Implement persistent SQLite database feature that allows scripts to query schedule data directly via SQL after loading XER files through MCP. Key changes: - Extend load_xer with db_path parameter for persistent database - Add get_database_info tool to retrieve database connection details - Add schema introspection with tables, columns, primary/foreign keys - Support WAL mode for concurrent read access - Use atomic write pattern to prevent corruption New features: - db_path=None: in-memory database (default, backward compatible) - db_path="": auto-generate path from XER filename (.sqlite extension) - db_path="/path/to/db": explicit persistent database path Response includes complete DatabaseInfo: - db_path: absolute path (or :memory:) - is_persistent: boolean - source_file: loaded XER path - loaded_at: ISO timestamp - schema: tables with columns, primary keys, foreign keys, row counts Closes: User Story 1, 2, 3 from 002-direct-db-access spec
This commit is contained in:
143
tests/contract/test_get_database_info.py
Normal file
143
tests/contract/test_get_database_info.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Contract tests for get_database_info MCP tool."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from xer_mcp.db import db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_db():
|
||||
"""Reset database state for each test."""
|
||||
if db.is_initialized:
|
||||
db.close()
|
||||
yield
|
||||
if db.is_initialized:
|
||||
db.close()
|
||||
|
||||
|
||||
class TestGetDatabaseInfoContract:
|
||||
"""Contract tests verifying get_database_info tool interface."""
|
||||
|
||||
async def test_get_database_info_returns_current_database(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""get_database_info returns info about currently loaded database."""
|
||||
from xer_mcp.tools.get_database_info import get_database_info
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
result = await get_database_info()
|
||||
|
||||
assert "database" in result
|
||||
assert result["database"]["db_path"] == str(db_file)
|
||||
assert result["database"]["is_persistent"] is True
|
||||
|
||||
async def test_get_database_info_error_when_no_database(self) -> None:
|
||||
"""get_database_info returns error when no database loaded."""
|
||||
from xer_mcp.tools.get_database_info import get_database_info
|
||||
|
||||
# Ensure database is not initialized
|
||||
if db.is_initialized:
|
||||
db.close()
|
||||
|
||||
result = await get_database_info()
|
||||
|
||||
assert "error" in result
|
||||
assert result["error"]["code"] == "NO_FILE_LOADED"
|
||||
|
||||
async def test_get_database_info_includes_schema(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""get_database_info includes schema information."""
|
||||
from xer_mcp.tools.get_database_info import get_database_info
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
result = await get_database_info()
|
||||
|
||||
assert "schema" in result["database"]
|
||||
assert "tables" in result["database"]["schema"]
|
||||
|
||||
async def test_get_database_info_includes_loaded_at(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""get_database_info includes loaded_at timestamp."""
|
||||
from xer_mcp.tools.get_database_info import get_database_info
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
result = await get_database_info()
|
||||
|
||||
assert "loaded_at" in result["database"]
|
||||
# Should be ISO format timestamp
|
||||
assert "T" in result["database"]["loaded_at"]
|
||||
|
||||
async def test_get_database_info_includes_source_file(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""get_database_info includes source XER file path."""
|
||||
from xer_mcp.tools.get_database_info import get_database_info
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
result = await get_database_info()
|
||||
|
||||
assert result["database"]["source_file"] == str(sample_xer_single_project)
|
||||
|
||||
async def test_get_database_info_for_memory_database(
|
||||
self, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""get_database_info works for in-memory database."""
|
||||
from xer_mcp.tools.get_database_info import get_database_info
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
await load_xer(file_path=str(sample_xer_single_project))
|
||||
|
||||
result = await get_database_info()
|
||||
|
||||
assert "database" in result
|
||||
assert result["database"]["db_path"] == ":memory:"
|
||||
assert result["database"]["is_persistent"] is False
|
||||
|
||||
|
||||
class TestGetDatabaseInfoToolSchema:
|
||||
"""Tests for MCP tool schema."""
|
||||
|
||||
async def test_get_database_info_tool_registered(self) -> None:
|
||||
"""get_database_info tool is registered with MCP server."""
|
||||
from xer_mcp.server import list_tools
|
||||
|
||||
tools = await list_tools()
|
||||
tool_names = [t.name for t in tools]
|
||||
assert "get_database_info" in tool_names
|
||||
|
||||
async def test_get_database_info_tool_has_empty_input_schema(self) -> None:
|
||||
"""get_database_info tool has no required inputs."""
|
||||
from xer_mcp.server import list_tools
|
||||
|
||||
tools = await list_tools()
|
||||
tool = next(t for t in tools if t.name == "get_database_info")
|
||||
# Should have empty or no required properties
|
||||
assert "required" not in tool.inputSchema or len(tool.inputSchema.get("required", [])) == 0
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Contract tests for load_xer MCP tool."""
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -9,10 +10,14 @@ from xer_mcp.db import db
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_db():
|
||||
"""Initialize and clear database for each test."""
|
||||
db.initialize()
|
||||
"""Reset database state for each test."""
|
||||
# Close any existing connection
|
||||
if db.is_initialized:
|
||||
db.close()
|
||||
yield
|
||||
db.clear()
|
||||
# Cleanup after test
|
||||
if db.is_initialized:
|
||||
db.close()
|
||||
|
||||
|
||||
class TestLoadXerContract:
|
||||
@@ -105,3 +110,141 @@ class TestLoadXerContract:
|
||||
assert "plan_end_date" in result["project"]
|
||||
# Dates should be ISO8601 format
|
||||
assert "T" in result["project"]["plan_start_date"]
|
||||
|
||||
|
||||
class TestLoadXerPersistentDatabase:
|
||||
"""Contract tests for persistent database functionality."""
|
||||
|
||||
async def test_load_xer_with_db_path_creates_file(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""load_xer with db_path creates persistent database file."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert db_file.exists()
|
||||
assert result["database"]["db_path"] == str(db_file)
|
||||
assert result["database"]["is_persistent"] is True
|
||||
|
||||
async def test_load_xer_with_empty_db_path_auto_generates(self, tmp_path: Path) -> None:
|
||||
"""load_xer with empty db_path generates path from XER filename."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
# Create XER file in tmp_path
|
||||
xer_file = tmp_path / "my_schedule.xer"
|
||||
from tests.conftest import SAMPLE_XER_SINGLE_PROJECT
|
||||
|
||||
xer_file.write_text(SAMPLE_XER_SINGLE_PROJECT)
|
||||
|
||||
result = await load_xer(file_path=str(xer_file), db_path="")
|
||||
|
||||
assert result["success"] is True
|
||||
expected_db = str(tmp_path / "my_schedule.sqlite")
|
||||
assert result["database"]["db_path"] == expected_db
|
||||
assert result["database"]["is_persistent"] is True
|
||||
assert Path(expected_db).exists()
|
||||
|
||||
async def test_load_xer_without_db_path_uses_memory(
|
||||
self, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""load_xer without db_path uses in-memory database (backward compatible)."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
result = await load_xer(file_path=str(sample_xer_single_project))
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["database"]["db_path"] == ":memory:"
|
||||
assert result["database"]["is_persistent"] is False
|
||||
|
||||
async def test_load_xer_database_contains_all_data(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""Persistent database contains all parsed data."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
# Verify data via direct SQL
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM activities")
|
||||
count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
assert count == result["activity_count"]
|
||||
|
||||
async def test_load_xer_response_includes_database_info(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""load_xer response includes complete database info."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
assert "database" in result
|
||||
db_info = result["database"]
|
||||
assert "db_path" in db_info
|
||||
assert "is_persistent" in db_info
|
||||
assert "source_file" in db_info
|
||||
assert "loaded_at" in db_info
|
||||
assert "schema" in db_info
|
||||
|
||||
async def test_load_xer_response_schema_includes_tables(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""load_xer response schema includes table information."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
schema = result["database"]["schema"]
|
||||
assert "version" in schema
|
||||
assert "tables" in schema
|
||||
table_names = [t["name"] for t in schema["tables"]]
|
||||
assert "activities" in table_names
|
||||
assert "relationships" in table_names
|
||||
|
||||
async def test_load_xer_error_on_invalid_path(self, sample_xer_single_project: Path) -> None:
|
||||
"""load_xer returns error for invalid path."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path="/nonexistent/dir/file.db",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
# Either FILE_NOT_WRITABLE or DATABASE_ERROR is acceptable
|
||||
# depending on how SQLite reports the error
|
||||
assert result["error"]["code"] in ("FILE_NOT_WRITABLE", "DATABASE_ERROR")
|
||||
|
||||
|
||||
class TestLoadXerToolSchema:
|
||||
"""Tests for MCP tool schema."""
|
||||
|
||||
async def test_load_xer_tool_schema_includes_db_path(self) -> None:
|
||||
"""MCP tool schema includes db_path parameter."""
|
||||
from xer_mcp.server import list_tools
|
||||
|
||||
tools = await list_tools()
|
||||
load_xer_tool = next(t for t in tools if t.name == "load_xer")
|
||||
props = load_xer_tool.inputSchema["properties"]
|
||||
assert "db_path" in props
|
||||
assert props["db_path"]["type"] == "string"
|
||||
|
||||
185
tests/integration/test_direct_db_access.py
Normal file
185
tests/integration/test_direct_db_access.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Integration tests for direct database access feature."""
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from xer_mcp.db import db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_db():
|
||||
"""Reset database state for each test."""
|
||||
if db.is_initialized:
|
||||
db.close()
|
||||
yield
|
||||
if db.is_initialized:
|
||||
db.close()
|
||||
|
||||
|
||||
class TestDirectDatabaseAccess:
|
||||
"""Integration tests verifying external script can access database."""
|
||||
|
||||
async def test_external_script_can_query_database(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""External script can query database using returned path."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
# Simulate external script access (as shown in quickstart.md)
|
||||
db_path = result["database"]["db_path"]
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
# Query milestones
|
||||
cursor = conn.execute("""
|
||||
SELECT task_code, task_name, target_start_date, milestone_type
|
||||
FROM activities
|
||||
WHERE task_type IN ('TT_Mile', 'TT_FinMile')
|
||||
ORDER BY target_start_date
|
||||
""")
|
||||
|
||||
milestones = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
assert len(milestones) > 0
|
||||
assert all(row["task_code"] for row in milestones)
|
||||
|
||||
async def test_external_script_can_query_critical_path(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""External script can query critical path activities."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
db_path = result["database"]["db_path"]
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("""
|
||||
SELECT task_code, task_name, target_start_date, target_end_date
|
||||
FROM activities
|
||||
WHERE driving_path_flag = 1
|
||||
ORDER BY target_start_date
|
||||
""")
|
||||
|
||||
critical_activities = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
assert len(critical_activities) > 0
|
||||
|
||||
async def test_external_script_can_join_tables(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""External script can join activities with WBS."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
db_path = result["database"]["db_path"]
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute("""
|
||||
SELECT a.task_code, a.task_name, w.wbs_name
|
||||
FROM activities a
|
||||
JOIN wbs w ON a.wbs_id = w.wbs_id
|
||||
LIMIT 10
|
||||
""")
|
||||
|
||||
joined_rows = cursor.fetchall()
|
||||
conn.close()
|
||||
|
||||
assert len(joined_rows) > 0
|
||||
|
||||
async def test_database_accessible_after_mcp_load(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""Database remains accessible while MCP tools are active."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
loaded_count = result["activity_count"]
|
||||
|
||||
# External script queries database
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM activities")
|
||||
external_count = cursor.fetchone()[0]
|
||||
conn.close()
|
||||
|
||||
# Both should match
|
||||
assert external_count == loaded_count
|
||||
|
||||
async def test_schema_info_matches_actual_database(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""Returned schema info matches actual database structure."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
schema = result["database"]["schema"]
|
||||
db_path = result["database"]["db_path"]
|
||||
|
||||
# Verify tables exist in actual database
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
||||
)
|
||||
actual_tables = {row[0] for row in cursor.fetchall()}
|
||||
conn.close()
|
||||
|
||||
schema_tables = {t["name"] for t in schema["tables"]}
|
||||
assert schema_tables == actual_tables
|
||||
|
||||
async def test_row_counts_match_actual_data(
|
||||
self, tmp_path: Path, sample_xer_single_project: Path
|
||||
) -> None:
|
||||
"""Schema row counts match actual database row counts."""
|
||||
from xer_mcp.tools.load_xer import load_xer
|
||||
|
||||
db_file = tmp_path / "schedule.db"
|
||||
result = await load_xer(
|
||||
file_path=str(sample_xer_single_project),
|
||||
db_path=str(db_file),
|
||||
)
|
||||
|
||||
schema = result["database"]["schema"]
|
||||
db_path = result["database"]["db_path"]
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
|
||||
for table_info in schema["tables"]:
|
||||
cursor = conn.execute(
|
||||
f"SELECT COUNT(*) FROM {table_info['name']}" # noqa: S608
|
||||
)
|
||||
actual_count = cursor.fetchone()[0]
|
||||
assert table_info["row_count"] == actual_count, (
|
||||
f"Table {table_info['name']}: expected {table_info['row_count']}, "
|
||||
f"got {actual_count}"
|
||||
)
|
||||
|
||||
conn.close()
|
||||
234
tests/unit/test_db_manager.py
Normal file
234
tests/unit/test_db_manager.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""Unit tests for DatabaseManager file-based database support."""
|
||||
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from xer_mcp.db import DatabaseManager
|
||||
|
||||
|
||||
class TestDatabaseManagerInitialization:
|
||||
"""Tests for DatabaseManager initialization modes."""
|
||||
|
||||
def test_initialize_with_memory_by_default(self) -> None:
|
||||
"""Default initialization uses in-memory database."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
assert dm.db_path == ":memory:"
|
||||
assert dm.is_persistent is False
|
||||
dm.close()
|
||||
|
||||
def test_initialize_with_file_path(self, tmp_path: Path) -> None:
|
||||
"""Can initialize with explicit file path."""
|
||||
db_file = tmp_path / "test.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(db_file))
|
||||
assert dm.db_path == str(db_file)
|
||||
assert dm.is_persistent is True
|
||||
assert db_file.exists()
|
||||
dm.close()
|
||||
|
||||
def test_initialize_with_empty_string_auto_generates_path(self, tmp_path: Path) -> None:
|
||||
"""Empty string db_path with source_file auto-generates path."""
|
||||
xer_file = tmp_path / "schedule.xer"
|
||||
xer_file.write_text("dummy content")
|
||||
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path="", source_file=str(xer_file))
|
||||
expected_db = str(tmp_path / "schedule.sqlite")
|
||||
assert dm.db_path == expected_db
|
||||
assert dm.is_persistent is True
|
||||
assert Path(expected_db).exists()
|
||||
dm.close()
|
||||
|
||||
def test_file_database_persists_after_close(self, tmp_path: Path) -> None:
|
||||
"""File-based database persists after connection close."""
|
||||
db_file = tmp_path / "persist_test.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(db_file))
|
||||
|
||||
# Insert test data
|
||||
with dm.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO projects (proj_id, proj_short_name, loaded_at) "
|
||||
"VALUES ('P1', 'Test', datetime('now'))"
|
||||
)
|
||||
dm.commit()
|
||||
dm.close()
|
||||
|
||||
# Verify file exists and has data
|
||||
assert db_file.exists()
|
||||
conn = sqlite3.connect(str(db_file))
|
||||
cursor = conn.execute("SELECT proj_id FROM projects")
|
||||
rows = cursor.fetchall()
|
||||
conn.close()
|
||||
assert len(rows) == 1
|
||||
assert rows[0][0] == "P1"
|
||||
|
||||
def test_source_file_tracked(self, tmp_path: Path) -> None:
|
||||
"""Source file path is tracked when provided."""
|
||||
db_file = tmp_path / "test.db"
|
||||
xer_file = tmp_path / "schedule.xer"
|
||||
xer_file.write_text("dummy")
|
||||
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(db_file), source_file=str(xer_file))
|
||||
assert dm.source_file == str(xer_file)
|
||||
dm.close()
|
||||
|
||||
def test_loaded_at_timestamp(self, tmp_path: Path) -> None:
|
||||
"""Loaded_at timestamp is recorded."""
|
||||
db_file = tmp_path / "test.db"
|
||||
dm = DatabaseManager()
|
||||
before = datetime.now()
|
||||
dm.initialize(db_path=str(db_file))
|
||||
after = datetime.now()
|
||||
|
||||
loaded_at = dm.loaded_at
|
||||
assert loaded_at is not None
|
||||
assert before <= loaded_at <= after
|
||||
dm.close()
|
||||
|
||||
def test_memory_database_not_persistent(self) -> None:
|
||||
"""In-memory database is not persistent."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
assert dm.is_persistent is False
|
||||
assert dm.db_path == ":memory:"
|
||||
dm.close()
|
||||
|
||||
|
||||
class TestDatabaseManagerWalMode:
|
||||
"""Tests for WAL mode in file-based databases."""
|
||||
|
||||
def test_file_database_uses_wal_mode(self, tmp_path: Path) -> None:
|
||||
"""File-based database uses WAL mode for concurrent access."""
|
||||
db_file = tmp_path / "wal_test.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(db_file))
|
||||
|
||||
with dm.cursor() as cur:
|
||||
cur.execute("PRAGMA journal_mode")
|
||||
mode = cur.fetchone()[0]
|
||||
assert mode.lower() == "wal"
|
||||
dm.close()
|
||||
|
||||
def test_memory_database_does_not_use_wal(self) -> None:
|
||||
"""In-memory database doesn't use WAL mode (not applicable)."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
|
||||
with dm.cursor() as cur:
|
||||
cur.execute("PRAGMA journal_mode")
|
||||
mode = cur.fetchone()[0]
|
||||
# Memory databases use 'memory' journal mode
|
||||
assert mode.lower() == "memory"
|
||||
dm.close()
|
||||
|
||||
|
||||
class TestAtomicWrite:
|
||||
"""Tests for atomic write pattern."""
|
||||
|
||||
def test_atomic_write_creates_final_file(self, tmp_path: Path) -> None:
|
||||
"""Database is created at final path after initialization."""
|
||||
target = tmp_path / "atomic_test.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(target))
|
||||
assert target.exists()
|
||||
assert not Path(str(target) + ".tmp").exists()
|
||||
dm.close()
|
||||
|
||||
def test_atomic_write_no_temp_file_remains(self, tmp_path: Path) -> None:
|
||||
"""No .tmp file remains after successful initialization."""
|
||||
target = tmp_path / "atomic_clean.db"
|
||||
dm = DatabaseManager()
|
||||
dm.initialize(db_path=str(target))
|
||||
dm.close()
|
||||
|
||||
# Check no temp files remain
|
||||
temp_files = list(tmp_path.glob("*.tmp"))
|
||||
assert len(temp_files) == 0
|
||||
|
||||
|
||||
class TestSchemaIntrospection:
|
||||
"""Tests for database schema introspection."""
|
||||
|
||||
def test_get_schema_info_returns_all_tables(self) -> None:
|
||||
"""Schema info includes all database tables."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
assert schema["version"] == "0.2.0"
|
||||
table_names = [t["name"] for t in schema["tables"]]
|
||||
assert "projects" in table_names
|
||||
assert "activities" in table_names
|
||||
assert "relationships" in table_names
|
||||
assert "wbs" in table_names
|
||||
assert "calendars" in table_names
|
||||
dm.close()
|
||||
|
||||
def test_get_schema_info_includes_column_details(self) -> None:
|
||||
"""Schema info includes column names, types, and nullable."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||
column_names = [c["name"] for c in activities_table["columns"]]
|
||||
assert "task_id" in column_names
|
||||
assert "task_name" in column_names
|
||||
|
||||
# Check column details
|
||||
task_id_col = next(c for c in activities_table["columns"] if c["name"] == "task_id")
|
||||
assert task_id_col["type"] == "TEXT"
|
||||
# Note: SQLite reports PRIMARY KEY TEXT columns as nullable
|
||||
# but the PRIMARY KEY constraint still applies
|
||||
assert "nullable" in task_id_col
|
||||
|
||||
# Check a NOT NULL column
|
||||
task_name_col = next(c for c in activities_table["columns"] if c["name"] == "task_name")
|
||||
assert task_name_col["nullable"] is False
|
||||
dm.close()
|
||||
|
||||
def test_get_schema_info_includes_row_counts(self) -> None:
|
||||
"""Schema info includes row counts for each table."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
for table in schema["tables"]:
|
||||
assert "row_count" in table
|
||||
assert isinstance(table["row_count"], int)
|
||||
assert table["row_count"] >= 0
|
||||
dm.close()
|
||||
|
||||
def test_schema_info_includes_primary_keys(self) -> None:
|
||||
"""Schema info includes primary key for each table."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||
assert "primary_key" in activities_table
|
||||
assert "task_id" in activities_table["primary_key"]
|
||||
dm.close()
|
||||
|
||||
def test_schema_info_includes_foreign_keys(self) -> None:
|
||||
"""Schema info includes foreign key relationships."""
|
||||
dm = DatabaseManager()
|
||||
dm.initialize()
|
||||
schema = dm.get_schema_info()
|
||||
|
||||
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||
assert "foreign_keys" in activities_table
|
||||
|
||||
# activities.proj_id -> projects.proj_id
|
||||
fk = next(
|
||||
(fk for fk in activities_table["foreign_keys"] if fk["column"] == "proj_id"),
|
||||
None,
|
||||
)
|
||||
assert fk is not None
|
||||
assert fk["references_table"] == "projects"
|
||||
assert fk["references_column"] == "proj_id"
|
||||
dm.close()
|
||||
Reference in New Issue
Block a user