From 6c292a392d9b6a9a1f8db76aa54ce8f6cd9734f8 Mon Sep 17 00:00:00 2001 From: Erfan Nourbakhsh Date: Wed, 18 Dec 2024 07:38:02 -0500 Subject: [PATCH 1/2] Initial prototype of Mermaid code generator for pipelines and quantum graphs --- python/lsst/pipe/base/mermaid_tools.py | 431 ++++++++++++++++++ .../pipeline_graph/visualization/_mermaid.py | 428 +++++++++++++++++ 2 files changed, 859 insertions(+) create mode 100644 python/lsst/pipe/base/mermaid_tools.py create mode 100644 python/lsst/pipe/base/pipeline_graph/visualization/_mermaid.py diff --git a/python/lsst/pipe/base/mermaid_tools.py b/python/lsst/pipe/base/mermaid_tools.py new file mode 100644 index 000000000..1fb7efaf2 --- /dev/null +++ b/python/lsst/pipe/base/mermaid_tools.py @@ -0,0 +1,431 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Module defining few methods to generate Mermaid charts from pipelines or +quantum graphs. +""" + +from __future__ import annotations + +__all__ = ["graph2mermaid", "pipeline2mermaid"] + +import html +import re +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Iterable + +from lsst.daf.butler import DatasetType, DimensionUniverse + +from . import connectionTypes +from .connections import iterConnections +from .pipeline import Pipeline + +if TYPE_CHECKING: + from lsst.daf.butler import DatasetRef + from lsst.pipe.base import QuantumGraph, TaskDef + + +def _datasetRefId(dsRef: DatasetRef) -> str: + """Make a unique identifier string for a dataset ref based on its name and + dataId.""" + dsIdParts = [dsRef.datasetType.name] + dsIdParts.extend(f"{key}_{dsRef.dataId[key]}" for key in sorted(dsRef.dataId.required.keys())) + return "_".join(dsIdParts) + + +def _makeDatasetNode(dsRef: DatasetRef, allDatasetRefs: dict[str, str], file: Any) -> str: + """Create a Mermaid node for a dataset if it doesn't exist, and return its + node ID.""" + dsId = _datasetRefId(dsRef) + nodeName = allDatasetRefs.get(dsId) + if nodeName is None: + nodeName = f"DATASET_{len(allDatasetRefs)}" + allDatasetRefs[dsId] = nodeName + # Simple label: datasetType name and run. + label_lines = [f"**{dsRef.datasetType.name}**", f"run: {dsRef.run}"] + # Add dataId info. + for k in sorted(dsRef.dataId.required.keys()): + label_lines.append(f"{k}={dsRef.dataId[k]}") + label = "
".join(label_lines) + print(f'{nodeName}["{label}"]', file=file) + return nodeName + + +def graph2mermaid(qgraph: QuantumGraph, file: Any) -> None: + """Convert QuantumGraph into a Mermaid flowchart (top-down). + + Parameters + ---------- + qgraph : `lsst.pipe.base.QuantumGraph` + QuantumGraph instance. + file : `str` or file object + File where Mermaid flowchart is written, can be a file name or file + object. + + Raises + ------ + OSError + Raised if the output file cannot be opened. + ImportError + Raised if the task class cannot be imported. + """ + # Open a file if needed. + close = False + if not hasattr(file, "write"): + file = open(file, "w") + close = True + + # Start Mermaid code block and flowchart. + print("```mermaid", file=file) + print("flowchart TD", file=file) + + # To avoid duplicating dataset nodes, we track them. + allDatasetRefs: dict[str, str] = {} + + # Process each task/quantum. + for taskId, taskDef in enumerate(qgraph.taskGraph): + quanta = qgraph.getNodesForTask(taskDef) + for qId, quantumNode in enumerate(quanta): + # Create quantum node. + taskNodeName = f"TASK_{taskId}_{qId}" + taskLabelLines = [f"**{taskDef.label}**", f"Node ID: {quantumNode.nodeId}"] + dataId = quantumNode.quantum.dataId + if dataId is not None: + for k in sorted(dataId.required.keys()): + taskLabelLines.append(f"{k}={dataId[k]}") + else: + raise ValueError("Quantum DataId cannot be None") + taskLabel = "
".join(taskLabelLines) + print(f'{taskNodeName}["{taskLabel}"]', file=file) + + # Quantum inputs: datasets --> tasks + for dsRefs in quantumNode.quantum.inputs.values(): + for dsRef in dsRefs: + dsNode = _makeDatasetNode(dsRef, allDatasetRefs, file) + print(f"{dsNode} --> {taskNodeName}", file=file) + + # Quantum outputs: tasks --> datasets + for dsRefs in quantumNode.quantum.outputs.values(): + for dsRef in dsRefs: + dsNode = _makeDatasetNode(dsRef, allDatasetRefs, file) + print(f"{taskNodeName} --> {dsNode}", file=file) + + # End Mermaid code block. + print("```", file=file) + + if close: + file.close() + + +def _expand_dimensions(dimension_list: list[str], universe: DimensionUniverse) -> list[str]: + """Return expanded list of dimensions, with special skypix treatment. + + Parameters + ---------- + dimension_set : `list` [`str`] + The original set of dimension names. + universe : DimensionUniverse + Used to conform the dimension set according to a known schema. + + Returns + ------- + dimensions : `list` [`str`] + Expanded list of dimensions. + """ + dimension_set = set(dimension_list) + skypix_dim = [] + if "skypix" in dimension_set: + dimension_set.remove("skypix") + skypix_dim = ["skypix"] + dimensions = universe.conform(dimension_set) + return list(dimensions.names) + skypix_dim + + +def _format_dimensions(dims: list[str]) -> str: + """Format and sort dimension names as a comma-separated list inside curly + braces. + + For example, if dims=["detector", "visit"], returns "{detector,visit}". + + Parameters + ---------- + dims : list of str + The dimension names to format and sort. + + Returns + ------- + str + The formatted dimension string, or an empty string if no dimensions. + """ + if not dims: + return "" + sorted_dims = sorted(dims) + return "{" + ", ".join(sorted_dims) + "}" + + +def _render_task_node( + task_id: str, taskDef: TaskDef, universe: DimensionUniverse, file: Any, show_dimensions: bool +) -> None: + """Render a single task node in the Mermaid diagram. + + Parameters + ---------- + task_id : str + Unique Mermaid node identifier for this task. + taskDef : TaskDef + The pipeline task definition, which includes the task label, task name, + and connections. + universe : DimensionUniverse + Used to conform and sort the task's dimensions. + file : file-like + The output file-like object to write the Mermaid node definition. + show_dimensions : bool + If True, display the task's dimensions after conforming them. + """ + # Basic info: bold label, then task name. + lines = [ + f"{html.escape(taskDef.label)}", + html.escape(taskDef.taskName), + ] + + # If requested, display the task's conformed dimensions. + if show_dimensions and taskDef.connections and taskDef.connections.dimensions: + task_dims = _expand_dimensions(taskDef.connections.dimensions, universe) + if task_dims: + dim_str = _format_dimensions(task_dims) + lines.append(f"dimensions: {html.escape(dim_str)}") + + # Join with
for line breaks and define the node with the label. + label = "
".join(lines) + print(f'{task_id}["{label}"]', file=file) + print(f"class {task_id} task;", file=file) + + +def _render_dataset_node( + ds_id: str, + ds_name: str, + connection: connectionTypes.BaseConnection, + universe: DimensionUniverse, + file: Any, + show_dimensions: bool, + show_storage: bool, +) -> None: + """Render a dataset-type node in the Mermaid diagram. + + Parameters + ---------- + ds_id : str + Unique Mermaid node identifier for this dataset. + ds_name : str + The dataset type name. + connection : BaseConnection + The dataset connection object, potentially dimensioned and having a + storage class. + universe : DimensionUniverse + Used to conform and sort the dataset's dimensions if it is dimensioned. + file : file-like + The output file-like object to write the Mermaid node definition. + show_dimensions : bool + If True, display the dataset's conformed dimensions. + show_storage : bool + If True, display the dataset's storage class if available. + """ + # Start with the dataset name in bold. + lines = [f"{html.escape(ds_name)}"] + + # If dimensioned and requested, show conformed dimensions. + ds_dims = [] + if show_dimensions and isinstance(connection, connectionTypes.DimensionedConnection): + ds_dims = _expand_dimensions(connection.dimensions, universe) + + if ds_dims: + dim_str = _format_dimensions(ds_dims) + lines.append(f"dimensions: {html.escape(dim_str)}") + + # If storage class is available and requested, display it. + if show_storage and getattr(connection, "storageClass", None) is not None: + lines.append(f"storage class: {html.escape(str(connection.storageClass))}") + + label = "
".join(lines) + print(f'{ds_id}["{label}"]', file=file) + print(f"class {ds_id} ds;", file=file) + + +def pipeline2mermaid( + pipeline: Pipeline | Iterable[TaskDef], file: Any, show_dimensions: bool = True, show_storage: bool = True +) -> None: + """Convert a Pipeline into a Mermaid flowchart diagram. + + This function produces a Mermaid flowchart, representing tasks and their + inputs/outputs as dataset nodes. It uses a top-down layout. + + Parameters + ---------- + pipeline : Pipeline or Iterable[TaskDef] + The pipeline or collection of tasks to represent. + file : str or file-like + The output file or file-like object into which the Mermaid code is + written. + show_dimensions : bool, optional + If True, display dimension information for tasks and datasets. + Default is True. + show_storage : bool, optional + If True, display storage class information for datasets. Default is + True. + + Raises + ------ + OSError + Raised if the output file cannot be opened. + ImportError + Raised if the task class cannot be imported. + """ + universe = DimensionUniverse() + + # Ensure that pipeline is iterable of task definitions. + if isinstance(pipeline, Pipeline): + pipeline = pipeline.to_graph()._iter_task_defs() + + # Open file if needed. + close = False + if not hasattr(file, "write"): + file = open(file, "w") + close = True + + # Begin the Mermaid code block with top-down layout. + print("```mermaid", file=file) + print("flowchart TD", file=file) + + # Define classes for tasks and datasets. + print( + "classDef task fill:#B1F2EF,color:#000,stroke:#000,stroke-width:3px,font-family:Monospace,font-size:14px,text-align:left;", + file=file, + ) + print( + "classDef ds fill:#F5F5F5,color:#000,stroke:#00BABC,stroke-width:3px,font-family:Monospace,font-size:14px,text-align:left,rx:10,ry:10;", + file=file, + ) + + # Track which datasets have been rendered to avoid duplicates. + allDatasets: set[str | tuple[str, str]] = set() + + # Used for linking metadata datasets after tasks are processed. + labelToTaskName = {} + metadataNodesToLink = set() + + # We'll store edges as (from_node, to_node, is_prerequisite) tuples. + edges: list[tuple[str, str, bool]] = [] + + def get_task_id(idx: int) -> str: + """Generate a safe Mermaid node ID for a task.""" + return f"TASK_{idx}" + + def get_dataset_id(name: str) -> str: + """Generate a safe Mermaid node ID for a dataset.""" + # Replace non-alphanumerics with underscores. + return "DATASET_" + re.sub(r"[^0-9A-Za-z_]", "_", name) + + metadata_pattern = re.compile(r"^(.*)_metadata$") + + # Sort tasks by label for consistent diagram ordering. + pipeline_tasks = sorted(pipeline, key=lambda x: x.label) + + # Process each task and its connections. + for idx, taskDef in enumerate(pipeline_tasks): + task_id = get_task_id(idx) + labelToTaskName[taskDef.label] = task_id + + # Render the task node. + _render_task_node(task_id, taskDef, universe, file, show_dimensions) + + # Handle standard inputs (non-prerequisite). + for attr in sorted(iterConnections(taskDef.connections, "inputs"), key=lambda x: x.name): + ds_id = get_dataset_id(attr.name) + if attr.name not in allDatasets: + _render_dataset_node(ds_id, attr.name, attr, universe, file, show_dimensions, show_storage) + allDatasets.add(attr.name) + edges.append((ds_id, task_id, False)) + + # Handle component datasets (composite -> component). + nodeName, component = DatasetType.splitDatasetTypeName(attr.name) + if component is not None and (nodeName, attr.name) not in allDatasets: + ds_id_parent = get_dataset_id(nodeName) + if nodeName not in allDatasets: + _render_dataset_node( + ds_id_parent, nodeName, attr, universe, file, show_dimensions, show_storage + ) + allDatasets.add(nodeName) + edges.append((ds_id_parent, ds_id, False)) + allDatasets.add((nodeName, attr.name)) + + # If this is a metadata dataset, record it for linking later. + if (match := metadata_pattern.match(attr.name)) is not None: + matchTaskLabel = match.group(1) + metadataNodesToLink.add((matchTaskLabel, attr.name)) + + # Handle prerequisite inputs (to be drawn with a dashed line). + for attr in sorted(iterConnections(taskDef.connections, "prerequisiteInputs"), key=lambda x: x.name): + ds_id = get_dataset_id(attr.name) + if attr.name not in allDatasets: + _render_dataset_node(ds_id, attr.name, attr, universe, file, show_dimensions, show_storage) + allDatasets.add(attr.name) + edges.append((ds_id, task_id, True)) + + # Handle outputs (task -> dataset). + for attr in sorted(iterConnections(taskDef.connections, "outputs"), key=lambda x: x.name): + ds_id = get_dataset_id(attr.name) + if attr.name not in allDatasets: + _render_dataset_node(ds_id, attr.name, attr, universe, file, show_dimensions, show_storage) + allDatasets.add(attr.name) + edges.append((task_id, ds_id, False)) + + # Link metadata datasets after all tasks processed. + for matchLabel, dsTypeName in metadataNodesToLink: + if (result := labelToTaskName.get(matchLabel)) is not None: + ds_id = get_dataset_id(dsTypeName) + edges.append((result, ds_id, False)) + + # Print all edges and track which are prerequisite. + prereq_indices = [] + for i, (f, t, p) in enumerate(edges): + print(f"{f} --> {t}", file=file) + if p: + prereq_indices.append(i) + + # Apply default edge style + print("linkStyle default stroke:#000,stroke-width:1.5px,font-family:Monospace,font-size:14px;", file=file) + + # Apply dashed style for all prerequisite edges in one line. + if prereq_indices: + prereq_str = ",".join(str(i) for i in prereq_indices) + print(f"linkStyle {prereq_str} stroke-dasharray:5;", file=file) + + # End code block + print("```", file=file) + + if close: + file.close() diff --git a/python/lsst/pipe/base/pipeline_graph/visualization/_mermaid.py b/python/lsst/pipe/base/pipeline_graph/visualization/_mermaid.py new file mode 100644 index 000000000..f42ab482a --- /dev/null +++ b/python/lsst/pipe/base/pipeline_graph/visualization/_mermaid.py @@ -0,0 +1,428 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (http://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +from __future__ import annotations + +__all__ = ("show_mermaid",) + +import html +import os +import sys +from collections.abc import Mapping +from typing import Any, TextIO + +from .._nodes import NodeType +from .._pipeline_graph import PipelineGraph +from ._formatting import NodeKey, format_dimensions, format_task_class +from ._options import NodeAttributeOptions +from ._show import parse_display_args + +# Configuration constants for label formatting and overflow handling. +_LABEL_PX_SIZE = 18 +_LABEL_MAX_LINES_SOFT = 10 +_LABEL_MAX_LINES_HARD = 15 +_OVERFLOW_MAX_LINES = 20 + + +def show_mermaid( + pipeline_graph: PipelineGraph, + stream: TextIO = sys.stdout, + **kwargs: Any, +) -> None: + """Write a Mermaid flowchart representation of the pipeline graph to a + stream. + + This function converts a given `PipelineGraph` into a Mermaid-based + flowchart. Nodes represent tasks (and possibly task-init nodes) and dataset + types, and edges represent connections between them. Dimensions and storage + classes can be included as additional metadata on nodes. Prerequisite edges + are rendered as dashed lines. + + Parameters + ---------- + pipeline_graph : `PipelineGraph` + The pipeline graph to visualize. + stream : `TextIO`, optional + The output stream where Mermaid code is written. Defaults to + `sys.stdout`. + **kwargs : Any + Additional arguments passed to `parse_display_args` to control aspects + such as displaying dimensions, storage classes, or full task class + names. + + Notes + ----- + - The diagram uses a top-down layout (`flowchart TD`). + - Three Mermaid classes are defined: + - `task` for normal tasks, + - `dsType` for dataset-type nodes, + - `taskInit` for task-init nodes. + - Edges that represent prerequisite relationships are rendered as dashed + lines using `linkStyle`. + - If a node's label is too long, overflow nodes are created to hold extra + lines. + """ + # Parse display arguments to determine what to show. + xgraph, options = parse_display_args(pipeline_graph, **kwargs) + + # Begin the Mermaid code block. + print("```mermaid", file=stream) + print("flowchart TD", file=stream) + + # Define Mermaid classes for node styling. + print( + f"classDef task fill:#B1F2EF,color:#000,stroke:#000,stroke-width:3px," + f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;", + file=stream, + ) + print( + f"classDef dsType fill:#F5F5F5,color:#000,stroke:#00BABC,stroke-width:3px," + f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left,rx:8,ry:8;", + file=stream, + ) + print( + f"classDef taskInit fill:#F4DEFA,color:#000,stroke:#000,stroke-width:3px," + f"font-family:Monospace,font-size:{_LABEL_PX_SIZE}px,text-align:left;", + file=stream, + ) + + # `overflow_ref` tracks the reference numbers for overflow nodes. + overflow_ref = 1 + overflow_ids = [] + + # Render nodes. + for node_key, node_data in xgraph.nodes.items(): + if node_key.node_type in (NodeType.TASK, NodeType.TASK_INIT): + # Render a task or task-init node. + _render_task_node(node_key, node_data, options, stream) + elif node_key.node_type == NodeType.DATASET_TYPE: + # Render a dataset-type node with possible overflow handling. + overflow_ref, node_overflow_ids = _render_dataset_type_node( + node_key, node_data, options, stream, overflow_ref + ) + if node_overflow_ids: + overflow_ids.extend(node_overflow_ids) + else: + raise AssertionError(f"Unexpected node type: {node_key.node_type}") + + # Collect edges for printing and track which ones are prerequisite + # so we can apply dashed styling after printing them. + edges = [] + for _, (from_node, to_node, *_rest) in enumerate(xgraph.edges): + is_prereq = xgraph.nodes[from_node].get("is_prerequisite", False) + edges.append((from_node.node_id, to_node.node_id, is_prereq)) + + # Print all edges + for _, (f, t, p) in enumerate(edges): + _render_edge(f, t, p, stream) + + # After printing all edges, apply linkStyle to prerequisite edges to make + # them dashed: + + # First, gather indices of prerequisite edges. + prereq_indices = [str(i) for i, (_, _, p) in enumerate(edges) if p] + + # Then apply dashed styling to all prerequisite edges in one line. + if prereq_indices: + print(f"linkStyle {','.join(prereq_indices)} stroke-dasharray:5;", file=stream) + + # End code block. + print("```", file=stream) + + +def _render_task_node( + node_key: NodeKey, + node_data: Mapping[str, Any], + options: NodeAttributeOptions, + stream: TextIO, +) -> None: + """Render a Mermaid node for a task or task-init node. + + Parameters + ---------- + node_key : NodeKey + Identifies the node. The node type determines styling and whether + dimensions apply. + node_data : Mapping[str, Any] + Node attributes, including possibly 'task_class_name' and 'dimensions'. + options : NodeAttributeOptions + Rendering options controlling whether to show dimensions, storage + classes, etc. + stream : TextIO + The output stream for Mermaid syntax. + """ + # Convert node_key into a label, handling line splitting and prefix + # extraction. + lines, _, _ = _format_label(str(node_key)) + + # If requested, show the fully qualified task class name beneath the task + # label. + if options.task_classes and node_key.node_type in (NodeType.TASK, NodeType.TASK_INIT): + lines.append(html.escape(format_task_class(options, node_data["task_class_name"]))) + + # Show dimensions if requested and if this is not a task-init node. + if options.dimensions and node_key.node_type != NodeType.TASK_INIT: + dims_str = html.escape(format_dimensions(options, node_data["dimensions"])) + lines.append(f"dimensions: {dims_str}") + + # Join lines with
for multi-line label. + label = "
".join(lines) + + # Print Mermaid node. + node_id = node_key.node_id + print(f'{node_id}["{label}"]', file=stream) + + # Assign class based on node type. + if node_key.node_type == NodeType.TASK: + print(f"class {node_id} task;", file=stream) + else: + # For NodeType.TASK_INIT. + print(f"class {node_id} taskInit;", file=stream) + + +def _render_dataset_type_node( + node_key: NodeKey, + node_data: Mapping[str, Any], + options: NodeAttributeOptions, + stream: TextIO, + overflow_ref: int, +) -> tuple[int, list[str]]: + """Render a Mermaid node for a dataset-type node, handling overflow lines + if needed. + + Dataset-type nodes can have many lines of label text. If the label exceeds + a certain threshold, we create separate "overflow" nodes. + + Parameters + ---------- + node_key : NodeKey + Identifies this dataset-type node. + node_data : Mapping[str, Any] + Node attributes, possibly including dimensions and storage class. + options : NodeAttributeOptions + Rendering options controlling whether to show dimensions and storage + classes. + stream : TextIO + The output stream for Mermaid syntax. + overflow_ref : int + The current reference number for overflow nodes. If overflow occurs, + this is incremented. + + Returns + ------- + overflow_ref : int + Possibly incremented overflow reference number. + overflow_ids : list[str] + IDs of overflow nodes created, if any. + """ + # Format the node label, respecting soft/hard line limits. + labels, label_extras, _ = _format_label(str(node_key), _LABEL_MAX_LINES_SOFT) + + overflow_ids = [] + total_lines = len(labels) + len(label_extras) + if total_lines > _LABEL_MAX_LINES_HARD: + # Too many lines, we must handle overflow by splitting extras. + allowed_extras = _LABEL_MAX_LINES_HARD - len(labels) + if allowed_extras < 0: + allowed_extras = 0 + extras_for_overflow = label_extras[allowed_extras:] + label_extras = label_extras[:allowed_extras] + + if extras_for_overflow: + # Introduce an overflow anchor. + overflow_anchor = f"[{overflow_ref}]" + labels.append(f"...more details in {overflow_anchor}") + + # Create overflow nodes in chunks. + for i in range(0, len(extras_for_overflow), _OVERFLOW_MAX_LINES): + overflow_id = f"{node_key.node_id}_overflow_{overflow_ref}_{i}" + chunk = extras_for_overflow[i : i + _OVERFLOW_MAX_LINES] + chunk.insert(0, f"{html.escape(overflow_anchor)}") + _render_simple_node(overflow_id, chunk, "dsType", stream) + overflow_ids.append(overflow_id) + + overflow_ref += 1 + + # Combine final lines after overflow handling. + final_lines = labels + label_extras + + # Append dimensions if requested. + if options.dimensions: + dims_str = html.escape(format_dimensions(options, node_data["dimensions"])) + final_lines.append(f"dimensions: {dims_str}") + + # Append storage class if requested. + if options.storage_classes: + final_lines.append(f"storage class: {html.escape(node_data['storage_class_name'])}") + + # Render the main dataset-type node. + _render_simple_node(node_key.node_id, final_lines, "dsType", stream) + + return overflow_ref, overflow_ids + + +def _render_simple_node(node_id: str, lines: list[str], node_class: str, stream: TextIO) -> None: + """Render a simple Mermaid node with given lines and a class. + + This helper function is used for both primary nodes and overflow nodes once + the split has been decided. + + Parameters + ---------- + node_id : str + Mermaid node ID. + lines : list[str] + Lines of HTML-formatted text to display in the node. + node_class : str + Mermaid class name to style the node (e.g., 'dsType', 'task', 'taskInit'). + stream : TextIO + The output stream. + """ + label = "
".join(lines) + print(f'{node_id}["{label}"]', file=stream) + print(f"class {node_id} {node_class};", file=stream) + + +def _render_edge(from_node_id: str, to_node_id: str, is_prerequisite: bool, stream: TextIO) -> None: + """Render a Mermaid edge from one node to another. + + Edges in Mermaid are normally specified as `A --> B`. Prerequisite edges + will later be styled as dashed lines using linkStyle after all edges have + been printed. + + Parameters + ---------- + from_node_id : str + The ID of the 'from' node in the edge. + to_node_id : str + The ID of the 'to' node in the edge. + is_prerequisite : bool + If True, this edge represents a prerequisite connection and will be + styled as dashed. + stream : TextIO + The output stream for Mermaid syntax. + """ + # At this stage, we simply print the edge. The styling (dashed) for + # prerequisite edges is applied afterwards via linkStyle lines. + print(f"{from_node_id} --> {to_node_id}", file=stream) + + +def _format_label( + label: str, + max_lines: int = 10, + min_common_prefix_len: int = 1000, +) -> tuple[list[str], list[str], str]: + """Parse and format a label into multiple lines with optional overflow + handling. + + This function attempts to cleanly format long labels by: + - Splitting the label by ", ". + - Identifying a common prefix to factor out if sufficiently long. + - Limiting the number of lines to 'max_lines', storing extras for potential + overflow. + + Parameters + ---------- + label : str + The raw label text, often derived from a NodeKey. + max_lines : int, optional + Maximum lines before overflow is triggered. + min_common_prefix_len : int, optional + Minimum length for considering a common prefix extraction. + + Returns + ------- + labels : list[str] + Main label lines as HTML-formatted text. + label_extras : list[str] + Overflow lines if the label is too long. + common_prefix : str + Extracted common prefix, if any. + """ + parsed_labels, parsed_label_extras, common_prefix = _parse_label(label, max_lines, min_common_prefix_len) + + # If there's a common prefix, present it bolded. + if common_prefix: + common_prefix = f"{html.escape(common_prefix)}:" + + indent = "  " if common_prefix else "" + labels = [f"{indent}{html.escape(el)}" for el in parsed_labels] + label_extras = [f"{indent}{html.escape(el)}" for el in parsed_label_extras] + + if common_prefix: + labels.insert(0, common_prefix) + + return labels, label_extras, common_prefix or "" + + +def _parse_label( + label: str, + max_lines: int, + min_common_prefix_len: int, +) -> tuple[list[str], list[str], str]: + """Split and process label text for overflow and common prefix extraction. + + Parameters + ---------- + label : str + The raw label text. + max_lines : int + Maximum number of lines before overflow. + min_common_prefix_len : int + Minimum length for a common prefix to be considered. + + Returns + ------- + labels : list[str] + The primary label lines. + label_extras : list[str] + Any overflow lines that exceed max_lines. + common_prefix : str + The extracted common prefix, if applicable. + """ + labels = label.split(", ") + common_prefix = os.path.commonprefix(labels) + + # If there's a long common prefix for multiple labels, factor it out at the + # nearest underscore. + if len(labels) > 3 and len(common_prefix) > min_common_prefix_len: + final_underscore_index = common_prefix.rfind("_") + if final_underscore_index > 0: + common_prefix = common_prefix[: final_underscore_index + 1] + labels = [element[len(common_prefix) :] for element in labels] + else: + common_prefix = "" + else: + common_prefix = "" + + # Handle overflow if needed. + if (len(labels) + bool(common_prefix)) > max_lines: + label_extras = labels[max_lines - bool(common_prefix) :] + labels = labels[: max_lines - bool(common_prefix)] + else: + label_extras = [] + + return labels, label_extras, common_prefix From 4defcf6f13607d1b3fd48dd456e6505b80204e71 Mon Sep 17 00:00:00 2001 From: Erfan Nourbakhsh Date: Wed, 18 Dec 2024 07:39:51 -0500 Subject: [PATCH 2/2] Add unit test for generating Mermaid graphs --- tests/test_mermaid_tools.py | 188 ++++++++++++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 tests/test_mermaid_tools.py diff --git a/tests/test_mermaid_tools.py b/tests/test_mermaid_tools.py new file mode 100644 index 000000000..2e8428361 --- /dev/null +++ b/tests/test_mermaid_tools.py @@ -0,0 +1,188 @@ +# This file is part of pipe_base. +# +# Developed for the LSST Data Management System. +# This product includes software developed by the LSST Project +# (https://www.lsst.org). +# See the COPYRIGHT file at the top-level directory of this distribution +# for details of code ownership. +# +# This software is dual licensed under the GNU General Public License and also +# under a 3-clause BSD license. Recipients may choose which of these licenses +# to use; please see the files gpl-3.0.txt and/or bsd_license.txt, +# respectively. If you choose the GPL option then the following text applies +# (but note that there is still no warranty even if you opt for BSD instead): +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +"""Simple unit test for Pipeline visualization. +""" + +import io +import unittest + +import lsst.pipe.base.connectionTypes as cT +import lsst.utils.tests +from lsst.pipe.base import Pipeline, PipelineTask, PipelineTaskConfig, PipelineTaskConnections +from lsst.pipe.base.mermaid_tools import pipeline2mermaid + + +class ExamplePipelineTaskConnections(PipelineTaskConnections, dimensions=()): + """Connections class used for testing. + + Parameters + ---------- + config : `PipelineTaskConfig` + The config to use for this connections class. + """ + + input1 = cT.Input( + name="", dimensions=["visit", "detector"], storageClass="example", doc="Input for this task" + ) + input2 = cT.Input( + name="", dimensions=["visit", "detector"], storageClass="example", doc="Input for this task" + ) + output1 = cT.Output( + name="", dimensions=["visit", "detector"], storageClass="example", doc="Output for this task" + ) + output2 = cT.Output( + name="", dimensions=["visit", "detector"], storageClass="example", doc="Output for this task" + ) + + def __init__(self, *, config=None): + super().__init__(config=config) + if not config.connections.input2: + self.inputs.remove("input2") + if not config.connections.output2: + self.outputs.remove("output2") + + +class ExamplePipelineTaskConfig(PipelineTaskConfig, pipelineConnections=ExamplePipelineTaskConnections): + """Example config used for testing.""" + + +def _makeConfig(inputName, outputName, pipeline, label): + """Add config overrides. + + Factory method for config instances. + + inputName and outputName can be either string or tuple of strings + with two items max. + """ + if isinstance(inputName, tuple): + pipeline.addConfigOverride(label, "connections.input1", inputName[0]) + pipeline.addConfigOverride(label, "connections.input2", inputName[1] if len(inputName) > 1 else "") + else: + pipeline.addConfigOverride(label, "connections.input1", inputName) + + if isinstance(outputName, tuple): + pipeline.addConfigOverride(label, "connections.output1", outputName[0]) + pipeline.addConfigOverride(label, "connections.output2", outputName[1] if len(outputName) > 1 else "") + else: + pipeline.addConfigOverride(label, "connections.output1", outputName) + + +class ExamplePipelineTask(PipelineTask): + """Example pipeline task used for testing.""" + + ConfigClass = ExamplePipelineTaskConfig + + +def _makePipeline(tasks): + """Generate Pipeline instance. + + Parameters + ---------- + tasks : list of tuples + Each tuple in the list has 3 or 4 items: + - input DatasetType name(s), string or tuple of strings + - output DatasetType name(s), string or tuple of strings + - task label, string + - optional task class object, can be None + + Returns + ------- + Pipeline instance + """ + pipe = Pipeline("test pipeline") + for task in tasks: + inputs = task[0] + outputs = task[1] + label = task[2] + klass = task[3] if len(task) > 3 else ExamplePipelineTask + pipe.addTask(klass, label) + _makeConfig(inputs, outputs, pipe, label) + return list(pipe.to_graph()._iter_task_defs()) + + +class MermaidToolsTestCase(unittest.TestCase): + """A test case for Mermaid tools.""" + + def test_pipeline2mermaid(self): + """Tests for mermaidt_tools.pipeline2mermaid method.""" + pipeline = _makePipeline( + [ + ("A", ("B", "C"), "task0"), + ("C", "E", "task1"), + ("B", "D", "task2"), + (("D", "E"), "F", "task3"), + ("D.C", "G", "task4"), + ("task3_metadata", "H", "task5"), + ] + ) + file = io.StringIO() + pipeline2mermaid(pipeline, file) + + # It's hard to validate complete output, just checking few basic + # things, even that is not terribly stable. + lines = file.getvalue().strip().split("\n") + nClassDefs = 2 + nTasks = 6 + nTaskClass = 6 + nDatasets = 10 + nDatasetClass = 10 + nEdges = 16 + nLinkStyle = 1 + nExtra = 3 # Opening, flowchart line, closing + + self.assertEqual( + len(lines), + nClassDefs + nTasks + nTaskClass + nDatasets + nDatasetClass + nEdges + nLinkStyle + nExtra, + ) + + # Make sure components are connected appropriately. + self.assertIn("DATASET_D --> DATASET_D_C", file.getvalue()) + + # Make sure there is a connection created for metadata if someone tries + # to read it in. + self.assertIn("TASK_3 --> DATASET_task3_metadata", file.getvalue()) + + +class MyMemoryTestCase(lsst.utils.tests.MemoryTestCase): + """Generic file handle leak check.""" + + +def setup_module(module): + """Set up the module for pytest. + + Parameters + ---------- + module : `~types.ModuleType` + Module to set up. + """ + lsst.utils.tests.init() + + +if __name__ == "__main__": + lsst.utils.tests.init() + unittest.main()