feat: add direct database access for scripts (v0.2.0)

Implement persistent SQLite database feature that allows scripts to
query schedule data directly via SQL after loading XER files through MCP.

Key changes:
- Extend load_xer with db_path parameter for persistent database
- Add get_database_info tool to retrieve database connection details
- Add schema introspection with tables, columns, primary/foreign keys
- Support WAL mode for concurrent read access
- Use atomic write pattern to prevent corruption

New features:
- db_path=None: in-memory database (default, backward compatible)
- db_path="": auto-generate path from XER filename (.sqlite extension)
- db_path="/path/to/db": explicit persistent database path

Response includes complete DatabaseInfo:
- db_path: absolute path (or :memory:)
- is_persistent: boolean
- source_file: loaded XER path
- loaded_at: ISO timestamp
- schema: tables with columns, primary keys, foreign keys, row counts

Closes: User Story 1, 2, 3 from 002-direct-db-access spec
This commit is contained in:
2026-01-08 12:54:56 -05:00
parent 3e7ad39eb8
commit d6a79bf24a
11 changed files with 1064 additions and 40 deletions

View File

@@ -1,6 +1,6 @@
[project]
name = "xer-mcp"
version = "0.1.0"
version = "0.2.0"
description = "MCP server for querying Primavera P6 XER schedule data"
readme = "README.md"
requires-python = ">=3.14"

View File

@@ -1,24 +1,66 @@
"""Database connection management for XER MCP Server."""
import os
import sqlite3
from collections.abc import Generator
from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from xer_mcp.db.schema import get_schema
# Schema version for introspection responses
SCHEMA_VERSION = "0.2.0"
class DatabaseManager:
"""Manages SQLite database connections and schema initialization."""
def __init__(self) -> None:
"""Initialize database manager with in-memory database."""
"""Initialize database manager."""
self._connection: sqlite3.Connection | None = None
self._db_path: str = ":memory:"
self._source_file: str | None = None
self._loaded_at: datetime | None = None
def initialize(self) -> None:
"""Initialize the in-memory database with schema."""
def initialize(
self,
db_path: str | None = None,
source_file: str | None = None,
) -> None:
"""Initialize the database with schema.
Args:
db_path: Path for database file. If None or omitted, uses in-memory.
If empty string, auto-generates from source_file.
source_file: Path to the XER file being loaded (for tracking).
"""
self._source_file = source_file
self._loaded_at = datetime.now()
# Determine database path
if db_path is None:
# Default: in-memory database
self._db_path = ":memory:"
elif db_path == "":
# Auto-generate from source file
if source_file:
base = Path(source_file).with_suffix(".sqlite")
self._db_path = str(base)
else:
self._db_path = ":memory:"
else:
# Use provided path
self._db_path = db_path
# Create database
if self._db_path == ":memory:":
self._connection = sqlite3.connect(":memory:", check_same_thread=False)
self._connection.row_factory = sqlite3.Row
self._connection.executescript(get_schema())
else:
# File-based database with atomic write pattern
self._create_file_database()
def clear(self) -> None:
"""Clear all data from the database."""
@@ -61,6 +103,159 @@ class DatabaseManager:
"""Check if the database is initialized."""
return self._connection is not None
@property
def db_path(self) -> str:
"""Get the database path."""
return self._db_path
@property
def is_persistent(self) -> bool:
"""Check if the database is file-based (persistent)."""
return self._db_path != ":memory:"
@property
def source_file(self) -> str | None:
"""Get the source XER file path."""
return self._source_file
@property
def loaded_at(self) -> datetime | None:
"""Get the timestamp when data was loaded."""
return self._loaded_at
def _create_file_database(self) -> None:
"""Create a file-based database with atomic write pattern."""
temp_path = self._db_path + ".tmp"
try:
# Create database at temp path
conn = sqlite3.connect(temp_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.executescript(get_schema())
conn.commit()
conn.close()
# Atomic rename (POSIX)
if os.path.exists(self._db_path):
os.unlink(self._db_path)
os.rename(temp_path, self._db_path)
# Open final database with WAL mode
self._connection = sqlite3.connect(self._db_path, check_same_thread=False)
self._connection.row_factory = sqlite3.Row
self._connection.execute("PRAGMA journal_mode=WAL")
except Exception:
# Clean up temp file on failure
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def get_schema_info(self) -> dict:
"""Get database schema information for introspection.
Returns:
Dictionary with schema version and table information.
"""
if self._connection is None:
raise RuntimeError("Database not initialized. Call initialize() first.")
tables = []
# Get all tables (excluding sqlite internal tables)
cursor = self._connection.execute(
"SELECT name FROM sqlite_master WHERE type='table' "
"AND name NOT LIKE 'sqlite_%' ORDER BY name"
)
table_names = [row[0] for row in cursor.fetchall()]
for table_name in table_names:
table_info = self._get_table_info(table_name)
tables.append(table_info)
return {
"version": SCHEMA_VERSION,
"tables": tables,
}
def _get_table_info(self, table_name: str) -> dict:
"""Get detailed information about a table.
Args:
table_name: Name of the table.
Returns:
Dictionary with table name, columns, primary keys, foreign keys, row count.
"""
columns = []
primary_key = []
# Get column info
cursor = self._connection.execute(f"PRAGMA table_info({table_name})") # noqa: S608
for row in cursor.fetchall():
col_name = row[1]
col_type = row[2] or "TEXT"
not_null = bool(row[3])
default_val = row[4]
is_pk = bool(row[5])
col_info: dict = {
"name": col_name,
"type": col_type,
"nullable": not not_null,
}
if default_val is not None:
col_info["default"] = str(default_val)
columns.append(col_info)
if is_pk:
primary_key.append(col_name)
# Get foreign keys
foreign_keys = []
cursor = self._connection.execute(
f"PRAGMA foreign_key_list({table_name})" # noqa: S608
)
for row in cursor.fetchall():
fk_info = {
"column": row[3], # from column
"references_table": row[2], # table
"references_column": row[4], # to column
}
foreign_keys.append(fk_info)
# Get row count
cursor = self._connection.execute(
f"SELECT COUNT(*) FROM {table_name}" # noqa: S608
)
row_count = cursor.fetchone()[0]
return {
"name": table_name,
"columns": columns,
"primary_key": primary_key,
"foreign_keys": foreign_keys,
"row_count": row_count,
}
def get_database_info(self) -> dict:
"""Get complete database information for API responses.
Returns:
Dictionary with database path, persistence status, source file,
loaded timestamp, and schema information.
"""
if not self.is_initialized:
raise RuntimeError("Database not initialized. Call initialize() first.")
return {
"db_path": self._db_path,
"is_persistent": self.is_persistent,
"source_file": self._source_file,
"loaded_at": self._loaded_at.isoformat() if self._loaded_at else None,
"schema": self.get_schema_info(),
}
# Global database manager instance
db = DatabaseManager()

View File

@@ -239,7 +239,8 @@ def query_relationships(
lag_hours=row[6] or 0.0,
pred_type=pred_type,
)
relationships.append({
relationships.append(
{
"task_pred_id": row[0],
"task_id": row[1],
"task_name": row[2],
@@ -248,7 +249,8 @@ def query_relationships(
"pred_type": pred_type,
"lag_hr_cnt": row[6],
"driving": driving,
})
}
)
return relationships, total
@@ -298,14 +300,16 @@ def get_predecessors(activity_id: str) -> list[dict]:
lag_hours=row[4] or 0.0,
pred_type=pred_type,
)
result.append({
result.append(
{
"task_id": row[0],
"task_code": row[1],
"task_name": row[2],
"relationship_type": pred_type,
"lag_hr_cnt": row[4],
"driving": driving,
})
}
)
return result
@@ -355,14 +359,16 @@ def get_successors(activity_id: str) -> list[dict]:
lag_hours=row[4] or 0.0,
pred_type=pred_type,
)
result.append({
result.append(
{
"task_id": row[0],
"task_code": row[1],
"task_name": row[2],
"relationship_type": pred_type,
"lag_hr_cnt": row[4],
"driving": driving,
})
}
)
return result

View File

@@ -56,3 +56,30 @@ class ActivityNotFoundError(XerMcpError):
"ACTIVITY_NOT_FOUND",
f"Activity not found: {activity_id}",
)
class FileNotWritableError(XerMcpError):
"""Raised when the database file path is not writable."""
def __init__(self, path: str, reason: str = "") -> None:
msg = f"Cannot write to database file: {path}"
if reason:
msg = f"{msg} ({reason})"
super().__init__("FILE_NOT_WRITABLE", msg)
class DiskFullError(XerMcpError):
"""Raised when there is insufficient disk space."""
def __init__(self, path: str) -> None:
super().__init__(
"DISK_FULL",
f"Insufficient disk space to create database: {path}",
)
class DatabaseError(XerMcpError):
"""Raised for general database errors."""
def __init__(self, message: str) -> None:
super().__init__("DATABASE_ERROR", message)

View File

@@ -50,6 +50,12 @@ async def list_tools() -> list[Tool]:
"type": "string",
"description": "Project ID to select (required for multi-project files)",
},
"db_path": {
"type": "string",
"description": "Path for persistent SQLite database file. "
"If omitted, uses in-memory database. "
"If empty string, auto-generates path from XER filename (same directory, .sqlite extension).",
},
},
"required": ["file_path"],
},
@@ -183,6 +189,15 @@ async def list_tools() -> list[Tool]:
"properties": {},
},
),
Tool(
name="get_database_info",
description="Get information about the currently loaded database including file path and schema. "
"Use this to get connection details for direct SQL access.",
inputSchema={
"type": "object",
"properties": {},
},
),
]
@@ -197,6 +212,7 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
result = await load_xer(
file_path=arguments["file_path"],
project_id=arguments.get("project_id"),
db_path=arguments.get("db_path"),
)
return [TextContent(type="text", text=json.dumps(result, indent=2))]
@@ -258,6 +274,12 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
result = await get_critical_path()
return [TextContent(type="text", text=json.dumps(result, indent=2))]
if name == "get_database_info":
from xer_mcp.tools.get_database_info import get_database_info
result = await get_database_info()
return [TextContent(type="text", text=json.dumps(result, indent=2))]
raise ValueError(f"Unknown tool: {name}")

View File

@@ -0,0 +1,25 @@
"""get_database_info MCP tool implementation."""
from xer_mcp.db import db
async def get_database_info() -> dict:
"""Get information about the currently loaded database.
Returns connection details for direct SQL access including
database path, schema information, and metadata.
Returns:
Dictionary with database info or error if no database loaded
"""
if not db.is_initialized:
return {
"error": {
"code": "NO_FILE_LOADED",
"message": "No XER file is loaded. Use the load_xer tool first.",
}
}
return {
"database": db.get_database_info(),
}

View File

@@ -1,5 +1,9 @@
"""load_xer MCP tool implementation."""
import errno
import os
import sqlite3
from xer_mcp.db import db
from xer_mcp.db.loader import get_activity_count, get_relationship_count, load_parsed_data
from xer_mcp.errors import FileNotFoundError, ParseError
@@ -7,19 +11,55 @@ from xer_mcp.parser.xer_parser import XerParser
from xer_mcp.server import set_file_loaded
async def load_xer(file_path: str, project_id: str | None = None) -> dict:
async def load_xer(
file_path: str,
project_id: str | None = None,
db_path: str | None = None,
) -> dict:
"""Load a Primavera P6 XER file and parse its schedule data.
Args:
file_path: Absolute path to the XER file
project_id: Project ID to select (required for multi-project files)
db_path: Path for persistent database file. If None, uses in-memory.
If empty string, auto-generates from XER filename.
Returns:
Dictionary with success status and project info or error details
"""
# Ensure database is initialized
if not db.is_initialized:
db.initialize()
# Initialize database with specified path
try:
db.initialize(db_path=db_path, source_file=file_path)
except PermissionError:
target = db_path if db_path else file_path
return {
"success": False,
"error": {"code": "FILE_NOT_WRITABLE", "message": f"Cannot write database: {target}"},
}
except OSError as e:
target = db_path if db_path else file_path
if e.errno == errno.ENOSPC:
return {
"success": False,
"error": {"code": "DISK_FULL", "message": f"Insufficient disk space: {target}"},
}
if e.errno == errno.ENOENT:
return {
"success": False,
"error": {
"code": "FILE_NOT_WRITABLE",
"message": f"Directory does not exist: {os.path.dirname(target)}",
},
}
return {
"success": False,
"error": {"code": "DATABASE_ERROR", "message": str(e)},
}
except sqlite3.Error as e:
return {
"success": False,
"error": {"code": "DATABASE_ERROR", "message": str(e)},
}
parser = XerParser()
@@ -73,6 +113,9 @@ async def load_xer(file_path: str, project_id: str | None = None) -> dict:
activity_count = get_activity_count()
relationship_count = get_relationship_count()
# Get database info
database_info = db.get_database_info()
return {
"success": True,
"project": {
@@ -83,4 +126,5 @@ async def load_xer(file_path: str, project_id: str | None = None) -> dict:
},
"activity_count": activity_count,
"relationship_count": relationship_count,
"database": database_info,
}

View File

@@ -0,0 +1,143 @@
"""Contract tests for get_database_info MCP tool."""
from pathlib import Path
import pytest
from xer_mcp.db import db
@pytest.fixture(autouse=True)
def setup_db():
"""Reset database state for each test."""
if db.is_initialized:
db.close()
yield
if db.is_initialized:
db.close()
class TestGetDatabaseInfoContract:
"""Contract tests verifying get_database_info tool interface."""
async def test_get_database_info_returns_current_database(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""get_database_info returns info about currently loaded database."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
result = await get_database_info()
assert "database" in result
assert result["database"]["db_path"] == str(db_file)
assert result["database"]["is_persistent"] is True
async def test_get_database_info_error_when_no_database(self) -> None:
"""get_database_info returns error when no database loaded."""
from xer_mcp.tools.get_database_info import get_database_info
# Ensure database is not initialized
if db.is_initialized:
db.close()
result = await get_database_info()
assert "error" in result
assert result["error"]["code"] == "NO_FILE_LOADED"
async def test_get_database_info_includes_schema(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""get_database_info includes schema information."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
result = await get_database_info()
assert "schema" in result["database"]
assert "tables" in result["database"]["schema"]
async def test_get_database_info_includes_loaded_at(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""get_database_info includes loaded_at timestamp."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
result = await get_database_info()
assert "loaded_at" in result["database"]
# Should be ISO format timestamp
assert "T" in result["database"]["loaded_at"]
async def test_get_database_info_includes_source_file(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""get_database_info includes source XER file path."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
result = await get_database_info()
assert result["database"]["source_file"] == str(sample_xer_single_project)
async def test_get_database_info_for_memory_database(
self, sample_xer_single_project: Path
) -> None:
"""get_database_info works for in-memory database."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
await load_xer(file_path=str(sample_xer_single_project))
result = await get_database_info()
assert "database" in result
assert result["database"]["db_path"] == ":memory:"
assert result["database"]["is_persistent"] is False
class TestGetDatabaseInfoToolSchema:
"""Tests for MCP tool schema."""
async def test_get_database_info_tool_registered(self) -> None:
"""get_database_info tool is registered with MCP server."""
from xer_mcp.server import list_tools
tools = await list_tools()
tool_names = [t.name for t in tools]
assert "get_database_info" in tool_names
async def test_get_database_info_tool_has_empty_input_schema(self) -> None:
"""get_database_info tool has no required inputs."""
from xer_mcp.server import list_tools
tools = await list_tools()
tool = next(t for t in tools if t.name == "get_database_info")
# Should have empty or no required properties
assert "required" not in tool.inputSchema or len(tool.inputSchema.get("required", [])) == 0

View File

@@ -1,5 +1,6 @@
"""Contract tests for load_xer MCP tool."""
import sqlite3
from pathlib import Path
import pytest
@@ -9,10 +10,14 @@ from xer_mcp.db import db
@pytest.fixture(autouse=True)
def setup_db():
"""Initialize and clear database for each test."""
db.initialize()
"""Reset database state for each test."""
# Close any existing connection
if db.is_initialized:
db.close()
yield
db.clear()
# Cleanup after test
if db.is_initialized:
db.close()
class TestLoadXerContract:
@@ -105,3 +110,141 @@ class TestLoadXerContract:
assert "plan_end_date" in result["project"]
# Dates should be ISO8601 format
assert "T" in result["project"]["plan_start_date"]
class TestLoadXerPersistentDatabase:
"""Contract tests for persistent database functionality."""
async def test_load_xer_with_db_path_creates_file(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""load_xer with db_path creates persistent database file."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
assert result["success"] is True
assert db_file.exists()
assert result["database"]["db_path"] == str(db_file)
assert result["database"]["is_persistent"] is True
async def test_load_xer_with_empty_db_path_auto_generates(self, tmp_path: Path) -> None:
"""load_xer with empty db_path generates path from XER filename."""
from xer_mcp.tools.load_xer import load_xer
# Create XER file in tmp_path
xer_file = tmp_path / "my_schedule.xer"
from tests.conftest import SAMPLE_XER_SINGLE_PROJECT
xer_file.write_text(SAMPLE_XER_SINGLE_PROJECT)
result = await load_xer(file_path=str(xer_file), db_path="")
assert result["success"] is True
expected_db = str(tmp_path / "my_schedule.sqlite")
assert result["database"]["db_path"] == expected_db
assert result["database"]["is_persistent"] is True
assert Path(expected_db).exists()
async def test_load_xer_without_db_path_uses_memory(
self, sample_xer_single_project: Path
) -> None:
"""load_xer without db_path uses in-memory database (backward compatible)."""
from xer_mcp.tools.load_xer import load_xer
result = await load_xer(file_path=str(sample_xer_single_project))
assert result["success"] is True
assert result["database"]["db_path"] == ":memory:"
assert result["database"]["is_persistent"] is False
async def test_load_xer_database_contains_all_data(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""Persistent database contains all parsed data."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
# Verify data via direct SQL
conn = sqlite3.connect(str(db_file))
cursor = conn.execute("SELECT COUNT(*) FROM activities")
count = cursor.fetchone()[0]
conn.close()
assert count == result["activity_count"]
async def test_load_xer_response_includes_database_info(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""load_xer response includes complete database info."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
assert "database" in result
db_info = result["database"]
assert "db_path" in db_info
assert "is_persistent" in db_info
assert "source_file" in db_info
assert "loaded_at" in db_info
assert "schema" in db_info
async def test_load_xer_response_schema_includes_tables(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""load_xer response schema includes table information."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
schema = result["database"]["schema"]
assert "version" in schema
assert "tables" in schema
table_names = [t["name"] for t in schema["tables"]]
assert "activities" in table_names
assert "relationships" in table_names
async def test_load_xer_error_on_invalid_path(self, sample_xer_single_project: Path) -> None:
"""load_xer returns error for invalid path."""
from xer_mcp.tools.load_xer import load_xer
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path="/nonexistent/dir/file.db",
)
assert result["success"] is False
# Either FILE_NOT_WRITABLE or DATABASE_ERROR is acceptable
# depending on how SQLite reports the error
assert result["error"]["code"] in ("FILE_NOT_WRITABLE", "DATABASE_ERROR")
class TestLoadXerToolSchema:
"""Tests for MCP tool schema."""
async def test_load_xer_tool_schema_includes_db_path(self) -> None:
"""MCP tool schema includes db_path parameter."""
from xer_mcp.server import list_tools
tools = await list_tools()
load_xer_tool = next(t for t in tools if t.name == "load_xer")
props = load_xer_tool.inputSchema["properties"]
assert "db_path" in props
assert props["db_path"]["type"] == "string"

View File

@@ -0,0 +1,185 @@
"""Integration tests for direct database access feature."""
import sqlite3
from pathlib import Path
import pytest
from xer_mcp.db import db
@pytest.fixture(autouse=True)
def setup_db():
"""Reset database state for each test."""
if db.is_initialized:
db.close()
yield
if db.is_initialized:
db.close()
class TestDirectDatabaseAccess:
"""Integration tests verifying external script can access database."""
async def test_external_script_can_query_database(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""External script can query database using returned path."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
# Simulate external script access (as shown in quickstart.md)
db_path = result["database"]["db_path"]
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
# Query milestones
cursor = conn.execute("""
SELECT task_code, task_name, target_start_date, milestone_type
FROM activities
WHERE task_type IN ('TT_Mile', 'TT_FinMile')
ORDER BY target_start_date
""")
milestones = cursor.fetchall()
conn.close()
assert len(milestones) > 0
assert all(row["task_code"] for row in milestones)
async def test_external_script_can_query_critical_path(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""External script can query critical path activities."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
db_path = result["database"]["db_path"]
conn = sqlite3.connect(db_path)
cursor = conn.execute("""
SELECT task_code, task_name, target_start_date, target_end_date
FROM activities
WHERE driving_path_flag = 1
ORDER BY target_start_date
""")
critical_activities = cursor.fetchall()
conn.close()
assert len(critical_activities) > 0
async def test_external_script_can_join_tables(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""External script can join activities with WBS."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
db_path = result["database"]["db_path"]
conn = sqlite3.connect(db_path)
cursor = conn.execute("""
SELECT a.task_code, a.task_name, w.wbs_name
FROM activities a
JOIN wbs w ON a.wbs_id = w.wbs_id
LIMIT 10
""")
joined_rows = cursor.fetchall()
conn.close()
assert len(joined_rows) > 0
async def test_database_accessible_after_mcp_load(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""Database remains accessible while MCP tools are active."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
loaded_count = result["activity_count"]
# External script queries database
conn = sqlite3.connect(str(db_file))
cursor = conn.execute("SELECT COUNT(*) FROM activities")
external_count = cursor.fetchone()[0]
conn.close()
# Both should match
assert external_count == loaded_count
async def test_schema_info_matches_actual_database(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""Returned schema info matches actual database structure."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
schema = result["database"]["schema"]
db_path = result["database"]["db_path"]
# Verify tables exist in actual database
conn = sqlite3.connect(db_path)
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
actual_tables = {row[0] for row in cursor.fetchall()}
conn.close()
schema_tables = {t["name"] for t in schema["tables"]}
assert schema_tables == actual_tables
async def test_row_counts_match_actual_data(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""Schema row counts match actual database row counts."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
schema = result["database"]["schema"]
db_path = result["database"]["db_path"]
conn = sqlite3.connect(db_path)
for table_info in schema["tables"]:
cursor = conn.execute(
f"SELECT COUNT(*) FROM {table_info['name']}" # noqa: S608
)
actual_count = cursor.fetchone()[0]
assert table_info["row_count"] == actual_count, (
f"Table {table_info['name']}: expected {table_info['row_count']}, "
f"got {actual_count}"
)
conn.close()

View File

@@ -0,0 +1,234 @@
"""Unit tests for DatabaseManager file-based database support."""
import sqlite3
from datetime import datetime
from pathlib import Path
from xer_mcp.db import DatabaseManager
class TestDatabaseManagerInitialization:
"""Tests for DatabaseManager initialization modes."""
def test_initialize_with_memory_by_default(self) -> None:
"""Default initialization uses in-memory database."""
dm = DatabaseManager()
dm.initialize()
assert dm.db_path == ":memory:"
assert dm.is_persistent is False
dm.close()
def test_initialize_with_file_path(self, tmp_path: Path) -> None:
"""Can initialize with explicit file path."""
db_file = tmp_path / "test.db"
dm = DatabaseManager()
dm.initialize(db_path=str(db_file))
assert dm.db_path == str(db_file)
assert dm.is_persistent is True
assert db_file.exists()
dm.close()
def test_initialize_with_empty_string_auto_generates_path(self, tmp_path: Path) -> None:
"""Empty string db_path with source_file auto-generates path."""
xer_file = tmp_path / "schedule.xer"
xer_file.write_text("dummy content")
dm = DatabaseManager()
dm.initialize(db_path="", source_file=str(xer_file))
expected_db = str(tmp_path / "schedule.sqlite")
assert dm.db_path == expected_db
assert dm.is_persistent is True
assert Path(expected_db).exists()
dm.close()
def test_file_database_persists_after_close(self, tmp_path: Path) -> None:
"""File-based database persists after connection close."""
db_file = tmp_path / "persist_test.db"
dm = DatabaseManager()
dm.initialize(db_path=str(db_file))
# Insert test data
with dm.cursor() as cur:
cur.execute(
"INSERT INTO projects (proj_id, proj_short_name, loaded_at) "
"VALUES ('P1', 'Test', datetime('now'))"
)
dm.commit()
dm.close()
# Verify file exists and has data
assert db_file.exists()
conn = sqlite3.connect(str(db_file))
cursor = conn.execute("SELECT proj_id FROM projects")
rows = cursor.fetchall()
conn.close()
assert len(rows) == 1
assert rows[0][0] == "P1"
def test_source_file_tracked(self, tmp_path: Path) -> None:
"""Source file path is tracked when provided."""
db_file = tmp_path / "test.db"
xer_file = tmp_path / "schedule.xer"
xer_file.write_text("dummy")
dm = DatabaseManager()
dm.initialize(db_path=str(db_file), source_file=str(xer_file))
assert dm.source_file == str(xer_file)
dm.close()
def test_loaded_at_timestamp(self, tmp_path: Path) -> None:
"""Loaded_at timestamp is recorded."""
db_file = tmp_path / "test.db"
dm = DatabaseManager()
before = datetime.now()
dm.initialize(db_path=str(db_file))
after = datetime.now()
loaded_at = dm.loaded_at
assert loaded_at is not None
assert before <= loaded_at <= after
dm.close()
def test_memory_database_not_persistent(self) -> None:
"""In-memory database is not persistent."""
dm = DatabaseManager()
dm.initialize()
assert dm.is_persistent is False
assert dm.db_path == ":memory:"
dm.close()
class TestDatabaseManagerWalMode:
"""Tests for WAL mode in file-based databases."""
def test_file_database_uses_wal_mode(self, tmp_path: Path) -> None:
"""File-based database uses WAL mode for concurrent access."""
db_file = tmp_path / "wal_test.db"
dm = DatabaseManager()
dm.initialize(db_path=str(db_file))
with dm.cursor() as cur:
cur.execute("PRAGMA journal_mode")
mode = cur.fetchone()[0]
assert mode.lower() == "wal"
dm.close()
def test_memory_database_does_not_use_wal(self) -> None:
"""In-memory database doesn't use WAL mode (not applicable)."""
dm = DatabaseManager()
dm.initialize()
with dm.cursor() as cur:
cur.execute("PRAGMA journal_mode")
mode = cur.fetchone()[0]
# Memory databases use 'memory' journal mode
assert mode.lower() == "memory"
dm.close()
class TestAtomicWrite:
"""Tests for atomic write pattern."""
def test_atomic_write_creates_final_file(self, tmp_path: Path) -> None:
"""Database is created at final path after initialization."""
target = tmp_path / "atomic_test.db"
dm = DatabaseManager()
dm.initialize(db_path=str(target))
assert target.exists()
assert not Path(str(target) + ".tmp").exists()
dm.close()
def test_atomic_write_no_temp_file_remains(self, tmp_path: Path) -> None:
"""No .tmp file remains after successful initialization."""
target = tmp_path / "atomic_clean.db"
dm = DatabaseManager()
dm.initialize(db_path=str(target))
dm.close()
# Check no temp files remain
temp_files = list(tmp_path.glob("*.tmp"))
assert len(temp_files) == 0
class TestSchemaIntrospection:
"""Tests for database schema introspection."""
def test_get_schema_info_returns_all_tables(self) -> None:
"""Schema info includes all database tables."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
assert schema["version"] == "0.2.0"
table_names = [t["name"] for t in schema["tables"]]
assert "projects" in table_names
assert "activities" in table_names
assert "relationships" in table_names
assert "wbs" in table_names
assert "calendars" in table_names
dm.close()
def test_get_schema_info_includes_column_details(self) -> None:
"""Schema info includes column names, types, and nullable."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
column_names = [c["name"] for c in activities_table["columns"]]
assert "task_id" in column_names
assert "task_name" in column_names
# Check column details
task_id_col = next(c for c in activities_table["columns"] if c["name"] == "task_id")
assert task_id_col["type"] == "TEXT"
# Note: SQLite reports PRIMARY KEY TEXT columns as nullable
# but the PRIMARY KEY constraint still applies
assert "nullable" in task_id_col
# Check a NOT NULL column
task_name_col = next(c for c in activities_table["columns"] if c["name"] == "task_name")
assert task_name_col["nullable"] is False
dm.close()
def test_get_schema_info_includes_row_counts(self) -> None:
"""Schema info includes row counts for each table."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
for table in schema["tables"]:
assert "row_count" in table
assert isinstance(table["row_count"], int)
assert table["row_count"] >= 0
dm.close()
def test_schema_info_includes_primary_keys(self) -> None:
"""Schema info includes primary key for each table."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
assert "primary_key" in activities_table
assert "task_id" in activities_table["primary_key"]
dm.close()
def test_schema_info_includes_foreign_keys(self) -> None:
"""Schema info includes foreign key relationships."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
assert "foreign_keys" in activities_table
# activities.proj_id -> projects.proj_id
fk = next(
(fk for fk in activities_table["foreign_keys"] if fk["column"] == "proj_id"),
None,
)
assert fk is not None
assert fk["references_table"] == "projects"
assert fk["references_column"] == "proj_id"
dm.close()