"""Database query functions for XER data.""" from datetime import datetime, timedelta from xer_mcp.db import db def is_driving_relationship( pred_early_end: str | None, succ_early_start: str | None, lag_hours: float, pred_type: str, ) -> bool: """Determine if a relationship is driving the successor's early start. A relationship is "driving" when the predecessor's completion (plus lag) determines the successor's early start date. This is computed by comparing dates with a tolerance for overnight gaps and calendar differences. Args: pred_early_end: Predecessor's early end date (ISO format) succ_early_start: Successor's early start date (ISO format) lag_hours: Lag duration in hours (can be negative) pred_type: Relationship type (FS, SS, FF, SF) Returns: True if the relationship is driving, False otherwise """ if pred_early_end is None or succ_early_start is None: return False try: pred_end = datetime.fromisoformat(pred_early_end) succ_start = datetime.fromisoformat(succ_early_start) except (ValueError, TypeError): return False # For FS (Finish-to-Start): pred_end + lag should equal succ_start if pred_type == "FS": expected_start = pred_end + timedelta(hours=lag_hours) # Allow tolerance of 24 hours for overnight gaps and calendar differences diff = abs((succ_start - expected_start).total_seconds()) return diff <= 24 * 3600 # 24 hours tolerance # For SS (Start-to-Start): would need pred_early_start, not implemented # For FF (Finish-to-Finish): would need succ_early_end, not implemented # For SF (Start-to-Finish): complex case, not implemented # Default to False for non-FS relationships for now return False def query_activities( limit: int = 100, offset: int = 0, start_date: str | None = None, end_date: str | None = None, wbs_id: str | None = None, activity_type: str | None = None, ) -> tuple[list[dict], int]: """Query activities with pagination and filtering. Args: limit: Maximum number of results to return offset: Number of results to skip start_date: Filter activities starting on or after this date (YYYY-MM-DD) end_date: Filter activities ending on or before this date (YYYY-MM-DD) wbs_id: Filter by WBS ID activity_type: Filter by task type (TT_Task, TT_Mile, etc.) Returns: Tuple of (list of activity dicts, total count matching filters) """ # Build WHERE clause conditions = [] params: list = [] if start_date: conditions.append("target_start_date >= ?") params.append(f"{start_date}T00:00:00") if end_date: conditions.append("target_end_date <= ?") params.append(f"{end_date}T23:59:59") if wbs_id: conditions.append("wbs_id = ?") params.append(wbs_id) if activity_type: conditions.append("task_type = ?") params.append(activity_type) where_clause = " AND ".join(conditions) if conditions else "1=1" # Get total count with db.cursor() as cur: cur.execute(f"SELECT COUNT(*) FROM activities WHERE {where_clause}", params) # noqa: S608 total = cur.fetchone()[0] # Get paginated results query = f""" SELECT task_id, task_code, task_name, task_type, target_start_date, target_end_date, status_code, driving_path_flag, wbs_id, total_float_hr_cnt FROM activities WHERE {where_clause} ORDER BY target_start_date, task_code LIMIT ? OFFSET ? """ # noqa: S608 with db.cursor() as cur: cur.execute(query, [*params, limit, offset]) rows = cur.fetchall() activities = [ { "task_id": row[0], "task_code": row[1], "task_name": row[2], "task_type": row[3], "target_start_date": row[4], "target_end_date": row[5], "status_code": row[6], "driving_path_flag": bool(row[7]), } for row in rows ] return activities, total def get_activity_by_id(activity_id: str) -> dict | None: """Get a single activity by ID with full details. Args: activity_id: The task_id to look up Returns: Activity dict with all fields, or None if not found """ query = """ SELECT a.task_id, a.task_code, a.task_name, a.task_type, a.wbs_id, w.wbs_name, a.target_start_date, a.target_end_date, a.act_start_date, a.act_end_date, a.total_float_hr_cnt, a.status_code, a.driving_path_flag FROM activities a LEFT JOIN wbs w ON a.wbs_id = w.wbs_id WHERE a.task_id = ? """ with db.cursor() as cur: cur.execute(query, (activity_id,)) row = cur.fetchone() if row is None: return None # Count predecessors and successors with db.cursor() as cur: cur.execute( "SELECT COUNT(*) FROM relationships WHERE task_id = ?", (activity_id,), ) predecessor_count = cur.fetchone()[0] cur.execute( "SELECT COUNT(*) FROM relationships WHERE pred_task_id = ?", (activity_id,), ) successor_count = cur.fetchone()[0] return { "task_id": row[0], "task_code": row[1], "task_name": row[2], "task_type": row[3], "wbs_id": row[4], "wbs_name": row[5], "target_start_date": row[6], "target_end_date": row[7], "act_start_date": row[8], "act_end_date": row[9], "total_float_hr_cnt": row[10], "status_code": row[11], "driving_path_flag": bool(row[12]), "predecessor_count": predecessor_count, "successor_count": successor_count, } def query_relationships( limit: int = 100, offset: int = 0, ) -> tuple[list[dict], int]: """Query relationships with pagination. Args: limit: Maximum number of results to return offset: Number of results to skip Returns: Tuple of (list of relationship dicts, total count) """ # Get total count with db.cursor() as cur: cur.execute("SELECT COUNT(*) FROM relationships") total = cur.fetchone()[0] # Get paginated results with activity names and early dates for driving computation query = """ SELECT r.task_pred_id, r.task_id, succ.task_name, r.pred_task_id, pred.task_name, r.pred_type, r.lag_hr_cnt, pred.early_end_date, succ.early_start_date FROM relationships r LEFT JOIN activities succ ON r.task_id = succ.task_id LEFT JOIN activities pred ON r.pred_task_id = pred.task_id ORDER BY r.task_pred_id LIMIT ? OFFSET ? """ with db.cursor() as cur: cur.execute(query, (limit, offset)) rows = cur.fetchall() # Convert pred_type from internal format (PR_FS) to API format (FS) def format_pred_type(pred_type: str) -> str: if pred_type.startswith("PR_"): return pred_type[3:] return pred_type relationships = [] for row in rows: pred_type = format_pred_type(row[5]) driving = is_driving_relationship( pred_early_end=row[7], succ_early_start=row[8], lag_hours=row[6] or 0.0, pred_type=pred_type, ) relationships.append( { "task_pred_id": row[0], "task_id": row[1], "task_name": row[2], "pred_task_id": row[3], "pred_task_name": row[4], "pred_type": pred_type, "lag_hr_cnt": row[6], "driving": driving, } ) return relationships, total def get_predecessors(activity_id: str) -> list[dict]: """Get predecessor activities for a given activity. Args: activity_id: The task_id to find predecessors for Returns: List of predecessor activity dicts with relationship info and driving flag """ # Get successor's early start for driving calculation with db.cursor() as cur: cur.execute( "SELECT early_start_date FROM activities WHERE task_id = ?", (activity_id,), ) succ_row = cur.fetchone() succ_early_start = succ_row[0] if succ_row else None query = """ SELECT pred.task_id, pred.task_code, pred.task_name, r.pred_type, r.lag_hr_cnt, pred.early_end_date FROM relationships r JOIN activities pred ON r.pred_task_id = pred.task_id WHERE r.task_id = ? ORDER BY pred.task_code """ with db.cursor() as cur: cur.execute(query, (activity_id,)) rows = cur.fetchall() def format_pred_type(pred_type: str) -> str: if pred_type.startswith("PR_"): return pred_type[3:] return pred_type result = [] for row in rows: pred_type = format_pred_type(row[3]) driving = is_driving_relationship( pred_early_end=row[5], succ_early_start=succ_early_start, lag_hours=row[4] or 0.0, pred_type=pred_type, ) result.append( { "task_id": row[0], "task_code": row[1], "task_name": row[2], "relationship_type": pred_type, "lag_hr_cnt": row[4], "driving": driving, } ) return result def get_successors(activity_id: str) -> list[dict]: """Get successor activities for a given activity. Args: activity_id: The task_id to find successors for Returns: List of successor activity dicts with relationship info and driving flag """ # Get predecessor's early end for driving calculation with db.cursor() as cur: cur.execute( "SELECT early_end_date FROM activities WHERE task_id = ?", (activity_id,), ) pred_row = cur.fetchone() pred_early_end = pred_row[0] if pred_row else None query = """ SELECT succ.task_id, succ.task_code, succ.task_name, r.pred_type, r.lag_hr_cnt, succ.early_start_date FROM relationships r JOIN activities succ ON r.task_id = succ.task_id WHERE r.pred_task_id = ? ORDER BY succ.task_code """ with db.cursor() as cur: cur.execute(query, (activity_id,)) rows = cur.fetchall() def format_pred_type(pred_type: str) -> str: if pred_type.startswith("PR_"): return pred_type[3:] return pred_type result = [] for row in rows: pred_type = format_pred_type(row[3]) driving = is_driving_relationship( pred_early_end=pred_early_end, succ_early_start=row[5], lag_hours=row[4] or 0.0, pred_type=pred_type, ) result.append( { "task_id": row[0], "task_code": row[1], "task_name": row[2], "relationship_type": pred_type, "lag_hr_cnt": row[4], "driving": driving, } ) return result def get_project_summary(project_id: str) -> dict | None: """Get project summary information. Args: project_id: The project ID to get summary for Returns: Dictionary with project summary or None if not found """ # Get project info with db.cursor() as cur: cur.execute( """ SELECT proj_id, proj_short_name, plan_start_date, plan_end_date, last_recalc_date FROM projects WHERE proj_id = ? """, (project_id,), ) project_row = cur.fetchone() if project_row is None: return None # Get activity count with db.cursor() as cur: cur.execute("SELECT COUNT(*) FROM activities") activity_count = cur.fetchone()[0] # Get milestone count (both start and finish milestones) with db.cursor() as cur: cur.execute("SELECT COUNT(*) FROM activities WHERE task_type IN ('TT_Mile', 'TT_FinMile')") milestone_count = cur.fetchone()[0] # Get critical activity count with db.cursor() as cur: cur.execute("SELECT COUNT(*) FROM activities WHERE driving_path_flag = 1") critical_count = cur.fetchone()[0] return { "project_id": project_row[0], "project_name": project_row[1], "data_date": project_row[4], "plan_start_date": project_row[2], "plan_end_date": project_row[3], "activity_count": activity_count, "milestone_count": milestone_count, "critical_activity_count": critical_count, } def query_milestones() -> list[dict]: """Query all milestone activities (both start and finish milestones). Returns: List of milestone activity dicts with milestone_type (start/finish) """ query = """ SELECT task_id, task_code, task_name, target_start_date, target_end_date, status_code, milestone_type FROM activities WHERE task_type IN ('TT_Mile', 'TT_FinMile') ORDER BY target_start_date, task_code """ with db.cursor() as cur: cur.execute(query) rows = cur.fetchall() return [ { "task_id": row[0], "task_code": row[1], "task_name": row[2], "target_start_date": row[3], "target_end_date": row[4], "status_code": row[5], "milestone_type": row[6], } for row in rows ] def query_critical_path() -> list[dict]: """Query all activities on the critical path. Returns: List of critical path activity dicts ordered by start date """ query = """ SELECT task_id, task_code, task_name, task_type, target_start_date, target_end_date, total_float_hr_cnt, status_code FROM activities WHERE driving_path_flag = 1 ORDER BY target_start_date, task_code """ with db.cursor() as cur: cur.execute(query) rows = cur.fetchall() return [ { "task_id": row[0], "task_code": row[1], "task_name": row[2], "task_type": row[3], "target_start_date": row[4], "target_end_date": row[5], "total_float_hr_cnt": row[6], "status_code": row[7], } for row in rows ]