feat: add direct database access for scripts (v0.2.0)

Implement persistent SQLite database feature that allows scripts to
query schedule data directly via SQL after loading XER files through MCP.

Key changes:
- Extend load_xer with db_path parameter for persistent database
- Add get_database_info tool to retrieve database connection details
- Add schema introspection with tables, columns, primary/foreign keys
- Support WAL mode for concurrent read access
- Use atomic write pattern to prevent corruption

New features:
- db_path=None: in-memory database (default, backward compatible)
- db_path="": auto-generate path from XER filename (.sqlite extension)
- db_path="/path/to/db": explicit persistent database path

Response includes complete DatabaseInfo:
- db_path: absolute path (or :memory:)
- is_persistent: boolean
- source_file: loaded XER path
- loaded_at: ISO timestamp
- schema: tables with columns, primary keys, foreign keys, row counts

Closes: User Story 1, 2, 3 from 002-direct-db-access spec
This commit is contained in:
2026-01-08 12:54:56 -05:00
parent 3e7ad39eb8
commit d6a79bf24a
11 changed files with 1064 additions and 40 deletions

View File

@@ -1,6 +1,6 @@
[project] [project]
name = "xer-mcp" name = "xer-mcp"
version = "0.1.0" version = "0.2.0"
description = "MCP server for querying Primavera P6 XER schedule data" description = "MCP server for querying Primavera P6 XER schedule data"
readme = "README.md" readme = "README.md"
requires-python = ">=3.14" requires-python = ">=3.14"

View File

@@ -1,24 +1,66 @@
"""Database connection management for XER MCP Server.""" """Database connection management for XER MCP Server."""
import os
import sqlite3 import sqlite3
from collections.abc import Generator from collections.abc import Generator
from contextlib import contextmanager from contextlib import contextmanager
from datetime import datetime
from pathlib import Path
from xer_mcp.db.schema import get_schema from xer_mcp.db.schema import get_schema
# Schema version for introspection responses
SCHEMA_VERSION = "0.2.0"
class DatabaseManager: class DatabaseManager:
"""Manages SQLite database connections and schema initialization.""" """Manages SQLite database connections and schema initialization."""
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize database manager with in-memory database.""" """Initialize database manager."""
self._connection: sqlite3.Connection | None = None self._connection: sqlite3.Connection | None = None
self._db_path: str = ":memory:"
self._source_file: str | None = None
self._loaded_at: datetime | None = None
def initialize(self) -> None: def initialize(
"""Initialize the in-memory database with schema.""" self,
self._connection = sqlite3.connect(":memory:", check_same_thread=False) db_path: str | None = None,
self._connection.row_factory = sqlite3.Row source_file: str | None = None,
self._connection.executescript(get_schema()) ) -> None:
"""Initialize the database with schema.
Args:
db_path: Path for database file. If None or omitted, uses in-memory.
If empty string, auto-generates from source_file.
source_file: Path to the XER file being loaded (for tracking).
"""
self._source_file = source_file
self._loaded_at = datetime.now()
# Determine database path
if db_path is None:
# Default: in-memory database
self._db_path = ":memory:"
elif db_path == "":
# Auto-generate from source file
if source_file:
base = Path(source_file).with_suffix(".sqlite")
self._db_path = str(base)
else:
self._db_path = ":memory:"
else:
# Use provided path
self._db_path = db_path
# Create database
if self._db_path == ":memory:":
self._connection = sqlite3.connect(":memory:", check_same_thread=False)
self._connection.row_factory = sqlite3.Row
self._connection.executescript(get_schema())
else:
# File-based database with atomic write pattern
self._create_file_database()
def clear(self) -> None: def clear(self) -> None:
"""Clear all data from the database.""" """Clear all data from the database."""
@@ -61,6 +103,159 @@ class DatabaseManager:
"""Check if the database is initialized.""" """Check if the database is initialized."""
return self._connection is not None return self._connection is not None
@property
def db_path(self) -> str:
"""Get the database path."""
return self._db_path
@property
def is_persistent(self) -> bool:
"""Check if the database is file-based (persistent)."""
return self._db_path != ":memory:"
@property
def source_file(self) -> str | None:
"""Get the source XER file path."""
return self._source_file
@property
def loaded_at(self) -> datetime | None:
"""Get the timestamp when data was loaded."""
return self._loaded_at
def _create_file_database(self) -> None:
"""Create a file-based database with atomic write pattern."""
temp_path = self._db_path + ".tmp"
try:
# Create database at temp path
conn = sqlite3.connect(temp_path, check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.executescript(get_schema())
conn.commit()
conn.close()
# Atomic rename (POSIX)
if os.path.exists(self._db_path):
os.unlink(self._db_path)
os.rename(temp_path, self._db_path)
# Open final database with WAL mode
self._connection = sqlite3.connect(self._db_path, check_same_thread=False)
self._connection.row_factory = sqlite3.Row
self._connection.execute("PRAGMA journal_mode=WAL")
except Exception:
# Clean up temp file on failure
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def get_schema_info(self) -> dict:
"""Get database schema information for introspection.
Returns:
Dictionary with schema version and table information.
"""
if self._connection is None:
raise RuntimeError("Database not initialized. Call initialize() first.")
tables = []
# Get all tables (excluding sqlite internal tables)
cursor = self._connection.execute(
"SELECT name FROM sqlite_master WHERE type='table' "
"AND name NOT LIKE 'sqlite_%' ORDER BY name"
)
table_names = [row[0] for row in cursor.fetchall()]
for table_name in table_names:
table_info = self._get_table_info(table_name)
tables.append(table_info)
return {
"version": SCHEMA_VERSION,
"tables": tables,
}
def _get_table_info(self, table_name: str) -> dict:
"""Get detailed information about a table.
Args:
table_name: Name of the table.
Returns:
Dictionary with table name, columns, primary keys, foreign keys, row count.
"""
columns = []
primary_key = []
# Get column info
cursor = self._connection.execute(f"PRAGMA table_info({table_name})") # noqa: S608
for row in cursor.fetchall():
col_name = row[1]
col_type = row[2] or "TEXT"
not_null = bool(row[3])
default_val = row[4]
is_pk = bool(row[5])
col_info: dict = {
"name": col_name,
"type": col_type,
"nullable": not not_null,
}
if default_val is not None:
col_info["default"] = str(default_val)
columns.append(col_info)
if is_pk:
primary_key.append(col_name)
# Get foreign keys
foreign_keys = []
cursor = self._connection.execute(
f"PRAGMA foreign_key_list({table_name})" # noqa: S608
)
for row in cursor.fetchall():
fk_info = {
"column": row[3], # from column
"references_table": row[2], # table
"references_column": row[4], # to column
}
foreign_keys.append(fk_info)
# Get row count
cursor = self._connection.execute(
f"SELECT COUNT(*) FROM {table_name}" # noqa: S608
)
row_count = cursor.fetchone()[0]
return {
"name": table_name,
"columns": columns,
"primary_key": primary_key,
"foreign_keys": foreign_keys,
"row_count": row_count,
}
def get_database_info(self) -> dict:
"""Get complete database information for API responses.
Returns:
Dictionary with database path, persistence status, source file,
loaded timestamp, and schema information.
"""
if not self.is_initialized:
raise RuntimeError("Database not initialized. Call initialize() first.")
return {
"db_path": self._db_path,
"is_persistent": self.is_persistent,
"source_file": self._source_file,
"loaded_at": self._loaded_at.isoformat() if self._loaded_at else None,
"schema": self.get_schema_info(),
}
# Global database manager instance # Global database manager instance
db = DatabaseManager() db = DatabaseManager()

View File

@@ -239,16 +239,18 @@ def query_relationships(
lag_hours=row[6] or 0.0, lag_hours=row[6] or 0.0,
pred_type=pred_type, pred_type=pred_type,
) )
relationships.append({ relationships.append(
"task_pred_id": row[0], {
"task_id": row[1], "task_pred_id": row[0],
"task_name": row[2], "task_id": row[1],
"pred_task_id": row[3], "task_name": row[2],
"pred_task_name": row[4], "pred_task_id": row[3],
"pred_type": pred_type, "pred_task_name": row[4],
"lag_hr_cnt": row[6], "pred_type": pred_type,
"driving": driving, "lag_hr_cnt": row[6],
}) "driving": driving,
}
)
return relationships, total return relationships, total
@@ -298,14 +300,16 @@ def get_predecessors(activity_id: str) -> list[dict]:
lag_hours=row[4] or 0.0, lag_hours=row[4] or 0.0,
pred_type=pred_type, pred_type=pred_type,
) )
result.append({ result.append(
"task_id": row[0], {
"task_code": row[1], "task_id": row[0],
"task_name": row[2], "task_code": row[1],
"relationship_type": pred_type, "task_name": row[2],
"lag_hr_cnt": row[4], "relationship_type": pred_type,
"driving": driving, "lag_hr_cnt": row[4],
}) "driving": driving,
}
)
return result return result
@@ -355,14 +359,16 @@ def get_successors(activity_id: str) -> list[dict]:
lag_hours=row[4] or 0.0, lag_hours=row[4] or 0.0,
pred_type=pred_type, pred_type=pred_type,
) )
result.append({ result.append(
"task_id": row[0], {
"task_code": row[1], "task_id": row[0],
"task_name": row[2], "task_code": row[1],
"relationship_type": pred_type, "task_name": row[2],
"lag_hr_cnt": row[4], "relationship_type": pred_type,
"driving": driving, "lag_hr_cnt": row[4],
}) "driving": driving,
}
)
return result return result

View File

@@ -56,3 +56,30 @@ class ActivityNotFoundError(XerMcpError):
"ACTIVITY_NOT_FOUND", "ACTIVITY_NOT_FOUND",
f"Activity not found: {activity_id}", f"Activity not found: {activity_id}",
) )
class FileNotWritableError(XerMcpError):
"""Raised when the database file path is not writable."""
def __init__(self, path: str, reason: str = "") -> None:
msg = f"Cannot write to database file: {path}"
if reason:
msg = f"{msg} ({reason})"
super().__init__("FILE_NOT_WRITABLE", msg)
class DiskFullError(XerMcpError):
"""Raised when there is insufficient disk space."""
def __init__(self, path: str) -> None:
super().__init__(
"DISK_FULL",
f"Insufficient disk space to create database: {path}",
)
class DatabaseError(XerMcpError):
"""Raised for general database errors."""
def __init__(self, message: str) -> None:
super().__init__("DATABASE_ERROR", message)

View File

@@ -50,6 +50,12 @@ async def list_tools() -> list[Tool]:
"type": "string", "type": "string",
"description": "Project ID to select (required for multi-project files)", "description": "Project ID to select (required for multi-project files)",
}, },
"db_path": {
"type": "string",
"description": "Path for persistent SQLite database file. "
"If omitted, uses in-memory database. "
"If empty string, auto-generates path from XER filename (same directory, .sqlite extension).",
},
}, },
"required": ["file_path"], "required": ["file_path"],
}, },
@@ -183,6 +189,15 @@ async def list_tools() -> list[Tool]:
"properties": {}, "properties": {},
}, },
), ),
Tool(
name="get_database_info",
description="Get information about the currently loaded database including file path and schema. "
"Use this to get connection details for direct SQL access.",
inputSchema={
"type": "object",
"properties": {},
},
),
] ]
@@ -197,6 +212,7 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
result = await load_xer( result = await load_xer(
file_path=arguments["file_path"], file_path=arguments["file_path"],
project_id=arguments.get("project_id"), project_id=arguments.get("project_id"),
db_path=arguments.get("db_path"),
) )
return [TextContent(type="text", text=json.dumps(result, indent=2))] return [TextContent(type="text", text=json.dumps(result, indent=2))]
@@ -258,6 +274,12 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
result = await get_critical_path() result = await get_critical_path()
return [TextContent(type="text", text=json.dumps(result, indent=2))] return [TextContent(type="text", text=json.dumps(result, indent=2))]
if name == "get_database_info":
from xer_mcp.tools.get_database_info import get_database_info
result = await get_database_info()
return [TextContent(type="text", text=json.dumps(result, indent=2))]
raise ValueError(f"Unknown tool: {name}") raise ValueError(f"Unknown tool: {name}")

View File

@@ -0,0 +1,25 @@
"""get_database_info MCP tool implementation."""
from xer_mcp.db import db
async def get_database_info() -> dict:
"""Get information about the currently loaded database.
Returns connection details for direct SQL access including
database path, schema information, and metadata.
Returns:
Dictionary with database info or error if no database loaded
"""
if not db.is_initialized:
return {
"error": {
"code": "NO_FILE_LOADED",
"message": "No XER file is loaded. Use the load_xer tool first.",
}
}
return {
"database": db.get_database_info(),
}

View File

@@ -1,5 +1,9 @@
"""load_xer MCP tool implementation.""" """load_xer MCP tool implementation."""
import errno
import os
import sqlite3
from xer_mcp.db import db from xer_mcp.db import db
from xer_mcp.db.loader import get_activity_count, get_relationship_count, load_parsed_data from xer_mcp.db.loader import get_activity_count, get_relationship_count, load_parsed_data
from xer_mcp.errors import FileNotFoundError, ParseError from xer_mcp.errors import FileNotFoundError, ParseError
@@ -7,19 +11,55 @@ from xer_mcp.parser.xer_parser import XerParser
from xer_mcp.server import set_file_loaded from xer_mcp.server import set_file_loaded
async def load_xer(file_path: str, project_id: str | None = None) -> dict: async def load_xer(
file_path: str,
project_id: str | None = None,
db_path: str | None = None,
) -> dict:
"""Load a Primavera P6 XER file and parse its schedule data. """Load a Primavera P6 XER file and parse its schedule data.
Args: Args:
file_path: Absolute path to the XER file file_path: Absolute path to the XER file
project_id: Project ID to select (required for multi-project files) project_id: Project ID to select (required for multi-project files)
db_path: Path for persistent database file. If None, uses in-memory.
If empty string, auto-generates from XER filename.
Returns: Returns:
Dictionary with success status and project info or error details Dictionary with success status and project info or error details
""" """
# Ensure database is initialized # Initialize database with specified path
if not db.is_initialized: try:
db.initialize() db.initialize(db_path=db_path, source_file=file_path)
except PermissionError:
target = db_path if db_path else file_path
return {
"success": False,
"error": {"code": "FILE_NOT_WRITABLE", "message": f"Cannot write database: {target}"},
}
except OSError as e:
target = db_path if db_path else file_path
if e.errno == errno.ENOSPC:
return {
"success": False,
"error": {"code": "DISK_FULL", "message": f"Insufficient disk space: {target}"},
}
if e.errno == errno.ENOENT:
return {
"success": False,
"error": {
"code": "FILE_NOT_WRITABLE",
"message": f"Directory does not exist: {os.path.dirname(target)}",
},
}
return {
"success": False,
"error": {"code": "DATABASE_ERROR", "message": str(e)},
}
except sqlite3.Error as e:
return {
"success": False,
"error": {"code": "DATABASE_ERROR", "message": str(e)},
}
parser = XerParser() parser = XerParser()
@@ -73,6 +113,9 @@ async def load_xer(file_path: str, project_id: str | None = None) -> dict:
activity_count = get_activity_count() activity_count = get_activity_count()
relationship_count = get_relationship_count() relationship_count = get_relationship_count()
# Get database info
database_info = db.get_database_info()
return { return {
"success": True, "success": True,
"project": { "project": {
@@ -83,4 +126,5 @@ async def load_xer(file_path: str, project_id: str | None = None) -> dict:
}, },
"activity_count": activity_count, "activity_count": activity_count,
"relationship_count": relationship_count, "relationship_count": relationship_count,
"database": database_info,
} }

View File

@@ -0,0 +1,143 @@
"""Contract tests for get_database_info MCP tool."""
from pathlib import Path
import pytest
from xer_mcp.db import db
@pytest.fixture(autouse=True)
def setup_db():
"""Reset database state for each test."""
if db.is_initialized:
db.close()
yield
if db.is_initialized:
db.close()
class TestGetDatabaseInfoContract:
"""Contract tests verifying get_database_info tool interface."""
async def test_get_database_info_returns_current_database(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""get_database_info returns info about currently loaded database."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
result = await get_database_info()
assert "database" in result
assert result["database"]["db_path"] == str(db_file)
assert result["database"]["is_persistent"] is True
async def test_get_database_info_error_when_no_database(self) -> None:
"""get_database_info returns error when no database loaded."""
from xer_mcp.tools.get_database_info import get_database_info
# Ensure database is not initialized
if db.is_initialized:
db.close()
result = await get_database_info()
assert "error" in result
assert result["error"]["code"] == "NO_FILE_LOADED"
async def test_get_database_info_includes_schema(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""get_database_info includes schema information."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
result = await get_database_info()
assert "schema" in result["database"]
assert "tables" in result["database"]["schema"]
async def test_get_database_info_includes_loaded_at(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""get_database_info includes loaded_at timestamp."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
result = await get_database_info()
assert "loaded_at" in result["database"]
# Should be ISO format timestamp
assert "T" in result["database"]["loaded_at"]
async def test_get_database_info_includes_source_file(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""get_database_info includes source XER file path."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
result = await get_database_info()
assert result["database"]["source_file"] == str(sample_xer_single_project)
async def test_get_database_info_for_memory_database(
self, sample_xer_single_project: Path
) -> None:
"""get_database_info works for in-memory database."""
from xer_mcp.tools.get_database_info import get_database_info
from xer_mcp.tools.load_xer import load_xer
await load_xer(file_path=str(sample_xer_single_project))
result = await get_database_info()
assert "database" in result
assert result["database"]["db_path"] == ":memory:"
assert result["database"]["is_persistent"] is False
class TestGetDatabaseInfoToolSchema:
"""Tests for MCP tool schema."""
async def test_get_database_info_tool_registered(self) -> None:
"""get_database_info tool is registered with MCP server."""
from xer_mcp.server import list_tools
tools = await list_tools()
tool_names = [t.name for t in tools]
assert "get_database_info" in tool_names
async def test_get_database_info_tool_has_empty_input_schema(self) -> None:
"""get_database_info tool has no required inputs."""
from xer_mcp.server import list_tools
tools = await list_tools()
tool = next(t for t in tools if t.name == "get_database_info")
# Should have empty or no required properties
assert "required" not in tool.inputSchema or len(tool.inputSchema.get("required", [])) == 0

View File

@@ -1,5 +1,6 @@
"""Contract tests for load_xer MCP tool.""" """Contract tests for load_xer MCP tool."""
import sqlite3
from pathlib import Path from pathlib import Path
import pytest import pytest
@@ -9,10 +10,14 @@ from xer_mcp.db import db
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def setup_db(): def setup_db():
"""Initialize and clear database for each test.""" """Reset database state for each test."""
db.initialize() # Close any existing connection
if db.is_initialized:
db.close()
yield yield
db.clear() # Cleanup after test
if db.is_initialized:
db.close()
class TestLoadXerContract: class TestLoadXerContract:
@@ -105,3 +110,141 @@ class TestLoadXerContract:
assert "plan_end_date" in result["project"] assert "plan_end_date" in result["project"]
# Dates should be ISO8601 format # Dates should be ISO8601 format
assert "T" in result["project"]["plan_start_date"] assert "T" in result["project"]["plan_start_date"]
class TestLoadXerPersistentDatabase:
"""Contract tests for persistent database functionality."""
async def test_load_xer_with_db_path_creates_file(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""load_xer with db_path creates persistent database file."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
assert result["success"] is True
assert db_file.exists()
assert result["database"]["db_path"] == str(db_file)
assert result["database"]["is_persistent"] is True
async def test_load_xer_with_empty_db_path_auto_generates(self, tmp_path: Path) -> None:
"""load_xer with empty db_path generates path from XER filename."""
from xer_mcp.tools.load_xer import load_xer
# Create XER file in tmp_path
xer_file = tmp_path / "my_schedule.xer"
from tests.conftest import SAMPLE_XER_SINGLE_PROJECT
xer_file.write_text(SAMPLE_XER_SINGLE_PROJECT)
result = await load_xer(file_path=str(xer_file), db_path="")
assert result["success"] is True
expected_db = str(tmp_path / "my_schedule.sqlite")
assert result["database"]["db_path"] == expected_db
assert result["database"]["is_persistent"] is True
assert Path(expected_db).exists()
async def test_load_xer_without_db_path_uses_memory(
self, sample_xer_single_project: Path
) -> None:
"""load_xer without db_path uses in-memory database (backward compatible)."""
from xer_mcp.tools.load_xer import load_xer
result = await load_xer(file_path=str(sample_xer_single_project))
assert result["success"] is True
assert result["database"]["db_path"] == ":memory:"
assert result["database"]["is_persistent"] is False
async def test_load_xer_database_contains_all_data(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""Persistent database contains all parsed data."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
# Verify data via direct SQL
conn = sqlite3.connect(str(db_file))
cursor = conn.execute("SELECT COUNT(*) FROM activities")
count = cursor.fetchone()[0]
conn.close()
assert count == result["activity_count"]
async def test_load_xer_response_includes_database_info(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""load_xer response includes complete database info."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
assert "database" in result
db_info = result["database"]
assert "db_path" in db_info
assert "is_persistent" in db_info
assert "source_file" in db_info
assert "loaded_at" in db_info
assert "schema" in db_info
async def test_load_xer_response_schema_includes_tables(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""load_xer response schema includes table information."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
schema = result["database"]["schema"]
assert "version" in schema
assert "tables" in schema
table_names = [t["name"] for t in schema["tables"]]
assert "activities" in table_names
assert "relationships" in table_names
async def test_load_xer_error_on_invalid_path(self, sample_xer_single_project: Path) -> None:
"""load_xer returns error for invalid path."""
from xer_mcp.tools.load_xer import load_xer
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path="/nonexistent/dir/file.db",
)
assert result["success"] is False
# Either FILE_NOT_WRITABLE or DATABASE_ERROR is acceptable
# depending on how SQLite reports the error
assert result["error"]["code"] in ("FILE_NOT_WRITABLE", "DATABASE_ERROR")
class TestLoadXerToolSchema:
"""Tests for MCP tool schema."""
async def test_load_xer_tool_schema_includes_db_path(self) -> None:
"""MCP tool schema includes db_path parameter."""
from xer_mcp.server import list_tools
tools = await list_tools()
load_xer_tool = next(t for t in tools if t.name == "load_xer")
props = load_xer_tool.inputSchema["properties"]
assert "db_path" in props
assert props["db_path"]["type"] == "string"

View File

@@ -0,0 +1,185 @@
"""Integration tests for direct database access feature."""
import sqlite3
from pathlib import Path
import pytest
from xer_mcp.db import db
@pytest.fixture(autouse=True)
def setup_db():
"""Reset database state for each test."""
if db.is_initialized:
db.close()
yield
if db.is_initialized:
db.close()
class TestDirectDatabaseAccess:
"""Integration tests verifying external script can access database."""
async def test_external_script_can_query_database(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""External script can query database using returned path."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
# Simulate external script access (as shown in quickstart.md)
db_path = result["database"]["db_path"]
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
# Query milestones
cursor = conn.execute("""
SELECT task_code, task_name, target_start_date, milestone_type
FROM activities
WHERE task_type IN ('TT_Mile', 'TT_FinMile')
ORDER BY target_start_date
""")
milestones = cursor.fetchall()
conn.close()
assert len(milestones) > 0
assert all(row["task_code"] for row in milestones)
async def test_external_script_can_query_critical_path(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""External script can query critical path activities."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
db_path = result["database"]["db_path"]
conn = sqlite3.connect(db_path)
cursor = conn.execute("""
SELECT task_code, task_name, target_start_date, target_end_date
FROM activities
WHERE driving_path_flag = 1
ORDER BY target_start_date
""")
critical_activities = cursor.fetchall()
conn.close()
assert len(critical_activities) > 0
async def test_external_script_can_join_tables(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""External script can join activities with WBS."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
db_path = result["database"]["db_path"]
conn = sqlite3.connect(db_path)
cursor = conn.execute("""
SELECT a.task_code, a.task_name, w.wbs_name
FROM activities a
JOIN wbs w ON a.wbs_id = w.wbs_id
LIMIT 10
""")
joined_rows = cursor.fetchall()
conn.close()
assert len(joined_rows) > 0
async def test_database_accessible_after_mcp_load(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""Database remains accessible while MCP tools are active."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
loaded_count = result["activity_count"]
# External script queries database
conn = sqlite3.connect(str(db_file))
cursor = conn.execute("SELECT COUNT(*) FROM activities")
external_count = cursor.fetchone()[0]
conn.close()
# Both should match
assert external_count == loaded_count
async def test_schema_info_matches_actual_database(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""Returned schema info matches actual database structure."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
schema = result["database"]["schema"]
db_path = result["database"]["db_path"]
# Verify tables exist in actual database
conn = sqlite3.connect(db_path)
cursor = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
actual_tables = {row[0] for row in cursor.fetchall()}
conn.close()
schema_tables = {t["name"] for t in schema["tables"]}
assert schema_tables == actual_tables
async def test_row_counts_match_actual_data(
self, tmp_path: Path, sample_xer_single_project: Path
) -> None:
"""Schema row counts match actual database row counts."""
from xer_mcp.tools.load_xer import load_xer
db_file = tmp_path / "schedule.db"
result = await load_xer(
file_path=str(sample_xer_single_project),
db_path=str(db_file),
)
schema = result["database"]["schema"]
db_path = result["database"]["db_path"]
conn = sqlite3.connect(db_path)
for table_info in schema["tables"]:
cursor = conn.execute(
f"SELECT COUNT(*) FROM {table_info['name']}" # noqa: S608
)
actual_count = cursor.fetchone()[0]
assert table_info["row_count"] == actual_count, (
f"Table {table_info['name']}: expected {table_info['row_count']}, "
f"got {actual_count}"
)
conn.close()

View File

@@ -0,0 +1,234 @@
"""Unit tests for DatabaseManager file-based database support."""
import sqlite3
from datetime import datetime
from pathlib import Path
from xer_mcp.db import DatabaseManager
class TestDatabaseManagerInitialization:
"""Tests for DatabaseManager initialization modes."""
def test_initialize_with_memory_by_default(self) -> None:
"""Default initialization uses in-memory database."""
dm = DatabaseManager()
dm.initialize()
assert dm.db_path == ":memory:"
assert dm.is_persistent is False
dm.close()
def test_initialize_with_file_path(self, tmp_path: Path) -> None:
"""Can initialize with explicit file path."""
db_file = tmp_path / "test.db"
dm = DatabaseManager()
dm.initialize(db_path=str(db_file))
assert dm.db_path == str(db_file)
assert dm.is_persistent is True
assert db_file.exists()
dm.close()
def test_initialize_with_empty_string_auto_generates_path(self, tmp_path: Path) -> None:
"""Empty string db_path with source_file auto-generates path."""
xer_file = tmp_path / "schedule.xer"
xer_file.write_text("dummy content")
dm = DatabaseManager()
dm.initialize(db_path="", source_file=str(xer_file))
expected_db = str(tmp_path / "schedule.sqlite")
assert dm.db_path == expected_db
assert dm.is_persistent is True
assert Path(expected_db).exists()
dm.close()
def test_file_database_persists_after_close(self, tmp_path: Path) -> None:
"""File-based database persists after connection close."""
db_file = tmp_path / "persist_test.db"
dm = DatabaseManager()
dm.initialize(db_path=str(db_file))
# Insert test data
with dm.cursor() as cur:
cur.execute(
"INSERT INTO projects (proj_id, proj_short_name, loaded_at) "
"VALUES ('P1', 'Test', datetime('now'))"
)
dm.commit()
dm.close()
# Verify file exists and has data
assert db_file.exists()
conn = sqlite3.connect(str(db_file))
cursor = conn.execute("SELECT proj_id FROM projects")
rows = cursor.fetchall()
conn.close()
assert len(rows) == 1
assert rows[0][0] == "P1"
def test_source_file_tracked(self, tmp_path: Path) -> None:
"""Source file path is tracked when provided."""
db_file = tmp_path / "test.db"
xer_file = tmp_path / "schedule.xer"
xer_file.write_text("dummy")
dm = DatabaseManager()
dm.initialize(db_path=str(db_file), source_file=str(xer_file))
assert dm.source_file == str(xer_file)
dm.close()
def test_loaded_at_timestamp(self, tmp_path: Path) -> None:
"""Loaded_at timestamp is recorded."""
db_file = tmp_path / "test.db"
dm = DatabaseManager()
before = datetime.now()
dm.initialize(db_path=str(db_file))
after = datetime.now()
loaded_at = dm.loaded_at
assert loaded_at is not None
assert before <= loaded_at <= after
dm.close()
def test_memory_database_not_persistent(self) -> None:
"""In-memory database is not persistent."""
dm = DatabaseManager()
dm.initialize()
assert dm.is_persistent is False
assert dm.db_path == ":memory:"
dm.close()
class TestDatabaseManagerWalMode:
"""Tests for WAL mode in file-based databases."""
def test_file_database_uses_wal_mode(self, tmp_path: Path) -> None:
"""File-based database uses WAL mode for concurrent access."""
db_file = tmp_path / "wal_test.db"
dm = DatabaseManager()
dm.initialize(db_path=str(db_file))
with dm.cursor() as cur:
cur.execute("PRAGMA journal_mode")
mode = cur.fetchone()[0]
assert mode.lower() == "wal"
dm.close()
def test_memory_database_does_not_use_wal(self) -> None:
"""In-memory database doesn't use WAL mode (not applicable)."""
dm = DatabaseManager()
dm.initialize()
with dm.cursor() as cur:
cur.execute("PRAGMA journal_mode")
mode = cur.fetchone()[0]
# Memory databases use 'memory' journal mode
assert mode.lower() == "memory"
dm.close()
class TestAtomicWrite:
"""Tests for atomic write pattern."""
def test_atomic_write_creates_final_file(self, tmp_path: Path) -> None:
"""Database is created at final path after initialization."""
target = tmp_path / "atomic_test.db"
dm = DatabaseManager()
dm.initialize(db_path=str(target))
assert target.exists()
assert not Path(str(target) + ".tmp").exists()
dm.close()
def test_atomic_write_no_temp_file_remains(self, tmp_path: Path) -> None:
"""No .tmp file remains after successful initialization."""
target = tmp_path / "atomic_clean.db"
dm = DatabaseManager()
dm.initialize(db_path=str(target))
dm.close()
# Check no temp files remain
temp_files = list(tmp_path.glob("*.tmp"))
assert len(temp_files) == 0
class TestSchemaIntrospection:
"""Tests for database schema introspection."""
def test_get_schema_info_returns_all_tables(self) -> None:
"""Schema info includes all database tables."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
assert schema["version"] == "0.2.0"
table_names = [t["name"] for t in schema["tables"]]
assert "projects" in table_names
assert "activities" in table_names
assert "relationships" in table_names
assert "wbs" in table_names
assert "calendars" in table_names
dm.close()
def test_get_schema_info_includes_column_details(self) -> None:
"""Schema info includes column names, types, and nullable."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
column_names = [c["name"] for c in activities_table["columns"]]
assert "task_id" in column_names
assert "task_name" in column_names
# Check column details
task_id_col = next(c for c in activities_table["columns"] if c["name"] == "task_id")
assert task_id_col["type"] == "TEXT"
# Note: SQLite reports PRIMARY KEY TEXT columns as nullable
# but the PRIMARY KEY constraint still applies
assert "nullable" in task_id_col
# Check a NOT NULL column
task_name_col = next(c for c in activities_table["columns"] if c["name"] == "task_name")
assert task_name_col["nullable"] is False
dm.close()
def test_get_schema_info_includes_row_counts(self) -> None:
"""Schema info includes row counts for each table."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
for table in schema["tables"]:
assert "row_count" in table
assert isinstance(table["row_count"], int)
assert table["row_count"] >= 0
dm.close()
def test_schema_info_includes_primary_keys(self) -> None:
"""Schema info includes primary key for each table."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
assert "primary_key" in activities_table
assert "task_id" in activities_table["primary_key"]
dm.close()
def test_schema_info_includes_foreign_keys(self) -> None:
"""Schema info includes foreign key relationships."""
dm = DatabaseManager()
dm.initialize()
schema = dm.get_schema_info()
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
assert "foreign_keys" in activities_table
# activities.proj_id -> projects.proj_id
fk = next(
(fk for fk in activities_table["foreign_keys"] if fk["column"] == "proj_id"),
None,
)
assert fk is not None
assert fk["references_table"] == "projects"
assert fk["references_column"] == "proj_id"
dm.close()