From d6a79bf24a4eddd9cb2310fe3940d2c4cbacbe42 Mon Sep 17 00:00:00 2001 From: Bill Ballou Date: Thu, 8 Jan 2026 12:54:56 -0500 Subject: [PATCH] feat: add direct database access for scripts (v0.2.0) Implement persistent SQLite database feature that allows scripts to query schedule data directly via SQL after loading XER files through MCP. Key changes: - Extend load_xer with db_path parameter for persistent database - Add get_database_info tool to retrieve database connection details - Add schema introspection with tables, columns, primary/foreign keys - Support WAL mode for concurrent read access - Use atomic write pattern to prevent corruption New features: - db_path=None: in-memory database (default, backward compatible) - db_path="": auto-generate path from XER filename (.sqlite extension) - db_path="/path/to/db": explicit persistent database path Response includes complete DatabaseInfo: - db_path: absolute path (or :memory:) - is_persistent: boolean - source_file: loaded XER path - loaded_at: ISO timestamp - schema: tables with columns, primary keys, foreign keys, row counts Closes: User Story 1, 2, 3 from 002-direct-db-access spec --- pyproject.toml | 2 +- src/xer_mcp/db/__init__.py | 207 +++++++++++++++++- src/xer_mcp/db/queries.py | 58 ++--- src/xer_mcp/errors.py | 27 +++ src/xer_mcp/server.py | 22 ++ src/xer_mcp/tools/get_database_info.py | 25 +++ src/xer_mcp/tools/load_xer.py | 52 ++++- tests/contract/test_get_database_info.py | 143 +++++++++++++ tests/contract/test_load_xer.py | 149 ++++++++++++- tests/integration/test_direct_db_access.py | 185 ++++++++++++++++ tests/unit/test_db_manager.py | 234 +++++++++++++++++++++ 11 files changed, 1064 insertions(+), 40 deletions(-) create mode 100644 src/xer_mcp/tools/get_database_info.py create mode 100644 tests/contract/test_get_database_info.py create mode 100644 tests/integration/test_direct_db_access.py create mode 100644 tests/unit/test_db_manager.py diff --git a/pyproject.toml b/pyproject.toml index 91ccabb..1601c80 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "xer-mcp" -version = "0.1.0" +version = "0.2.0" description = "MCP server for querying Primavera P6 XER schedule data" readme = "README.md" requires-python = ">=3.14" diff --git a/src/xer_mcp/db/__init__.py b/src/xer_mcp/db/__init__.py index f911a0d..e2f7283 100644 --- a/src/xer_mcp/db/__init__.py +++ b/src/xer_mcp/db/__init__.py @@ -1,24 +1,66 @@ """Database connection management for XER MCP Server.""" +import os import sqlite3 from collections.abc import Generator from contextlib import contextmanager +from datetime import datetime +from pathlib import Path from xer_mcp.db.schema import get_schema +# Schema version for introspection responses +SCHEMA_VERSION = "0.2.0" + class DatabaseManager: """Manages SQLite database connections and schema initialization.""" def __init__(self) -> None: - """Initialize database manager with in-memory database.""" + """Initialize database manager.""" self._connection: sqlite3.Connection | None = None + self._db_path: str = ":memory:" + self._source_file: str | None = None + self._loaded_at: datetime | None = None - def initialize(self) -> None: - """Initialize the in-memory database with schema.""" - self._connection = sqlite3.connect(":memory:", check_same_thread=False) - self._connection.row_factory = sqlite3.Row - self._connection.executescript(get_schema()) + def initialize( + self, + db_path: str | None = None, + source_file: str | None = None, + ) -> None: + """Initialize the database with schema. + + Args: + db_path: Path for database file. If None or omitted, uses in-memory. + If empty string, auto-generates from source_file. + source_file: Path to the XER file being loaded (for tracking). + """ + self._source_file = source_file + self._loaded_at = datetime.now() + + # Determine database path + if db_path is None: + # Default: in-memory database + self._db_path = ":memory:" + elif db_path == "": + # Auto-generate from source file + if source_file: + base = Path(source_file).with_suffix(".sqlite") + self._db_path = str(base) + else: + self._db_path = ":memory:" + else: + # Use provided path + self._db_path = db_path + + # Create database + if self._db_path == ":memory:": + self._connection = sqlite3.connect(":memory:", check_same_thread=False) + self._connection.row_factory = sqlite3.Row + self._connection.executescript(get_schema()) + else: + # File-based database with atomic write pattern + self._create_file_database() def clear(self) -> None: """Clear all data from the database.""" @@ -61,6 +103,159 @@ class DatabaseManager: """Check if the database is initialized.""" return self._connection is not None + @property + def db_path(self) -> str: + """Get the database path.""" + return self._db_path + + @property + def is_persistent(self) -> bool: + """Check if the database is file-based (persistent).""" + return self._db_path != ":memory:" + + @property + def source_file(self) -> str | None: + """Get the source XER file path.""" + return self._source_file + + @property + def loaded_at(self) -> datetime | None: + """Get the timestamp when data was loaded.""" + return self._loaded_at + + def _create_file_database(self) -> None: + """Create a file-based database with atomic write pattern.""" + temp_path = self._db_path + ".tmp" + + try: + # Create database at temp path + conn = sqlite3.connect(temp_path, check_same_thread=False) + conn.row_factory = sqlite3.Row + conn.executescript(get_schema()) + conn.commit() + conn.close() + + # Atomic rename (POSIX) + if os.path.exists(self._db_path): + os.unlink(self._db_path) + os.rename(temp_path, self._db_path) + + # Open final database with WAL mode + self._connection = sqlite3.connect(self._db_path, check_same_thread=False) + self._connection.row_factory = sqlite3.Row + self._connection.execute("PRAGMA journal_mode=WAL") + except Exception: + # Clean up temp file on failure + if os.path.exists(temp_path): + os.unlink(temp_path) + raise + + def get_schema_info(self) -> dict: + """Get database schema information for introspection. + + Returns: + Dictionary with schema version and table information. + """ + if self._connection is None: + raise RuntimeError("Database not initialized. Call initialize() first.") + + tables = [] + + # Get all tables (excluding sqlite internal tables) + cursor = self._connection.execute( + "SELECT name FROM sqlite_master WHERE type='table' " + "AND name NOT LIKE 'sqlite_%' ORDER BY name" + ) + table_names = [row[0] for row in cursor.fetchall()] + + for table_name in table_names: + table_info = self._get_table_info(table_name) + tables.append(table_info) + + return { + "version": SCHEMA_VERSION, + "tables": tables, + } + + def _get_table_info(self, table_name: str) -> dict: + """Get detailed information about a table. + + Args: + table_name: Name of the table. + + Returns: + Dictionary with table name, columns, primary keys, foreign keys, row count. + """ + columns = [] + primary_key = [] + + # Get column info + cursor = self._connection.execute(f"PRAGMA table_info({table_name})") # noqa: S608 + for row in cursor.fetchall(): + col_name = row[1] + col_type = row[2] or "TEXT" + not_null = bool(row[3]) + default_val = row[4] + is_pk = bool(row[5]) + + col_info: dict = { + "name": col_name, + "type": col_type, + "nullable": not not_null, + } + if default_val is not None: + col_info["default"] = str(default_val) + + columns.append(col_info) + + if is_pk: + primary_key.append(col_name) + + # Get foreign keys + foreign_keys = [] + cursor = self._connection.execute( + f"PRAGMA foreign_key_list({table_name})" # noqa: S608 + ) + for row in cursor.fetchall(): + fk_info = { + "column": row[3], # from column + "references_table": row[2], # table + "references_column": row[4], # to column + } + foreign_keys.append(fk_info) + + # Get row count + cursor = self._connection.execute( + f"SELECT COUNT(*) FROM {table_name}" # noqa: S608 + ) + row_count = cursor.fetchone()[0] + + return { + "name": table_name, + "columns": columns, + "primary_key": primary_key, + "foreign_keys": foreign_keys, + "row_count": row_count, + } + + def get_database_info(self) -> dict: + """Get complete database information for API responses. + + Returns: + Dictionary with database path, persistence status, source file, + loaded timestamp, and schema information. + """ + if not self.is_initialized: + raise RuntimeError("Database not initialized. Call initialize() first.") + + return { + "db_path": self._db_path, + "is_persistent": self.is_persistent, + "source_file": self._source_file, + "loaded_at": self._loaded_at.isoformat() if self._loaded_at else None, + "schema": self.get_schema_info(), + } + # Global database manager instance db = DatabaseManager() diff --git a/src/xer_mcp/db/queries.py b/src/xer_mcp/db/queries.py index 4b47d32..dd56ed4 100644 --- a/src/xer_mcp/db/queries.py +++ b/src/xer_mcp/db/queries.py @@ -239,16 +239,18 @@ def query_relationships( lag_hours=row[6] or 0.0, pred_type=pred_type, ) - relationships.append({ - "task_pred_id": row[0], - "task_id": row[1], - "task_name": row[2], - "pred_task_id": row[3], - "pred_task_name": row[4], - "pred_type": pred_type, - "lag_hr_cnt": row[6], - "driving": driving, - }) + relationships.append( + { + "task_pred_id": row[0], + "task_id": row[1], + "task_name": row[2], + "pred_task_id": row[3], + "pred_task_name": row[4], + "pred_type": pred_type, + "lag_hr_cnt": row[6], + "driving": driving, + } + ) return relationships, total @@ -298,14 +300,16 @@ def get_predecessors(activity_id: str) -> list[dict]: lag_hours=row[4] or 0.0, pred_type=pred_type, ) - result.append({ - "task_id": row[0], - "task_code": row[1], - "task_name": row[2], - "relationship_type": pred_type, - "lag_hr_cnt": row[4], - "driving": driving, - }) + result.append( + { + "task_id": row[0], + "task_code": row[1], + "task_name": row[2], + "relationship_type": pred_type, + "lag_hr_cnt": row[4], + "driving": driving, + } + ) return result @@ -355,14 +359,16 @@ def get_successors(activity_id: str) -> list[dict]: lag_hours=row[4] or 0.0, pred_type=pred_type, ) - result.append({ - "task_id": row[0], - "task_code": row[1], - "task_name": row[2], - "relationship_type": pred_type, - "lag_hr_cnt": row[4], - "driving": driving, - }) + result.append( + { + "task_id": row[0], + "task_code": row[1], + "task_name": row[2], + "relationship_type": pred_type, + "lag_hr_cnt": row[4], + "driving": driving, + } + ) return result diff --git a/src/xer_mcp/errors.py b/src/xer_mcp/errors.py index a9af745..3c1fd8b 100644 --- a/src/xer_mcp/errors.py +++ b/src/xer_mcp/errors.py @@ -56,3 +56,30 @@ class ActivityNotFoundError(XerMcpError): "ACTIVITY_NOT_FOUND", f"Activity not found: {activity_id}", ) + + +class FileNotWritableError(XerMcpError): + """Raised when the database file path is not writable.""" + + def __init__(self, path: str, reason: str = "") -> None: + msg = f"Cannot write to database file: {path}" + if reason: + msg = f"{msg} ({reason})" + super().__init__("FILE_NOT_WRITABLE", msg) + + +class DiskFullError(XerMcpError): + """Raised when there is insufficient disk space.""" + + def __init__(self, path: str) -> None: + super().__init__( + "DISK_FULL", + f"Insufficient disk space to create database: {path}", + ) + + +class DatabaseError(XerMcpError): + """Raised for general database errors.""" + + def __init__(self, message: str) -> None: + super().__init__("DATABASE_ERROR", message) diff --git a/src/xer_mcp/server.py b/src/xer_mcp/server.py index 3d9b7b1..a4934c1 100644 --- a/src/xer_mcp/server.py +++ b/src/xer_mcp/server.py @@ -50,6 +50,12 @@ async def list_tools() -> list[Tool]: "type": "string", "description": "Project ID to select (required for multi-project files)", }, + "db_path": { + "type": "string", + "description": "Path for persistent SQLite database file. " + "If omitted, uses in-memory database. " + "If empty string, auto-generates path from XER filename (same directory, .sqlite extension).", + }, }, "required": ["file_path"], }, @@ -183,6 +189,15 @@ async def list_tools() -> list[Tool]: "properties": {}, }, ), + Tool( + name="get_database_info", + description="Get information about the currently loaded database including file path and schema. " + "Use this to get connection details for direct SQL access.", + inputSchema={ + "type": "object", + "properties": {}, + }, + ), ] @@ -197,6 +212,7 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: result = await load_xer( file_path=arguments["file_path"], project_id=arguments.get("project_id"), + db_path=arguments.get("db_path"), ) return [TextContent(type="text", text=json.dumps(result, indent=2))] @@ -258,6 +274,12 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]: result = await get_critical_path() return [TextContent(type="text", text=json.dumps(result, indent=2))] + if name == "get_database_info": + from xer_mcp.tools.get_database_info import get_database_info + + result = await get_database_info() + return [TextContent(type="text", text=json.dumps(result, indent=2))] + raise ValueError(f"Unknown tool: {name}") diff --git a/src/xer_mcp/tools/get_database_info.py b/src/xer_mcp/tools/get_database_info.py new file mode 100644 index 0000000..2b5db87 --- /dev/null +++ b/src/xer_mcp/tools/get_database_info.py @@ -0,0 +1,25 @@ +"""get_database_info MCP tool implementation.""" + +from xer_mcp.db import db + + +async def get_database_info() -> dict: + """Get information about the currently loaded database. + + Returns connection details for direct SQL access including + database path, schema information, and metadata. + + Returns: + Dictionary with database info or error if no database loaded + """ + if not db.is_initialized: + return { + "error": { + "code": "NO_FILE_LOADED", + "message": "No XER file is loaded. Use the load_xer tool first.", + } + } + + return { + "database": db.get_database_info(), + } diff --git a/src/xer_mcp/tools/load_xer.py b/src/xer_mcp/tools/load_xer.py index fe281d6..204d32d 100644 --- a/src/xer_mcp/tools/load_xer.py +++ b/src/xer_mcp/tools/load_xer.py @@ -1,5 +1,9 @@ """load_xer MCP tool implementation.""" +import errno +import os +import sqlite3 + from xer_mcp.db import db from xer_mcp.db.loader import get_activity_count, get_relationship_count, load_parsed_data from xer_mcp.errors import FileNotFoundError, ParseError @@ -7,19 +11,55 @@ from xer_mcp.parser.xer_parser import XerParser from xer_mcp.server import set_file_loaded -async def load_xer(file_path: str, project_id: str | None = None) -> dict: +async def load_xer( + file_path: str, + project_id: str | None = None, + db_path: str | None = None, +) -> dict: """Load a Primavera P6 XER file and parse its schedule data. Args: file_path: Absolute path to the XER file project_id: Project ID to select (required for multi-project files) + db_path: Path for persistent database file. If None, uses in-memory. + If empty string, auto-generates from XER filename. Returns: Dictionary with success status and project info or error details """ - # Ensure database is initialized - if not db.is_initialized: - db.initialize() + # Initialize database with specified path + try: + db.initialize(db_path=db_path, source_file=file_path) + except PermissionError: + target = db_path if db_path else file_path + return { + "success": False, + "error": {"code": "FILE_NOT_WRITABLE", "message": f"Cannot write database: {target}"}, + } + except OSError as e: + target = db_path if db_path else file_path + if e.errno == errno.ENOSPC: + return { + "success": False, + "error": {"code": "DISK_FULL", "message": f"Insufficient disk space: {target}"}, + } + if e.errno == errno.ENOENT: + return { + "success": False, + "error": { + "code": "FILE_NOT_WRITABLE", + "message": f"Directory does not exist: {os.path.dirname(target)}", + }, + } + return { + "success": False, + "error": {"code": "DATABASE_ERROR", "message": str(e)}, + } + except sqlite3.Error as e: + return { + "success": False, + "error": {"code": "DATABASE_ERROR", "message": str(e)}, + } parser = XerParser() @@ -73,6 +113,9 @@ async def load_xer(file_path: str, project_id: str | None = None) -> dict: activity_count = get_activity_count() relationship_count = get_relationship_count() + # Get database info + database_info = db.get_database_info() + return { "success": True, "project": { @@ -83,4 +126,5 @@ async def load_xer(file_path: str, project_id: str | None = None) -> dict: }, "activity_count": activity_count, "relationship_count": relationship_count, + "database": database_info, } diff --git a/tests/contract/test_get_database_info.py b/tests/contract/test_get_database_info.py new file mode 100644 index 0000000..dd4d5d1 --- /dev/null +++ b/tests/contract/test_get_database_info.py @@ -0,0 +1,143 @@ +"""Contract tests for get_database_info MCP tool.""" + +from pathlib import Path + +import pytest + +from xer_mcp.db import db + + +@pytest.fixture(autouse=True) +def setup_db(): + """Reset database state for each test.""" + if db.is_initialized: + db.close() + yield + if db.is_initialized: + db.close() + + +class TestGetDatabaseInfoContract: + """Contract tests verifying get_database_info tool interface.""" + + async def test_get_database_info_returns_current_database( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """get_database_info returns info about currently loaded database.""" + from xer_mcp.tools.get_database_info import get_database_info + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + result = await get_database_info() + + assert "database" in result + assert result["database"]["db_path"] == str(db_file) + assert result["database"]["is_persistent"] is True + + async def test_get_database_info_error_when_no_database(self) -> None: + """get_database_info returns error when no database loaded.""" + from xer_mcp.tools.get_database_info import get_database_info + + # Ensure database is not initialized + if db.is_initialized: + db.close() + + result = await get_database_info() + + assert "error" in result + assert result["error"]["code"] == "NO_FILE_LOADED" + + async def test_get_database_info_includes_schema( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """get_database_info includes schema information.""" + from xer_mcp.tools.get_database_info import get_database_info + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + result = await get_database_info() + + assert "schema" in result["database"] + assert "tables" in result["database"]["schema"] + + async def test_get_database_info_includes_loaded_at( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """get_database_info includes loaded_at timestamp.""" + from xer_mcp.tools.get_database_info import get_database_info + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + result = await get_database_info() + + assert "loaded_at" in result["database"] + # Should be ISO format timestamp + assert "T" in result["database"]["loaded_at"] + + async def test_get_database_info_includes_source_file( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """get_database_info includes source XER file path.""" + from xer_mcp.tools.get_database_info import get_database_info + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + result = await get_database_info() + + assert result["database"]["source_file"] == str(sample_xer_single_project) + + async def test_get_database_info_for_memory_database( + self, sample_xer_single_project: Path + ) -> None: + """get_database_info works for in-memory database.""" + from xer_mcp.tools.get_database_info import get_database_info + from xer_mcp.tools.load_xer import load_xer + + await load_xer(file_path=str(sample_xer_single_project)) + + result = await get_database_info() + + assert "database" in result + assert result["database"]["db_path"] == ":memory:" + assert result["database"]["is_persistent"] is False + + +class TestGetDatabaseInfoToolSchema: + """Tests for MCP tool schema.""" + + async def test_get_database_info_tool_registered(self) -> None: + """get_database_info tool is registered with MCP server.""" + from xer_mcp.server import list_tools + + tools = await list_tools() + tool_names = [t.name for t in tools] + assert "get_database_info" in tool_names + + async def test_get_database_info_tool_has_empty_input_schema(self) -> None: + """get_database_info tool has no required inputs.""" + from xer_mcp.server import list_tools + + tools = await list_tools() + tool = next(t for t in tools if t.name == "get_database_info") + # Should have empty or no required properties + assert "required" not in tool.inputSchema or len(tool.inputSchema.get("required", [])) == 0 diff --git a/tests/contract/test_load_xer.py b/tests/contract/test_load_xer.py index 6f1eafa..3514d4d 100644 --- a/tests/contract/test_load_xer.py +++ b/tests/contract/test_load_xer.py @@ -1,5 +1,6 @@ """Contract tests for load_xer MCP tool.""" +import sqlite3 from pathlib import Path import pytest @@ -9,10 +10,14 @@ from xer_mcp.db import db @pytest.fixture(autouse=True) def setup_db(): - """Initialize and clear database for each test.""" - db.initialize() + """Reset database state for each test.""" + # Close any existing connection + if db.is_initialized: + db.close() yield - db.clear() + # Cleanup after test + if db.is_initialized: + db.close() class TestLoadXerContract: @@ -105,3 +110,141 @@ class TestLoadXerContract: assert "plan_end_date" in result["project"] # Dates should be ISO8601 format assert "T" in result["project"]["plan_start_date"] + + +class TestLoadXerPersistentDatabase: + """Contract tests for persistent database functionality.""" + + async def test_load_xer_with_db_path_creates_file( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """load_xer with db_path creates persistent database file.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + assert result["success"] is True + assert db_file.exists() + assert result["database"]["db_path"] == str(db_file) + assert result["database"]["is_persistent"] is True + + async def test_load_xer_with_empty_db_path_auto_generates(self, tmp_path: Path) -> None: + """load_xer with empty db_path generates path from XER filename.""" + from xer_mcp.tools.load_xer import load_xer + + # Create XER file in tmp_path + xer_file = tmp_path / "my_schedule.xer" + from tests.conftest import SAMPLE_XER_SINGLE_PROJECT + + xer_file.write_text(SAMPLE_XER_SINGLE_PROJECT) + + result = await load_xer(file_path=str(xer_file), db_path="") + + assert result["success"] is True + expected_db = str(tmp_path / "my_schedule.sqlite") + assert result["database"]["db_path"] == expected_db + assert result["database"]["is_persistent"] is True + assert Path(expected_db).exists() + + async def test_load_xer_without_db_path_uses_memory( + self, sample_xer_single_project: Path + ) -> None: + """load_xer without db_path uses in-memory database (backward compatible).""" + from xer_mcp.tools.load_xer import load_xer + + result = await load_xer(file_path=str(sample_xer_single_project)) + + assert result["success"] is True + assert result["database"]["db_path"] == ":memory:" + assert result["database"]["is_persistent"] is False + + async def test_load_xer_database_contains_all_data( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """Persistent database contains all parsed data.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + # Verify data via direct SQL + conn = sqlite3.connect(str(db_file)) + cursor = conn.execute("SELECT COUNT(*) FROM activities") + count = cursor.fetchone()[0] + conn.close() + + assert count == result["activity_count"] + + async def test_load_xer_response_includes_database_info( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """load_xer response includes complete database info.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + assert "database" in result + db_info = result["database"] + assert "db_path" in db_info + assert "is_persistent" in db_info + assert "source_file" in db_info + assert "loaded_at" in db_info + assert "schema" in db_info + + async def test_load_xer_response_schema_includes_tables( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """load_xer response schema includes table information.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + schema = result["database"]["schema"] + assert "version" in schema + assert "tables" in schema + table_names = [t["name"] for t in schema["tables"]] + assert "activities" in table_names + assert "relationships" in table_names + + async def test_load_xer_error_on_invalid_path(self, sample_xer_single_project: Path) -> None: + """load_xer returns error for invalid path.""" + from xer_mcp.tools.load_xer import load_xer + + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path="/nonexistent/dir/file.db", + ) + + assert result["success"] is False + # Either FILE_NOT_WRITABLE or DATABASE_ERROR is acceptable + # depending on how SQLite reports the error + assert result["error"]["code"] in ("FILE_NOT_WRITABLE", "DATABASE_ERROR") + + +class TestLoadXerToolSchema: + """Tests for MCP tool schema.""" + + async def test_load_xer_tool_schema_includes_db_path(self) -> None: + """MCP tool schema includes db_path parameter.""" + from xer_mcp.server import list_tools + + tools = await list_tools() + load_xer_tool = next(t for t in tools if t.name == "load_xer") + props = load_xer_tool.inputSchema["properties"] + assert "db_path" in props + assert props["db_path"]["type"] == "string" diff --git a/tests/integration/test_direct_db_access.py b/tests/integration/test_direct_db_access.py new file mode 100644 index 0000000..298af0c --- /dev/null +++ b/tests/integration/test_direct_db_access.py @@ -0,0 +1,185 @@ +"""Integration tests for direct database access feature.""" + +import sqlite3 +from pathlib import Path + +import pytest + +from xer_mcp.db import db + + +@pytest.fixture(autouse=True) +def setup_db(): + """Reset database state for each test.""" + if db.is_initialized: + db.close() + yield + if db.is_initialized: + db.close() + + +class TestDirectDatabaseAccess: + """Integration tests verifying external script can access database.""" + + async def test_external_script_can_query_database( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """External script can query database using returned path.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + # Simulate external script access (as shown in quickstart.md) + db_path = result["database"]["db_path"] + + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + + # Query milestones + cursor = conn.execute(""" + SELECT task_code, task_name, target_start_date, milestone_type + FROM activities + WHERE task_type IN ('TT_Mile', 'TT_FinMile') + ORDER BY target_start_date + """) + + milestones = cursor.fetchall() + conn.close() + + assert len(milestones) > 0 + assert all(row["task_code"] for row in milestones) + + async def test_external_script_can_query_critical_path( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """External script can query critical path activities.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + db_path = result["database"]["db_path"] + + conn = sqlite3.connect(db_path) + cursor = conn.execute(""" + SELECT task_code, task_name, target_start_date, target_end_date + FROM activities + WHERE driving_path_flag = 1 + ORDER BY target_start_date + """) + + critical_activities = cursor.fetchall() + conn.close() + + assert len(critical_activities) > 0 + + async def test_external_script_can_join_tables( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """External script can join activities with WBS.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + db_path = result["database"]["db_path"] + + conn = sqlite3.connect(db_path) + cursor = conn.execute(""" + SELECT a.task_code, a.task_name, w.wbs_name + FROM activities a + JOIN wbs w ON a.wbs_id = w.wbs_id + LIMIT 10 + """) + + joined_rows = cursor.fetchall() + conn.close() + + assert len(joined_rows) > 0 + + async def test_database_accessible_after_mcp_load( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """Database remains accessible while MCP tools are active.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + loaded_count = result["activity_count"] + + # External script queries database + conn = sqlite3.connect(str(db_file)) + cursor = conn.execute("SELECT COUNT(*) FROM activities") + external_count = cursor.fetchone()[0] + conn.close() + + # Both should match + assert external_count == loaded_count + + async def test_schema_info_matches_actual_database( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """Returned schema info matches actual database structure.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + schema = result["database"]["schema"] + db_path = result["database"]["db_path"] + + # Verify tables exist in actual database + conn = sqlite3.connect(db_path) + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'" + ) + actual_tables = {row[0] for row in cursor.fetchall()} + conn.close() + + schema_tables = {t["name"] for t in schema["tables"]} + assert schema_tables == actual_tables + + async def test_row_counts_match_actual_data( + self, tmp_path: Path, sample_xer_single_project: Path + ) -> None: + """Schema row counts match actual database row counts.""" + from xer_mcp.tools.load_xer import load_xer + + db_file = tmp_path / "schedule.db" + result = await load_xer( + file_path=str(sample_xer_single_project), + db_path=str(db_file), + ) + + schema = result["database"]["schema"] + db_path = result["database"]["db_path"] + + conn = sqlite3.connect(db_path) + + for table_info in schema["tables"]: + cursor = conn.execute( + f"SELECT COUNT(*) FROM {table_info['name']}" # noqa: S608 + ) + actual_count = cursor.fetchone()[0] + assert table_info["row_count"] == actual_count, ( + f"Table {table_info['name']}: expected {table_info['row_count']}, " + f"got {actual_count}" + ) + + conn.close() diff --git a/tests/unit/test_db_manager.py b/tests/unit/test_db_manager.py new file mode 100644 index 0000000..140bf25 --- /dev/null +++ b/tests/unit/test_db_manager.py @@ -0,0 +1,234 @@ +"""Unit tests for DatabaseManager file-based database support.""" + +import sqlite3 +from datetime import datetime +from pathlib import Path + +from xer_mcp.db import DatabaseManager + + +class TestDatabaseManagerInitialization: + """Tests for DatabaseManager initialization modes.""" + + def test_initialize_with_memory_by_default(self) -> None: + """Default initialization uses in-memory database.""" + dm = DatabaseManager() + dm.initialize() + assert dm.db_path == ":memory:" + assert dm.is_persistent is False + dm.close() + + def test_initialize_with_file_path(self, tmp_path: Path) -> None: + """Can initialize with explicit file path.""" + db_file = tmp_path / "test.db" + dm = DatabaseManager() + dm.initialize(db_path=str(db_file)) + assert dm.db_path == str(db_file) + assert dm.is_persistent is True + assert db_file.exists() + dm.close() + + def test_initialize_with_empty_string_auto_generates_path(self, tmp_path: Path) -> None: + """Empty string db_path with source_file auto-generates path.""" + xer_file = tmp_path / "schedule.xer" + xer_file.write_text("dummy content") + + dm = DatabaseManager() + dm.initialize(db_path="", source_file=str(xer_file)) + expected_db = str(tmp_path / "schedule.sqlite") + assert dm.db_path == expected_db + assert dm.is_persistent is True + assert Path(expected_db).exists() + dm.close() + + def test_file_database_persists_after_close(self, tmp_path: Path) -> None: + """File-based database persists after connection close.""" + db_file = tmp_path / "persist_test.db" + dm = DatabaseManager() + dm.initialize(db_path=str(db_file)) + + # Insert test data + with dm.cursor() as cur: + cur.execute( + "INSERT INTO projects (proj_id, proj_short_name, loaded_at) " + "VALUES ('P1', 'Test', datetime('now'))" + ) + dm.commit() + dm.close() + + # Verify file exists and has data + assert db_file.exists() + conn = sqlite3.connect(str(db_file)) + cursor = conn.execute("SELECT proj_id FROM projects") + rows = cursor.fetchall() + conn.close() + assert len(rows) == 1 + assert rows[0][0] == "P1" + + def test_source_file_tracked(self, tmp_path: Path) -> None: + """Source file path is tracked when provided.""" + db_file = tmp_path / "test.db" + xer_file = tmp_path / "schedule.xer" + xer_file.write_text("dummy") + + dm = DatabaseManager() + dm.initialize(db_path=str(db_file), source_file=str(xer_file)) + assert dm.source_file == str(xer_file) + dm.close() + + def test_loaded_at_timestamp(self, tmp_path: Path) -> None: + """Loaded_at timestamp is recorded.""" + db_file = tmp_path / "test.db" + dm = DatabaseManager() + before = datetime.now() + dm.initialize(db_path=str(db_file)) + after = datetime.now() + + loaded_at = dm.loaded_at + assert loaded_at is not None + assert before <= loaded_at <= after + dm.close() + + def test_memory_database_not_persistent(self) -> None: + """In-memory database is not persistent.""" + dm = DatabaseManager() + dm.initialize() + assert dm.is_persistent is False + assert dm.db_path == ":memory:" + dm.close() + + +class TestDatabaseManagerWalMode: + """Tests for WAL mode in file-based databases.""" + + def test_file_database_uses_wal_mode(self, tmp_path: Path) -> None: + """File-based database uses WAL mode for concurrent access.""" + db_file = tmp_path / "wal_test.db" + dm = DatabaseManager() + dm.initialize(db_path=str(db_file)) + + with dm.cursor() as cur: + cur.execute("PRAGMA journal_mode") + mode = cur.fetchone()[0] + assert mode.lower() == "wal" + dm.close() + + def test_memory_database_does_not_use_wal(self) -> None: + """In-memory database doesn't use WAL mode (not applicable).""" + dm = DatabaseManager() + dm.initialize() + + with dm.cursor() as cur: + cur.execute("PRAGMA journal_mode") + mode = cur.fetchone()[0] + # Memory databases use 'memory' journal mode + assert mode.lower() == "memory" + dm.close() + + +class TestAtomicWrite: + """Tests for atomic write pattern.""" + + def test_atomic_write_creates_final_file(self, tmp_path: Path) -> None: + """Database is created at final path after initialization.""" + target = tmp_path / "atomic_test.db" + dm = DatabaseManager() + dm.initialize(db_path=str(target)) + assert target.exists() + assert not Path(str(target) + ".tmp").exists() + dm.close() + + def test_atomic_write_no_temp_file_remains(self, tmp_path: Path) -> None: + """No .tmp file remains after successful initialization.""" + target = tmp_path / "atomic_clean.db" + dm = DatabaseManager() + dm.initialize(db_path=str(target)) + dm.close() + + # Check no temp files remain + temp_files = list(tmp_path.glob("*.tmp")) + assert len(temp_files) == 0 + + +class TestSchemaIntrospection: + """Tests for database schema introspection.""" + + def test_get_schema_info_returns_all_tables(self) -> None: + """Schema info includes all database tables.""" + dm = DatabaseManager() + dm.initialize() + schema = dm.get_schema_info() + + assert schema["version"] == "0.2.0" + table_names = [t["name"] for t in schema["tables"]] + assert "projects" in table_names + assert "activities" in table_names + assert "relationships" in table_names + assert "wbs" in table_names + assert "calendars" in table_names + dm.close() + + def test_get_schema_info_includes_column_details(self) -> None: + """Schema info includes column names, types, and nullable.""" + dm = DatabaseManager() + dm.initialize() + schema = dm.get_schema_info() + + activities_table = next(t for t in schema["tables"] if t["name"] == "activities") + column_names = [c["name"] for c in activities_table["columns"]] + assert "task_id" in column_names + assert "task_name" in column_names + + # Check column details + task_id_col = next(c for c in activities_table["columns"] if c["name"] == "task_id") + assert task_id_col["type"] == "TEXT" + # Note: SQLite reports PRIMARY KEY TEXT columns as nullable + # but the PRIMARY KEY constraint still applies + assert "nullable" in task_id_col + + # Check a NOT NULL column + task_name_col = next(c for c in activities_table["columns"] if c["name"] == "task_name") + assert task_name_col["nullable"] is False + dm.close() + + def test_get_schema_info_includes_row_counts(self) -> None: + """Schema info includes row counts for each table.""" + dm = DatabaseManager() + dm.initialize() + schema = dm.get_schema_info() + + for table in schema["tables"]: + assert "row_count" in table + assert isinstance(table["row_count"], int) + assert table["row_count"] >= 0 + dm.close() + + def test_schema_info_includes_primary_keys(self) -> None: + """Schema info includes primary key for each table.""" + dm = DatabaseManager() + dm.initialize() + schema = dm.get_schema_info() + + activities_table = next(t for t in schema["tables"] if t["name"] == "activities") + assert "primary_key" in activities_table + assert "task_id" in activities_table["primary_key"] + dm.close() + + def test_schema_info_includes_foreign_keys(self) -> None: + """Schema info includes foreign key relationships.""" + dm = DatabaseManager() + dm.initialize() + schema = dm.get_schema_info() + + activities_table = next(t for t in schema["tables"] if t["name"] == "activities") + assert "foreign_keys" in activities_table + + # activities.proj_id -> projects.proj_id + fk = next( + (fk for fk in activities_table["foreign_keys"] if fk["column"] == "proj_id"), + None, + ) + assert fk is not None + assert fk["references_table"] == "projects" + assert fk["references_column"] == "proj_id" + dm.close()