diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..02b79f1a 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,3 @@ +- bump: patch + changes: + added: Visualization option when running tests diff --git a/policyengine_core/scripts/policyengine_command.py b/policyengine_core/scripts/policyengine_command.py index e18fa388..2319ce95 100644 --- a/policyengine_core/scripts/policyengine_command.py +++ b/policyengine_core/scripts/policyengine_command.py @@ -59,6 +59,12 @@ def build_test_parser(parser): default=False, help="increase output verbosity. If specified, output the entire calculation trace.", ) + parser.add_argument( + "--visualize", + action="store_true", + default=False, + help="output a relationship graph of the variables being tested", + ) parser.add_argument( "-a", "--aggregate", diff --git a/policyengine_core/scripts/run_test.py b/policyengine_core/scripts/run_test.py index 6b39ebca..5eb4c7c7 100644 --- a/policyengine_core/scripts/run_test.py +++ b/policyengine_core/scripts/run_test.py @@ -29,6 +29,7 @@ def main(parser): "name_filter": args.name_filter, "only_variables": args.only_variables, "ignore_variables": args.ignore_variables, + "visualize": args.visualize, } paths = [os.path.abspath(path) for path in args.path] diff --git a/policyengine_core/tools/test_runner.py b/policyengine_core/tools/test_runner.py index 5400cddf..102e7cd5 100644 --- a/policyengine_core/tools/test_runner.py +++ b/policyengine_core/tools/test_runner.py @@ -234,6 +234,7 @@ def apply(self): verbose = self.options.get("verbose") performance_graph = self.options.get("performance_graph") performance_tables = self.options.get("performance_tables") + visualize = self.options.get("visualize") try: builder.set_default_period(period) @@ -256,7 +257,7 @@ def apply(self): try: self.simulation.trace = ( - verbose or performance_graph or performance_tables + verbose or performance_graph or performance_tables or visualize ) self.check_output() finally: @@ -267,6 +268,8 @@ def apply(self): self.generate_performance_graph(tracer) if performance_tables: self.generate_performance_tables(tracer) + if visualize: + self.generate_variable_graph(tracer) def print_computation_log(self, tracer): print("Computation log:") # noqa T001 @@ -278,6 +281,24 @@ def generate_performance_graph(self, tracer): def generate_performance_tables(self, tracer): tracer.generate_performance_tables(".") + def generate_variable_graph(self, tracer): + tracer.generate_variable_graph( + self.test.get("name"), self._all_output_vars() + ) + + def _all_output_vars(self): + return self._get_leaf_keys(self.test["output"]) + + def _get_leaf_keys(self, dictionary: dict): + keys = [] + for key, value in dictionary.items(): + if type(value) is dict: + keys.extend(self._get_leaf_keys(value)) + else: + keys.append(key) + + return keys + def check_output(self): output = self.test.get("output") diff --git a/policyengine_core/tracers/__init__.py b/policyengine_core/tracers/__init__.py index 3fc28b4f..e6b07cbd 100644 --- a/policyengine_core/tracers/__init__.py +++ b/policyengine_core/tracers/__init__.py @@ -2,6 +2,7 @@ from .flat_trace import FlatTrace from .full_tracer import FullTracer from .performance_log import PerformanceLog +from .variable_graph import VariableGraph from .simple_tracer import SimpleTracer from .trace_node import TraceNode from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant diff --git a/policyengine_core/tracers/full_tracer.py b/policyengine_core/tracers/full_tracer.py index a4cb71bb..052d5c2b 100644 --- a/policyengine_core/tracers/full_tracer.py +++ b/policyengine_core/tracers/full_tracer.py @@ -123,6 +123,10 @@ def computation_log(self) -> tracers.ComputationLog: def performance_log(self) -> tracers.PerformanceLog: return tracers.PerformanceLog(self) + @property + def variable_graph(self) -> tracers.VariableGraph: + return tracers.VariableGraph(self) + @property def flat_trace(self) -> tracers.FlatTrace: return tracers.FlatTrace(self) @@ -139,6 +143,13 @@ def generate_performance_graph(self, dir_path: str) -> None: def generate_performance_tables(self, dir_path: str) -> None: self.performance_log.generate_performance_tables(dir_path) + def generate_variable_graph( + self, name: str, output_vars: list[str] + ) -> None: + self.variable_graph.visualize( + name, aggregate=False, max_depth=None, output_vars=output_vars + ) + def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: tree_call = tree.name == variable children_calls = sum( diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py new file mode 100644 index 00000000..677c65f8 --- /dev/null +++ b/policyengine_core/tracers/variable_graph.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import os +import sys +import typing +from typing import Optional, Union + +import numpy +from pyvis.node import Node + +from policyengine_core.enums import EnumArray + +from .. import tracers + +from pyvis.network import Network + +if typing.TYPE_CHECKING: + from numpy.typing import ArrayLike + + Array = Union[EnumArray, ArrayLike] + + +class VariableGraph: + _full_tracer: tracers.FullTracer + + NETWORK_OPTIONS = """ + const options = { + "physics": { + "repulsion": { + "theta": 1, + "centralGravity": 0, + "springLength": 255, + "springConstant": 0.06, + "damping": 1, + "avoidOverlap": 1 + }, + "minVelocity": 0.75, + "solver": "repulsion" + } + } + """ + + def __init__(self, full_tracer: tracers.FullTracer) -> None: + self._full_tracer = full_tracer + + def tree( + self, + aggregate: bool = False, + max_depth: Optional[int] = None, + output_vars: list[str] = [], + ) -> VisualizeNode: + depth = 1 + is_root = True + + node_by_tree = [ + self._get_node(node, depth, aggregate, max_depth, output_vars) + for node in self._full_tracer.trees + ] + + return node_by_tree + + def visualize( + self, + name: str, + output_vars: list[str] = [], + aggregate=False, + max_depth: Optional[int] = None, + dir="_variable_graphs", + ) -> None: + """ + Visualize the computation log of a simulation as a relationship graph in the web browser. + """ + + try: + os.mkdir(dir) + except FileExistsError: + pass + os.chdir(dir) + + try: + net = self._network() + + for root_node in self.tree(aggregate, max_depth, output_vars): + self._add_nodes_and_edges(net, root_node) + + file_name = f"{self._to_snake_case(name)}.html" + + # redirect stdout to prevent net.show from printing the file name + old_stdout = sys.stdout # backup current stdout + sys.stdout = open(os.devnull, "w") + + net.show(file_name, notebook=False) + + sys.stdout = old_stdout + finally: + os.chdir("..") + + def _network(self) -> Network: + net = Network( + height="100vh", + directed=True, + select_menu=True, + neighborhood_highlight=True, + ) + Network.set_options(net, self.NETWORK_OPTIONS) + + return net + + def _add_nodes_and_edges(self, net: Network, root_node: VisualizeNode): + stack = [root_node] + edges: set[tuple[str, str]] = set() + + while len(stack) > 0: + node = stack.pop() + id = node.name + + net_node = self._get_network_node(net, id) + if net_node is not None: + if node.value not in net_node["title"]: + net_node["title"] += "\n" + node.value + + continue + + net.add_node( + id, color=node.color(), title=node.value, label=node.name + ) + + for child in node.children: + edge = (id, child.name) + edges.add(edge) + stack.append(child) + + for parent, child in edges: + net.add_edge(child, parent) + + def _get_network_node(self, net: Network, id: str): + try: + return net.get_node(id) + except KeyError: + return None + + def _get_node( + self, + node: tracers.TraceNode, + depth: int, + aggregate: bool, + max_depth: Optional[int], + output_vars: list[str], + ) -> VisualizeNode: + if max_depth is not None and depth > max_depth: + return [] + + children = [ + self._get_node(child, depth + 1, aggregate, max_depth, output_vars) + for child in node.children + ] + + is_leaf = len(node.children) == 0 + visualization_node = VisualizeNode( + node, + children, + is_leaf=is_leaf, + aggregate=aggregate, + output_vars=output_vars, + ) + + return visualization_node + + def _to_snake_case(self, string: str): + return string.replace(" ", "_").lower() + + +class VisualizeNode: + DEFAULT_COLOR = "#BFD0DF" + LEAF_COLOR = "#0099FF" + ROOT_COLOR = "#7B61FF" + + def __init__( + self, + node: tracers.TraceNode, + children: list[VisualizeNode], + is_leaf=False, + aggregate=False, + output_vars: list[str] = [], + ): + self.node = node + self.name = node.name + self.children = children + self.is_leaf = is_leaf + self.is_root = self.name in output_vars + self.value = self._value(aggregate) + + def _display( + self, + value: Optional[Array], + ) -> str: + if isinstance(value, EnumArray): + value = value.decode_to_str() + + return numpy.array2string(value, max_line_width=float("inf")) + + def _value(self, aggregate: bool) -> str: + value = self.node.value + + if value is None: + formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + + elif aggregate: + try: + formatted_value = str( + { + "avg": numpy.mean(value), + "max": numpy.max(value), + "min": numpy.min(value), + } + ) + + except TypeError: + formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + + else: + formatted_value = self._display(value) + + return f"{self.node.period}: {formatted_value}" + + def color(self) -> str: + if self.is_root: + return self.ROOT_COLOR + if self.is_leaf: + return self.LEAF_COLOR + + return self.DEFAULT_COLOR diff --git a/setup.py b/setup.py index 2834e8a0..3dd47b05 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "pandas>=1", "plotly>=5.6.0,<6", "ipython>=7.17.0,<8", + "pyvis>=0.3.2", ] dev_requirements = [