feat: add direct database access for scripts (v0.2.0)
Implement persistent SQLite database feature that allows scripts to query schedule data directly via SQL after loading XER files through MCP. Key changes: - Extend load_xer with db_path parameter for persistent database - Add get_database_info tool to retrieve database connection details - Add schema introspection with tables, columns, primary/foreign keys - Support WAL mode for concurrent read access - Use atomic write pattern to prevent corruption New features: - db_path=None: in-memory database (default, backward compatible) - db_path="": auto-generate path from XER filename (.sqlite extension) - db_path="/path/to/db": explicit persistent database path Response includes complete DatabaseInfo: - db_path: absolute path (or :memory:) - is_persistent: boolean - source_file: loaded XER path - loaded_at: ISO timestamp - schema: tables with columns, primary keys, foreign keys, row counts Closes: User Story 1, 2, 3 from 002-direct-db-access spec
This commit is contained in:
@@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "xer-mcp"
|
name = "xer-mcp"
|
||||||
version = "0.1.0"
|
version = "0.2.0"
|
||||||
description = "MCP server for querying Primavera P6 XER schedule data"
|
description = "MCP server for querying Primavera P6 XER schedule data"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.14"
|
requires-python = ">=3.14"
|
||||||
|
|||||||
@@ -1,24 +1,66 @@
|
|||||||
"""Database connection management for XER MCP Server."""
|
"""Database connection management for XER MCP Server."""
|
||||||
|
|
||||||
|
import os
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from xer_mcp.db.schema import get_schema
|
from xer_mcp.db.schema import get_schema
|
||||||
|
|
||||||
|
# Schema version for introspection responses
|
||||||
|
SCHEMA_VERSION = "0.2.0"
|
||||||
|
|
||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
"""Manages SQLite database connections and schema initialization."""
|
"""Manages SQLite database connections and schema initialization."""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize database manager with in-memory database."""
|
"""Initialize database manager."""
|
||||||
self._connection: sqlite3.Connection | None = None
|
self._connection: sqlite3.Connection | None = None
|
||||||
|
self._db_path: str = ":memory:"
|
||||||
|
self._source_file: str | None = None
|
||||||
|
self._loaded_at: datetime | None = None
|
||||||
|
|
||||||
def initialize(self) -> None:
|
def initialize(
|
||||||
"""Initialize the in-memory database with schema."""
|
self,
|
||||||
self._connection = sqlite3.connect(":memory:", check_same_thread=False)
|
db_path: str | None = None,
|
||||||
self._connection.row_factory = sqlite3.Row
|
source_file: str | None = None,
|
||||||
self._connection.executescript(get_schema())
|
) -> None:
|
||||||
|
"""Initialize the database with schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path for database file. If None or omitted, uses in-memory.
|
||||||
|
If empty string, auto-generates from source_file.
|
||||||
|
source_file: Path to the XER file being loaded (for tracking).
|
||||||
|
"""
|
||||||
|
self._source_file = source_file
|
||||||
|
self._loaded_at = datetime.now()
|
||||||
|
|
||||||
|
# Determine database path
|
||||||
|
if db_path is None:
|
||||||
|
# Default: in-memory database
|
||||||
|
self._db_path = ":memory:"
|
||||||
|
elif db_path == "":
|
||||||
|
# Auto-generate from source file
|
||||||
|
if source_file:
|
||||||
|
base = Path(source_file).with_suffix(".sqlite")
|
||||||
|
self._db_path = str(base)
|
||||||
|
else:
|
||||||
|
self._db_path = ":memory:"
|
||||||
|
else:
|
||||||
|
# Use provided path
|
||||||
|
self._db_path = db_path
|
||||||
|
|
||||||
|
# Create database
|
||||||
|
if self._db_path == ":memory:":
|
||||||
|
self._connection = sqlite3.connect(":memory:", check_same_thread=False)
|
||||||
|
self._connection.row_factory = sqlite3.Row
|
||||||
|
self._connection.executescript(get_schema())
|
||||||
|
else:
|
||||||
|
# File-based database with atomic write pattern
|
||||||
|
self._create_file_database()
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all data from the database."""
|
"""Clear all data from the database."""
|
||||||
@@ -61,6 +103,159 @@ class DatabaseManager:
|
|||||||
"""Check if the database is initialized."""
|
"""Check if the database is initialized."""
|
||||||
return self._connection is not None
|
return self._connection is not None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db_path(self) -> str:
|
||||||
|
"""Get the database path."""
|
||||||
|
return self._db_path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_persistent(self) -> bool:
|
||||||
|
"""Check if the database is file-based (persistent)."""
|
||||||
|
return self._db_path != ":memory:"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def source_file(self) -> str | None:
|
||||||
|
"""Get the source XER file path."""
|
||||||
|
return self._source_file
|
||||||
|
|
||||||
|
@property
|
||||||
|
def loaded_at(self) -> datetime | None:
|
||||||
|
"""Get the timestamp when data was loaded."""
|
||||||
|
return self._loaded_at
|
||||||
|
|
||||||
|
def _create_file_database(self) -> None:
|
||||||
|
"""Create a file-based database with atomic write pattern."""
|
||||||
|
temp_path = self._db_path + ".tmp"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create database at temp path
|
||||||
|
conn = sqlite3.connect(temp_path, check_same_thread=False)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.executescript(get_schema())
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Atomic rename (POSIX)
|
||||||
|
if os.path.exists(self._db_path):
|
||||||
|
os.unlink(self._db_path)
|
||||||
|
os.rename(temp_path, self._db_path)
|
||||||
|
|
||||||
|
# Open final database with WAL mode
|
||||||
|
self._connection = sqlite3.connect(self._db_path, check_same_thread=False)
|
||||||
|
self._connection.row_factory = sqlite3.Row
|
||||||
|
self._connection.execute("PRAGMA journal_mode=WAL")
|
||||||
|
except Exception:
|
||||||
|
# Clean up temp file on failure
|
||||||
|
if os.path.exists(temp_path):
|
||||||
|
os.unlink(temp_path)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_schema_info(self) -> dict:
|
||||||
|
"""Get database schema information for introspection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with schema version and table information.
|
||||||
|
"""
|
||||||
|
if self._connection is None:
|
||||||
|
raise RuntimeError("Database not initialized. Call initialize() first.")
|
||||||
|
|
||||||
|
tables = []
|
||||||
|
|
||||||
|
# Get all tables (excluding sqlite internal tables)
|
||||||
|
cursor = self._connection.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' "
|
||||||
|
"AND name NOT LIKE 'sqlite_%' ORDER BY name"
|
||||||
|
)
|
||||||
|
table_names = [row[0] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
for table_name in table_names:
|
||||||
|
table_info = self._get_table_info(table_name)
|
||||||
|
tables.append(table_info)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"version": SCHEMA_VERSION,
|
||||||
|
"tables": tables,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_table_info(self, table_name: str) -> dict:
|
||||||
|
"""Get detailed information about a table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Name of the table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with table name, columns, primary keys, foreign keys, row count.
|
||||||
|
"""
|
||||||
|
columns = []
|
||||||
|
primary_key = []
|
||||||
|
|
||||||
|
# Get column info
|
||||||
|
cursor = self._connection.execute(f"PRAGMA table_info({table_name})") # noqa: S608
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
col_name = row[1]
|
||||||
|
col_type = row[2] or "TEXT"
|
||||||
|
not_null = bool(row[3])
|
||||||
|
default_val = row[4]
|
||||||
|
is_pk = bool(row[5])
|
||||||
|
|
||||||
|
col_info: dict = {
|
||||||
|
"name": col_name,
|
||||||
|
"type": col_type,
|
||||||
|
"nullable": not not_null,
|
||||||
|
}
|
||||||
|
if default_val is not None:
|
||||||
|
col_info["default"] = str(default_val)
|
||||||
|
|
||||||
|
columns.append(col_info)
|
||||||
|
|
||||||
|
if is_pk:
|
||||||
|
primary_key.append(col_name)
|
||||||
|
|
||||||
|
# Get foreign keys
|
||||||
|
foreign_keys = []
|
||||||
|
cursor = self._connection.execute(
|
||||||
|
f"PRAGMA foreign_key_list({table_name})" # noqa: S608
|
||||||
|
)
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
fk_info = {
|
||||||
|
"column": row[3], # from column
|
||||||
|
"references_table": row[2], # table
|
||||||
|
"references_column": row[4], # to column
|
||||||
|
}
|
||||||
|
foreign_keys.append(fk_info)
|
||||||
|
|
||||||
|
# Get row count
|
||||||
|
cursor = self._connection.execute(
|
||||||
|
f"SELECT COUNT(*) FROM {table_name}" # noqa: S608
|
||||||
|
)
|
||||||
|
row_count = cursor.fetchone()[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"name": table_name,
|
||||||
|
"columns": columns,
|
||||||
|
"primary_key": primary_key,
|
||||||
|
"foreign_keys": foreign_keys,
|
||||||
|
"row_count": row_count,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_database_info(self) -> dict:
|
||||||
|
"""Get complete database information for API responses.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with database path, persistence status, source file,
|
||||||
|
loaded timestamp, and schema information.
|
||||||
|
"""
|
||||||
|
if not self.is_initialized:
|
||||||
|
raise RuntimeError("Database not initialized. Call initialize() first.")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"db_path": self._db_path,
|
||||||
|
"is_persistent": self.is_persistent,
|
||||||
|
"source_file": self._source_file,
|
||||||
|
"loaded_at": self._loaded_at.isoformat() if self._loaded_at else None,
|
||||||
|
"schema": self.get_schema_info(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Global database manager instance
|
# Global database manager instance
|
||||||
db = DatabaseManager()
|
db = DatabaseManager()
|
||||||
|
|||||||
@@ -239,16 +239,18 @@ def query_relationships(
|
|||||||
lag_hours=row[6] or 0.0,
|
lag_hours=row[6] or 0.0,
|
||||||
pred_type=pred_type,
|
pred_type=pred_type,
|
||||||
)
|
)
|
||||||
relationships.append({
|
relationships.append(
|
||||||
"task_pred_id": row[0],
|
{
|
||||||
"task_id": row[1],
|
"task_pred_id": row[0],
|
||||||
"task_name": row[2],
|
"task_id": row[1],
|
||||||
"pred_task_id": row[3],
|
"task_name": row[2],
|
||||||
"pred_task_name": row[4],
|
"pred_task_id": row[3],
|
||||||
"pred_type": pred_type,
|
"pred_task_name": row[4],
|
||||||
"lag_hr_cnt": row[6],
|
"pred_type": pred_type,
|
||||||
"driving": driving,
|
"lag_hr_cnt": row[6],
|
||||||
})
|
"driving": driving,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return relationships, total
|
return relationships, total
|
||||||
|
|
||||||
@@ -298,14 +300,16 @@ def get_predecessors(activity_id: str) -> list[dict]:
|
|||||||
lag_hours=row[4] or 0.0,
|
lag_hours=row[4] or 0.0,
|
||||||
pred_type=pred_type,
|
pred_type=pred_type,
|
||||||
)
|
)
|
||||||
result.append({
|
result.append(
|
||||||
"task_id": row[0],
|
{
|
||||||
"task_code": row[1],
|
"task_id": row[0],
|
||||||
"task_name": row[2],
|
"task_code": row[1],
|
||||||
"relationship_type": pred_type,
|
"task_name": row[2],
|
||||||
"lag_hr_cnt": row[4],
|
"relationship_type": pred_type,
|
||||||
"driving": driving,
|
"lag_hr_cnt": row[4],
|
||||||
})
|
"driving": driving,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -355,14 +359,16 @@ def get_successors(activity_id: str) -> list[dict]:
|
|||||||
lag_hours=row[4] or 0.0,
|
lag_hours=row[4] or 0.0,
|
||||||
pred_type=pred_type,
|
pred_type=pred_type,
|
||||||
)
|
)
|
||||||
result.append({
|
result.append(
|
||||||
"task_id": row[0],
|
{
|
||||||
"task_code": row[1],
|
"task_id": row[0],
|
||||||
"task_name": row[2],
|
"task_code": row[1],
|
||||||
"relationship_type": pred_type,
|
"task_name": row[2],
|
||||||
"lag_hr_cnt": row[4],
|
"relationship_type": pred_type,
|
||||||
"driving": driving,
|
"lag_hr_cnt": row[4],
|
||||||
})
|
"driving": driving,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -56,3 +56,30 @@ class ActivityNotFoundError(XerMcpError):
|
|||||||
"ACTIVITY_NOT_FOUND",
|
"ACTIVITY_NOT_FOUND",
|
||||||
f"Activity not found: {activity_id}",
|
f"Activity not found: {activity_id}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileNotWritableError(XerMcpError):
|
||||||
|
"""Raised when the database file path is not writable."""
|
||||||
|
|
||||||
|
def __init__(self, path: str, reason: str = "") -> None:
|
||||||
|
msg = f"Cannot write to database file: {path}"
|
||||||
|
if reason:
|
||||||
|
msg = f"{msg} ({reason})"
|
||||||
|
super().__init__("FILE_NOT_WRITABLE", msg)
|
||||||
|
|
||||||
|
|
||||||
|
class DiskFullError(XerMcpError):
|
||||||
|
"""Raised when there is insufficient disk space."""
|
||||||
|
|
||||||
|
def __init__(self, path: str) -> None:
|
||||||
|
super().__init__(
|
||||||
|
"DISK_FULL",
|
||||||
|
f"Insufficient disk space to create database: {path}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseError(XerMcpError):
|
||||||
|
"""Raised for general database errors."""
|
||||||
|
|
||||||
|
def __init__(self, message: str) -> None:
|
||||||
|
super().__init__("DATABASE_ERROR", message)
|
||||||
|
|||||||
@@ -50,6 +50,12 @@ async def list_tools() -> list[Tool]:
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "Project ID to select (required for multi-project files)",
|
"description": "Project ID to select (required for multi-project files)",
|
||||||
},
|
},
|
||||||
|
"db_path": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Path for persistent SQLite database file. "
|
||||||
|
"If omitted, uses in-memory database. "
|
||||||
|
"If empty string, auto-generates path from XER filename (same directory, .sqlite extension).",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"required": ["file_path"],
|
"required": ["file_path"],
|
||||||
},
|
},
|
||||||
@@ -183,6 +189,15 @@ async def list_tools() -> list[Tool]:
|
|||||||
"properties": {},
|
"properties": {},
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
|
Tool(
|
||||||
|
name="get_database_info",
|
||||||
|
description="Get information about the currently loaded database including file path and schema. "
|
||||||
|
"Use this to get connection details for direct SQL access.",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {},
|
||||||
|
},
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -197,6 +212,7 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|||||||
result = await load_xer(
|
result = await load_xer(
|
||||||
file_path=arguments["file_path"],
|
file_path=arguments["file_path"],
|
||||||
project_id=arguments.get("project_id"),
|
project_id=arguments.get("project_id"),
|
||||||
|
db_path=arguments.get("db_path"),
|
||||||
)
|
)
|
||||||
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||||
|
|
||||||
@@ -258,6 +274,12 @@ async def call_tool(name: str, arguments: dict) -> list[TextContent]:
|
|||||||
result = await get_critical_path()
|
result = await get_critical_path()
|
||||||
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||||
|
|
||||||
|
if name == "get_database_info":
|
||||||
|
from xer_mcp.tools.get_database_info import get_database_info
|
||||||
|
|
||||||
|
result = await get_database_info()
|
||||||
|
return [TextContent(type="text", text=json.dumps(result, indent=2))]
|
||||||
|
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
25
src/xer_mcp/tools/get_database_info.py
Normal file
25
src/xer_mcp/tools/get_database_info.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""get_database_info MCP tool implementation."""
|
||||||
|
|
||||||
|
from xer_mcp.db import db
|
||||||
|
|
||||||
|
|
||||||
|
async def get_database_info() -> dict:
|
||||||
|
"""Get information about the currently loaded database.
|
||||||
|
|
||||||
|
Returns connection details for direct SQL access including
|
||||||
|
database path, schema information, and metadata.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with database info or error if no database loaded
|
||||||
|
"""
|
||||||
|
if not db.is_initialized:
|
||||||
|
return {
|
||||||
|
"error": {
|
||||||
|
"code": "NO_FILE_LOADED",
|
||||||
|
"message": "No XER file is loaded. Use the load_xer tool first.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"database": db.get_database_info(),
|
||||||
|
}
|
||||||
@@ -1,5 +1,9 @@
|
|||||||
"""load_xer MCP tool implementation."""
|
"""load_xer MCP tool implementation."""
|
||||||
|
|
||||||
|
import errno
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
from xer_mcp.db import db
|
from xer_mcp.db import db
|
||||||
from xer_mcp.db.loader import get_activity_count, get_relationship_count, load_parsed_data
|
from xer_mcp.db.loader import get_activity_count, get_relationship_count, load_parsed_data
|
||||||
from xer_mcp.errors import FileNotFoundError, ParseError
|
from xer_mcp.errors import FileNotFoundError, ParseError
|
||||||
@@ -7,19 +11,55 @@ from xer_mcp.parser.xer_parser import XerParser
|
|||||||
from xer_mcp.server import set_file_loaded
|
from xer_mcp.server import set_file_loaded
|
||||||
|
|
||||||
|
|
||||||
async def load_xer(file_path: str, project_id: str | None = None) -> dict:
|
async def load_xer(
|
||||||
|
file_path: str,
|
||||||
|
project_id: str | None = None,
|
||||||
|
db_path: str | None = None,
|
||||||
|
) -> dict:
|
||||||
"""Load a Primavera P6 XER file and parse its schedule data.
|
"""Load a Primavera P6 XER file and parse its schedule data.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: Absolute path to the XER file
|
file_path: Absolute path to the XER file
|
||||||
project_id: Project ID to select (required for multi-project files)
|
project_id: Project ID to select (required for multi-project files)
|
||||||
|
db_path: Path for persistent database file. If None, uses in-memory.
|
||||||
|
If empty string, auto-generates from XER filename.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dictionary with success status and project info or error details
|
Dictionary with success status and project info or error details
|
||||||
"""
|
"""
|
||||||
# Ensure database is initialized
|
# Initialize database with specified path
|
||||||
if not db.is_initialized:
|
try:
|
||||||
db.initialize()
|
db.initialize(db_path=db_path, source_file=file_path)
|
||||||
|
except PermissionError:
|
||||||
|
target = db_path if db_path else file_path
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": {"code": "FILE_NOT_WRITABLE", "message": f"Cannot write database: {target}"},
|
||||||
|
}
|
||||||
|
except OSError as e:
|
||||||
|
target = db_path if db_path else file_path
|
||||||
|
if e.errno == errno.ENOSPC:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": {"code": "DISK_FULL", "message": f"Insufficient disk space: {target}"},
|
||||||
|
}
|
||||||
|
if e.errno == errno.ENOENT:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": {
|
||||||
|
"code": "FILE_NOT_WRITABLE",
|
||||||
|
"message": f"Directory does not exist: {os.path.dirname(target)}",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": {"code": "DATABASE_ERROR", "message": str(e)},
|
||||||
|
}
|
||||||
|
except sqlite3.Error as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": {"code": "DATABASE_ERROR", "message": str(e)},
|
||||||
|
}
|
||||||
|
|
||||||
parser = XerParser()
|
parser = XerParser()
|
||||||
|
|
||||||
@@ -73,6 +113,9 @@ async def load_xer(file_path: str, project_id: str | None = None) -> dict:
|
|||||||
activity_count = get_activity_count()
|
activity_count = get_activity_count()
|
||||||
relationship_count = get_relationship_count()
|
relationship_count = get_relationship_count()
|
||||||
|
|
||||||
|
# Get database info
|
||||||
|
database_info = db.get_database_info()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"project": {
|
"project": {
|
||||||
@@ -83,4 +126,5 @@ async def load_xer(file_path: str, project_id: str | None = None) -> dict:
|
|||||||
},
|
},
|
||||||
"activity_count": activity_count,
|
"activity_count": activity_count,
|
||||||
"relationship_count": relationship_count,
|
"relationship_count": relationship_count,
|
||||||
|
"database": database_info,
|
||||||
}
|
}
|
||||||
|
|||||||
143
tests/contract/test_get_database_info.py
Normal file
143
tests/contract/test_get_database_info.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""Contract tests for get_database_info MCP tool."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from xer_mcp.db import db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_db():
|
||||||
|
"""Reset database state for each test."""
|
||||||
|
if db.is_initialized:
|
||||||
|
db.close()
|
||||||
|
yield
|
||||||
|
if db.is_initialized:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDatabaseInfoContract:
|
||||||
|
"""Contract tests verifying get_database_info tool interface."""
|
||||||
|
|
||||||
|
async def test_get_database_info_returns_current_database(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""get_database_info returns info about currently loaded database."""
|
||||||
|
from xer_mcp.tools.get_database_info import get_database_info
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await get_database_info()
|
||||||
|
|
||||||
|
assert "database" in result
|
||||||
|
assert result["database"]["db_path"] == str(db_file)
|
||||||
|
assert result["database"]["is_persistent"] is True
|
||||||
|
|
||||||
|
async def test_get_database_info_error_when_no_database(self) -> None:
|
||||||
|
"""get_database_info returns error when no database loaded."""
|
||||||
|
from xer_mcp.tools.get_database_info import get_database_info
|
||||||
|
|
||||||
|
# Ensure database is not initialized
|
||||||
|
if db.is_initialized:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
result = await get_database_info()
|
||||||
|
|
||||||
|
assert "error" in result
|
||||||
|
assert result["error"]["code"] == "NO_FILE_LOADED"
|
||||||
|
|
||||||
|
async def test_get_database_info_includes_schema(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""get_database_info includes schema information."""
|
||||||
|
from xer_mcp.tools.get_database_info import get_database_info
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await get_database_info()
|
||||||
|
|
||||||
|
assert "schema" in result["database"]
|
||||||
|
assert "tables" in result["database"]["schema"]
|
||||||
|
|
||||||
|
async def test_get_database_info_includes_loaded_at(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""get_database_info includes loaded_at timestamp."""
|
||||||
|
from xer_mcp.tools.get_database_info import get_database_info
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await get_database_info()
|
||||||
|
|
||||||
|
assert "loaded_at" in result["database"]
|
||||||
|
# Should be ISO format timestamp
|
||||||
|
assert "T" in result["database"]["loaded_at"]
|
||||||
|
|
||||||
|
async def test_get_database_info_includes_source_file(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""get_database_info includes source XER file path."""
|
||||||
|
from xer_mcp.tools.get_database_info import get_database_info
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await get_database_info()
|
||||||
|
|
||||||
|
assert result["database"]["source_file"] == str(sample_xer_single_project)
|
||||||
|
|
||||||
|
async def test_get_database_info_for_memory_database(
|
||||||
|
self, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""get_database_info works for in-memory database."""
|
||||||
|
from xer_mcp.tools.get_database_info import get_database_info
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
await load_xer(file_path=str(sample_xer_single_project))
|
||||||
|
|
||||||
|
result = await get_database_info()
|
||||||
|
|
||||||
|
assert "database" in result
|
||||||
|
assert result["database"]["db_path"] == ":memory:"
|
||||||
|
assert result["database"]["is_persistent"] is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetDatabaseInfoToolSchema:
|
||||||
|
"""Tests for MCP tool schema."""
|
||||||
|
|
||||||
|
async def test_get_database_info_tool_registered(self) -> None:
|
||||||
|
"""get_database_info tool is registered with MCP server."""
|
||||||
|
from xer_mcp.server import list_tools
|
||||||
|
|
||||||
|
tools = await list_tools()
|
||||||
|
tool_names = [t.name for t in tools]
|
||||||
|
assert "get_database_info" in tool_names
|
||||||
|
|
||||||
|
async def test_get_database_info_tool_has_empty_input_schema(self) -> None:
|
||||||
|
"""get_database_info tool has no required inputs."""
|
||||||
|
from xer_mcp.server import list_tools
|
||||||
|
|
||||||
|
tools = await list_tools()
|
||||||
|
tool = next(t for t in tools if t.name == "get_database_info")
|
||||||
|
# Should have empty or no required properties
|
||||||
|
assert "required" not in tool.inputSchema or len(tool.inputSchema.get("required", [])) == 0
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Contract tests for load_xer MCP tool."""
|
"""Contract tests for load_xer MCP tool."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -9,10 +10,14 @@ from xer_mcp.db import db
|
|||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def setup_db():
|
def setup_db():
|
||||||
"""Initialize and clear database for each test."""
|
"""Reset database state for each test."""
|
||||||
db.initialize()
|
# Close any existing connection
|
||||||
|
if db.is_initialized:
|
||||||
|
db.close()
|
||||||
yield
|
yield
|
||||||
db.clear()
|
# Cleanup after test
|
||||||
|
if db.is_initialized:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
class TestLoadXerContract:
|
class TestLoadXerContract:
|
||||||
@@ -105,3 +110,141 @@ class TestLoadXerContract:
|
|||||||
assert "plan_end_date" in result["project"]
|
assert "plan_end_date" in result["project"]
|
||||||
# Dates should be ISO8601 format
|
# Dates should be ISO8601 format
|
||||||
assert "T" in result["project"]["plan_start_date"]
|
assert "T" in result["project"]["plan_start_date"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadXerPersistentDatabase:
|
||||||
|
"""Contract tests for persistent database functionality."""
|
||||||
|
|
||||||
|
async def test_load_xer_with_db_path_creates_file(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""load_xer with db_path creates persistent database file."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert db_file.exists()
|
||||||
|
assert result["database"]["db_path"] == str(db_file)
|
||||||
|
assert result["database"]["is_persistent"] is True
|
||||||
|
|
||||||
|
async def test_load_xer_with_empty_db_path_auto_generates(self, tmp_path: Path) -> None:
|
||||||
|
"""load_xer with empty db_path generates path from XER filename."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
# Create XER file in tmp_path
|
||||||
|
xer_file = tmp_path / "my_schedule.xer"
|
||||||
|
from tests.conftest import SAMPLE_XER_SINGLE_PROJECT
|
||||||
|
|
||||||
|
xer_file.write_text(SAMPLE_XER_SINGLE_PROJECT)
|
||||||
|
|
||||||
|
result = await load_xer(file_path=str(xer_file), db_path="")
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
expected_db = str(tmp_path / "my_schedule.sqlite")
|
||||||
|
assert result["database"]["db_path"] == expected_db
|
||||||
|
assert result["database"]["is_persistent"] is True
|
||||||
|
assert Path(expected_db).exists()
|
||||||
|
|
||||||
|
async def test_load_xer_without_db_path_uses_memory(
|
||||||
|
self, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""load_xer without db_path uses in-memory database (backward compatible)."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
result = await load_xer(file_path=str(sample_xer_single_project))
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["database"]["db_path"] == ":memory:"
|
||||||
|
assert result["database"]["is_persistent"] is False
|
||||||
|
|
||||||
|
async def test_load_xer_database_contains_all_data(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""Persistent database contains all parsed data."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify data via direct SQL
|
||||||
|
conn = sqlite3.connect(str(db_file))
|
||||||
|
cursor = conn.execute("SELECT COUNT(*) FROM activities")
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert count == result["activity_count"]
|
||||||
|
|
||||||
|
async def test_load_xer_response_includes_database_info(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""load_xer response includes complete database info."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "database" in result
|
||||||
|
db_info = result["database"]
|
||||||
|
assert "db_path" in db_info
|
||||||
|
assert "is_persistent" in db_info
|
||||||
|
assert "source_file" in db_info
|
||||||
|
assert "loaded_at" in db_info
|
||||||
|
assert "schema" in db_info
|
||||||
|
|
||||||
|
async def test_load_xer_response_schema_includes_tables(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""load_xer response schema includes table information."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = result["database"]["schema"]
|
||||||
|
assert "version" in schema
|
||||||
|
assert "tables" in schema
|
||||||
|
table_names = [t["name"] for t in schema["tables"]]
|
||||||
|
assert "activities" in table_names
|
||||||
|
assert "relationships" in table_names
|
||||||
|
|
||||||
|
async def test_load_xer_error_on_invalid_path(self, sample_xer_single_project: Path) -> None:
|
||||||
|
"""load_xer returns error for invalid path."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path="/nonexistent/dir/file.db",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
# Either FILE_NOT_WRITABLE or DATABASE_ERROR is acceptable
|
||||||
|
# depending on how SQLite reports the error
|
||||||
|
assert result["error"]["code"] in ("FILE_NOT_WRITABLE", "DATABASE_ERROR")
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadXerToolSchema:
|
||||||
|
"""Tests for MCP tool schema."""
|
||||||
|
|
||||||
|
async def test_load_xer_tool_schema_includes_db_path(self) -> None:
|
||||||
|
"""MCP tool schema includes db_path parameter."""
|
||||||
|
from xer_mcp.server import list_tools
|
||||||
|
|
||||||
|
tools = await list_tools()
|
||||||
|
load_xer_tool = next(t for t in tools if t.name == "load_xer")
|
||||||
|
props = load_xer_tool.inputSchema["properties"]
|
||||||
|
assert "db_path" in props
|
||||||
|
assert props["db_path"]["type"] == "string"
|
||||||
|
|||||||
185
tests/integration/test_direct_db_access.py
Normal file
185
tests/integration/test_direct_db_access.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""Integration tests for direct database access feature."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from xer_mcp.db import db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def setup_db():
|
||||||
|
"""Reset database state for each test."""
|
||||||
|
if db.is_initialized:
|
||||||
|
db.close()
|
||||||
|
yield
|
||||||
|
if db.is_initialized:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDirectDatabaseAccess:
|
||||||
|
"""Integration tests verifying external script can access database."""
|
||||||
|
|
||||||
|
async def test_external_script_can_query_database(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""External script can query database using returned path."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Simulate external script access (as shown in quickstart.md)
|
||||||
|
db_path = result["database"]["db_path"]
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
|
# Query milestones
|
||||||
|
cursor = conn.execute("""
|
||||||
|
SELECT task_code, task_name, target_start_date, milestone_type
|
||||||
|
FROM activities
|
||||||
|
WHERE task_type IN ('TT_Mile', 'TT_FinMile')
|
||||||
|
ORDER BY target_start_date
|
||||||
|
""")
|
||||||
|
|
||||||
|
milestones = cursor.fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert len(milestones) > 0
|
||||||
|
assert all(row["task_code"] for row in milestones)
|
||||||
|
|
||||||
|
async def test_external_script_can_query_critical_path(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""External script can query critical path activities."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_path = result["database"]["db_path"]
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute("""
|
||||||
|
SELECT task_code, task_name, target_start_date, target_end_date
|
||||||
|
FROM activities
|
||||||
|
WHERE driving_path_flag = 1
|
||||||
|
ORDER BY target_start_date
|
||||||
|
""")
|
||||||
|
|
||||||
|
critical_activities = cursor.fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert len(critical_activities) > 0
|
||||||
|
|
||||||
|
async def test_external_script_can_join_tables(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""External script can join activities with WBS."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
db_path = result["database"]["db_path"]
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute("""
|
||||||
|
SELECT a.task_code, a.task_name, w.wbs_name
|
||||||
|
FROM activities a
|
||||||
|
JOIN wbs w ON a.wbs_id = w.wbs_id
|
||||||
|
LIMIT 10
|
||||||
|
""")
|
||||||
|
|
||||||
|
joined_rows = cursor.fetchall()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert len(joined_rows) > 0
|
||||||
|
|
||||||
|
async def test_database_accessible_after_mcp_load(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""Database remains accessible while MCP tools are active."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
loaded_count = result["activity_count"]
|
||||||
|
|
||||||
|
# External script queries database
|
||||||
|
conn = sqlite3.connect(str(db_file))
|
||||||
|
cursor = conn.execute("SELECT COUNT(*) FROM activities")
|
||||||
|
external_count = cursor.fetchone()[0]
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Both should match
|
||||||
|
assert external_count == loaded_count
|
||||||
|
|
||||||
|
async def test_schema_info_matches_actual_database(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""Returned schema info matches actual database structure."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = result["database"]["schema"]
|
||||||
|
db_path = result["database"]["db_path"]
|
||||||
|
|
||||||
|
# Verify tables exist in actual database
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'"
|
||||||
|
)
|
||||||
|
actual_tables = {row[0] for row in cursor.fetchall()}
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
schema_tables = {t["name"] for t in schema["tables"]}
|
||||||
|
assert schema_tables == actual_tables
|
||||||
|
|
||||||
|
async def test_row_counts_match_actual_data(
|
||||||
|
self, tmp_path: Path, sample_xer_single_project: Path
|
||||||
|
) -> None:
|
||||||
|
"""Schema row counts match actual database row counts."""
|
||||||
|
from xer_mcp.tools.load_xer import load_xer
|
||||||
|
|
||||||
|
db_file = tmp_path / "schedule.db"
|
||||||
|
result = await load_xer(
|
||||||
|
file_path=str(sample_xer_single_project),
|
||||||
|
db_path=str(db_file),
|
||||||
|
)
|
||||||
|
|
||||||
|
schema = result["database"]["schema"]
|
||||||
|
db_path = result["database"]["db_path"]
|
||||||
|
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
|
||||||
|
for table_info in schema["tables"]:
|
||||||
|
cursor = conn.execute(
|
||||||
|
f"SELECT COUNT(*) FROM {table_info['name']}" # noqa: S608
|
||||||
|
)
|
||||||
|
actual_count = cursor.fetchone()[0]
|
||||||
|
assert table_info["row_count"] == actual_count, (
|
||||||
|
f"Table {table_info['name']}: expected {table_info['row_count']}, "
|
||||||
|
f"got {actual_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.close()
|
||||||
234
tests/unit/test_db_manager.py
Normal file
234
tests/unit/test_db_manager.py
Normal file
@@ -0,0 +1,234 @@
|
|||||||
|
"""Unit tests for DatabaseManager file-based database support."""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from xer_mcp.db import DatabaseManager
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabaseManagerInitialization:
|
||||||
|
"""Tests for DatabaseManager initialization modes."""
|
||||||
|
|
||||||
|
def test_initialize_with_memory_by_default(self) -> None:
|
||||||
|
"""Default initialization uses in-memory database."""
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize()
|
||||||
|
assert dm.db_path == ":memory:"
|
||||||
|
assert dm.is_persistent is False
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_initialize_with_file_path(self, tmp_path: Path) -> None:
|
||||||
|
"""Can initialize with explicit file path."""
|
||||||
|
db_file = tmp_path / "test.db"
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize(db_path=str(db_file))
|
||||||
|
assert dm.db_path == str(db_file)
|
||||||
|
assert dm.is_persistent is True
|
||||||
|
assert db_file.exists()
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_initialize_with_empty_string_auto_generates_path(self, tmp_path: Path) -> None:
|
||||||
|
"""Empty string db_path with source_file auto-generates path."""
|
||||||
|
xer_file = tmp_path / "schedule.xer"
|
||||||
|
xer_file.write_text("dummy content")
|
||||||
|
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize(db_path="", source_file=str(xer_file))
|
||||||
|
expected_db = str(tmp_path / "schedule.sqlite")
|
||||||
|
assert dm.db_path == expected_db
|
||||||
|
assert dm.is_persistent is True
|
||||||
|
assert Path(expected_db).exists()
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_file_database_persists_after_close(self, tmp_path: Path) -> None:
|
||||||
|
"""File-based database persists after connection close."""
|
||||||
|
db_file = tmp_path / "persist_test.db"
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize(db_path=str(db_file))
|
||||||
|
|
||||||
|
# Insert test data
|
||||||
|
with dm.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO projects (proj_id, proj_short_name, loaded_at) "
|
||||||
|
"VALUES ('P1', 'Test', datetime('now'))"
|
||||||
|
)
|
||||||
|
dm.commit()
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
# Verify file exists and has data
|
||||||
|
assert db_file.exists()
|
||||||
|
conn = sqlite3.connect(str(db_file))
|
||||||
|
cursor = conn.execute("SELECT proj_id FROM projects")
|
||||||
|
rows = cursor.fetchall()
|
||||||
|
conn.close()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0][0] == "P1"
|
||||||
|
|
||||||
|
def test_source_file_tracked(self, tmp_path: Path) -> None:
|
||||||
|
"""Source file path is tracked when provided."""
|
||||||
|
db_file = tmp_path / "test.db"
|
||||||
|
xer_file = tmp_path / "schedule.xer"
|
||||||
|
xer_file.write_text("dummy")
|
||||||
|
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize(db_path=str(db_file), source_file=str(xer_file))
|
||||||
|
assert dm.source_file == str(xer_file)
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_loaded_at_timestamp(self, tmp_path: Path) -> None:
|
||||||
|
"""Loaded_at timestamp is recorded."""
|
||||||
|
db_file = tmp_path / "test.db"
|
||||||
|
dm = DatabaseManager()
|
||||||
|
before = datetime.now()
|
||||||
|
dm.initialize(db_path=str(db_file))
|
||||||
|
after = datetime.now()
|
||||||
|
|
||||||
|
loaded_at = dm.loaded_at
|
||||||
|
assert loaded_at is not None
|
||||||
|
assert before <= loaded_at <= after
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_memory_database_not_persistent(self) -> None:
|
||||||
|
"""In-memory database is not persistent."""
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize()
|
||||||
|
assert dm.is_persistent is False
|
||||||
|
assert dm.db_path == ":memory:"
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatabaseManagerWalMode:
|
||||||
|
"""Tests for WAL mode in file-based databases."""
|
||||||
|
|
||||||
|
def test_file_database_uses_wal_mode(self, tmp_path: Path) -> None:
|
||||||
|
"""File-based database uses WAL mode for concurrent access."""
|
||||||
|
db_file = tmp_path / "wal_test.db"
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize(db_path=str(db_file))
|
||||||
|
|
||||||
|
with dm.cursor() as cur:
|
||||||
|
cur.execute("PRAGMA journal_mode")
|
||||||
|
mode = cur.fetchone()[0]
|
||||||
|
assert mode.lower() == "wal"
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_memory_database_does_not_use_wal(self) -> None:
|
||||||
|
"""In-memory database doesn't use WAL mode (not applicable)."""
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize()
|
||||||
|
|
||||||
|
with dm.cursor() as cur:
|
||||||
|
cur.execute("PRAGMA journal_mode")
|
||||||
|
mode = cur.fetchone()[0]
|
||||||
|
# Memory databases use 'memory' journal mode
|
||||||
|
assert mode.lower() == "memory"
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
|
||||||
|
class TestAtomicWrite:
|
||||||
|
"""Tests for atomic write pattern."""
|
||||||
|
|
||||||
|
def test_atomic_write_creates_final_file(self, tmp_path: Path) -> None:
|
||||||
|
"""Database is created at final path after initialization."""
|
||||||
|
target = tmp_path / "atomic_test.db"
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize(db_path=str(target))
|
||||||
|
assert target.exists()
|
||||||
|
assert not Path(str(target) + ".tmp").exists()
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_atomic_write_no_temp_file_remains(self, tmp_path: Path) -> None:
|
||||||
|
"""No .tmp file remains after successful initialization."""
|
||||||
|
target = tmp_path / "atomic_clean.db"
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize(db_path=str(target))
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
# Check no temp files remain
|
||||||
|
temp_files = list(tmp_path.glob("*.tmp"))
|
||||||
|
assert len(temp_files) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSchemaIntrospection:
|
||||||
|
"""Tests for database schema introspection."""
|
||||||
|
|
||||||
|
def test_get_schema_info_returns_all_tables(self) -> None:
|
||||||
|
"""Schema info includes all database tables."""
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize()
|
||||||
|
schema = dm.get_schema_info()
|
||||||
|
|
||||||
|
assert schema["version"] == "0.2.0"
|
||||||
|
table_names = [t["name"] for t in schema["tables"]]
|
||||||
|
assert "projects" in table_names
|
||||||
|
assert "activities" in table_names
|
||||||
|
assert "relationships" in table_names
|
||||||
|
assert "wbs" in table_names
|
||||||
|
assert "calendars" in table_names
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_get_schema_info_includes_column_details(self) -> None:
|
||||||
|
"""Schema info includes column names, types, and nullable."""
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize()
|
||||||
|
schema = dm.get_schema_info()
|
||||||
|
|
||||||
|
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||||
|
column_names = [c["name"] for c in activities_table["columns"]]
|
||||||
|
assert "task_id" in column_names
|
||||||
|
assert "task_name" in column_names
|
||||||
|
|
||||||
|
# Check column details
|
||||||
|
task_id_col = next(c for c in activities_table["columns"] if c["name"] == "task_id")
|
||||||
|
assert task_id_col["type"] == "TEXT"
|
||||||
|
# Note: SQLite reports PRIMARY KEY TEXT columns as nullable
|
||||||
|
# but the PRIMARY KEY constraint still applies
|
||||||
|
assert "nullable" in task_id_col
|
||||||
|
|
||||||
|
# Check a NOT NULL column
|
||||||
|
task_name_col = next(c for c in activities_table["columns"] if c["name"] == "task_name")
|
||||||
|
assert task_name_col["nullable"] is False
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_get_schema_info_includes_row_counts(self) -> None:
|
||||||
|
"""Schema info includes row counts for each table."""
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize()
|
||||||
|
schema = dm.get_schema_info()
|
||||||
|
|
||||||
|
for table in schema["tables"]:
|
||||||
|
assert "row_count" in table
|
||||||
|
assert isinstance(table["row_count"], int)
|
||||||
|
assert table["row_count"] >= 0
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_schema_info_includes_primary_keys(self) -> None:
|
||||||
|
"""Schema info includes primary key for each table."""
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize()
|
||||||
|
schema = dm.get_schema_info()
|
||||||
|
|
||||||
|
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||||
|
assert "primary_key" in activities_table
|
||||||
|
assert "task_id" in activities_table["primary_key"]
|
||||||
|
dm.close()
|
||||||
|
|
||||||
|
def test_schema_info_includes_foreign_keys(self) -> None:
|
||||||
|
"""Schema info includes foreign key relationships."""
|
||||||
|
dm = DatabaseManager()
|
||||||
|
dm.initialize()
|
||||||
|
schema = dm.get_schema_info()
|
||||||
|
|
||||||
|
activities_table = next(t for t in schema["tables"] if t["name"] == "activities")
|
||||||
|
assert "foreign_keys" in activities_table
|
||||||
|
|
||||||
|
# activities.proj_id -> projects.proj_id
|
||||||
|
fk = next(
|
||||||
|
(fk for fk in activities_table["foreign_keys"] if fk["column"] == "proj_id"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert fk is not None
|
||||||
|
assert fk["references_table"] == "projects"
|
||||||
|
assert fk["references_column"] == "proj_id"
|
||||||
|
dm.close()
|
||||||
Reference in New Issue
Block a user