From 002c365994a62d28d0be2f08cb6c768c6fd0e999 Mon Sep 17 00:00:00 2001 From: Caleb Date: Tue, 28 May 2024 13:26:29 -0600 Subject: [PATCH 1/9] add a visualize flag --- policyengine_core/scripts/policyengine_command.py | 6 ++++++ policyengine_core/scripts/run_test.py | 1 + 2 files changed, 7 insertions(+) diff --git a/policyengine_core/scripts/policyengine_command.py b/policyengine_core/scripts/policyengine_command.py index e18fa3888..2319ce951 100644 --- a/policyengine_core/scripts/policyengine_command.py +++ b/policyengine_core/scripts/policyengine_command.py @@ -59,6 +59,12 @@ def build_test_parser(parser): default=False, help="increase output verbosity. If specified, output the entire calculation trace.", ) + parser.add_argument( + "--visualize", + action="store_true", + default=False, + help="output a relationship graph of the variables being tested", + ) parser.add_argument( "-a", "--aggregate", diff --git a/policyengine_core/scripts/run_test.py b/policyengine_core/scripts/run_test.py index 6b39ebcac..5eb4c7c7e 100644 --- a/policyengine_core/scripts/run_test.py +++ b/policyengine_core/scripts/run_test.py @@ -29,6 +29,7 @@ def main(parser): "name_filter": args.name_filter, "only_variables": args.only_variables, "ignore_variables": args.ignore_variables, + "visualize": args.visualize, } paths = [os.path.abspath(path) for path in args.path] From 2438fe76c754afb805f8b4e38dcee6e6105f3b78 Mon Sep 17 00:00:00 2001 From: Caleb Date: Tue, 28 May 2024 14:18:31 -0600 Subject: [PATCH 2/9] create tree data structure for variable graph --- policyengine_core/tools/test_runner.py | 8 +- policyengine_core/tracers/__init__.py | 1 + policyengine_core/tracers/full_tracer.py | 15 ++- policyengine_core/tracers/variable_graph.py | 125 ++++++++++++++++++++ 4 files changed, 142 insertions(+), 7 deletions(-) create mode 100644 policyengine_core/tracers/variable_graph.py diff --git a/policyengine_core/tools/test_runner.py b/policyengine_core/tools/test_runner.py index 5400cddf7..0e9ddeb90 100644 --- a/policyengine_core/tools/test_runner.py +++ b/policyengine_core/tools/test_runner.py @@ -234,6 +234,7 @@ def apply(self): verbose = self.options.get("verbose") performance_graph = self.options.get("performance_graph") performance_tables = self.options.get("performance_tables") + visualize = self.options.get("visualize") try: builder.set_default_period(period) @@ -256,7 +257,7 @@ def apply(self): try: self.simulation.trace = ( - verbose or performance_graph or performance_tables + verbose or performance_graph or performance_tables or visualize ) self.check_output() finally: @@ -267,6 +268,8 @@ def apply(self): self.generate_performance_graph(tracer) if performance_tables: self.generate_performance_tables(tracer) + if visualize: + self.generate_variable_graph(tracer) def print_computation_log(self, tracer): print("Computation log:") # noqa T001 @@ -278,6 +281,9 @@ def generate_performance_graph(self, tracer): def generate_performance_tables(self, tracer): tracer.generate_performance_tables(".") + def generate_variable_graph(self, tracer): + tracer.generate_variable_graph(".") + def check_output(self): output = self.test.get("output") diff --git a/policyengine_core/tracers/__init__.py b/policyengine_core/tracers/__init__.py index 3fc28b4f4..e6b07cbd5 100644 --- a/policyengine_core/tracers/__init__.py +++ b/policyengine_core/tracers/__init__.py @@ -2,6 +2,7 @@ from .flat_trace import FlatTrace from .full_tracer import FullTracer from .performance_log import PerformanceLog +from .variable_graph import VariableGraph from .simple_tracer import SimpleTracer from .trace_node import TraceNode from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant diff --git a/policyengine_core/tracers/full_tracer.py b/policyengine_core/tracers/full_tracer.py index a4cb71bbd..2f1d7b416 100644 --- a/policyengine_core/tracers/full_tracer.py +++ b/policyengine_core/tracers/full_tracer.py @@ -30,9 +30,7 @@ def record_calculation_start( period: str, branch_name: str = "default", ) -> None: - self._simple_tracer.record_calculation_start( - variable, period, branch_name - ) + self._simple_tracer.record_calculation_start(variable, period, branch_name) self._enter_calculation(variable, period, branch_name) self._record_start_time() @@ -123,6 +121,10 @@ def computation_log(self) -> tracers.ComputationLog: def performance_log(self) -> tracers.PerformanceLog: return tracers.PerformanceLog(self) + @property + def variable_graph(self) -> tracers.VariableGraph: + return tracers.VariableGraph(self) + @property def flat_trace(self) -> tracers.FlatTrace: return tracers.FlatTrace(self) @@ -139,6 +141,9 @@ def generate_performance_graph(self, dir_path: str) -> None: def generate_performance_tables(self, dir_path: str) -> None: self.performance_log.generate_performance_tables(dir_path) + def generate_variable_graph(self, dir_path: str) -> None: + self.variable_graph.visualize(False, max_depth=None) + def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: tree_call = tree.name == variable children_calls = sum( @@ -148,9 +153,7 @@ def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: return tree_call + children_calls def get_nb_requests(self, variable: str) -> int: - return sum( - self._get_nb_requests(tree, variable) for tree in self.trees - ) + return sum(self._get_nb_requests(tree, variable) for tree in self.trees) def get_flat_trace(self) -> dict: return self.flat_trace.get_trace() diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py new file mode 100644 index 000000000..e1e93e551 --- /dev/null +++ b/policyengine_core/tracers/variable_graph.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import typing +from typing import Optional, Union + +import numpy + +from policyengine_core.enums import EnumArray + +from .. import tracers + +if typing.TYPE_CHECKING: + from numpy.typing import ArrayLike + + Array = Union[EnumArray, ArrayLike] + + +class VariableGraph: + _full_tracer: tracers.FullTracer + + def __init__(self, full_tracer: tracers.FullTracer) -> None: + self._full_tracer = full_tracer + + def tree( + self, + aggregate: bool = False, + max_depth: Optional[int] = None, + ) -> VisualizeNode: + depth = 1 + + node_by_tree = [ + self._get_node(node, depth, aggregate, max_depth) + for node in self._full_tracer.trees + ] + + return node_by_tree + + def visualize(self, aggregate=False, max_depth: Optional[int] = None) -> None: + """ + Visualize the computation log of a simulation as a relationship graph in the web browser. + + If ``aggregate`` is ``False`` (default), visualize the value of each + computed vector. + + If ``aggregate`` is ``True``, only the minimum, maximum, and + average value will be used of each computed vector. + + This mode is more suited for simulations on a large population. + + If ``max_depth`` is ``None`` (default), visualize the entire computation. + + If ``max_depth`` is set, for example to ``3``, only visualize computed + vectors up to a depth of ``max_depth``. + """ + for tree in self.tree(aggregate, max_depth): + print(tree.value) + + def _get_node( + self, + node: tracers.TraceNode, + depth: int, + aggregate: bool, + max_depth: Optional[int], + ) -> VisualizeNode: + if max_depth is not None and depth > max_depth: + return [] + + children = [ + self._get_node(child, depth + 1, aggregate, max_depth) + for child in node.children + ] + + is_leaf = len(node.children) == 0 + visualization_node = VisualizeNode( + node, children, is_leaf=is_leaf, aggregate=aggregate + ) + + return visualization_node + + +class VisualizeNode: + def __init__( + self, + node: tracers.TraceNode, + children: list[VisualizeNode], + is_leaf=False, + aggregate=False, + ): + self.node = node + self.children = children + self.is_leaf = is_leaf + self.value = self._value(aggregate) + + def _display( + self, + value: Optional[Array], + ) -> str: + if isinstance(value, EnumArray): + value = value.decode_to_str() + + return numpy.array2string(value, max_line_width=float("inf")) + + def _value(self, aggregate: bool) -> str: + value = self.node.value + + if value is None: + formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + + elif aggregate: + try: + formatted_value = str( + { + "avg": numpy.mean(value), + "max": numpy.max(value), + "min": numpy.min(value), + } + ) + + except TypeError: + formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}" + + else: + formatted_value = self._display(value) + + return f"{self.node.name}<{self.node.period}, ({self.node.branch_name})> = {formatted_value}" From fa8225b46ea9fd1a1ab48b9bd22d1dbda4c28b9b Mon Sep 17 00:00:00 2001 From: Caleb Date: Tue, 28 May 2024 15:26:27 -0600 Subject: [PATCH 3/9] graph the variable trees --- policyengine_core/tracers/variable_graph.py | 86 +++++++++++++++++++-- setup.py | 1 + 2 files changed, 81 insertions(+), 6 deletions(-) diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py index e1e93e551..a0ace9999 100644 --- a/policyengine_core/tracers/variable_graph.py +++ b/policyengine_core/tracers/variable_graph.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os +import webbrowser import typing from typing import Optional, Union @@ -9,6 +11,8 @@ from .. import tracers +from pyvis.network import Network + if typing.TYPE_CHECKING: from numpy.typing import ArrayLike @@ -18,6 +22,23 @@ class VariableGraph: _full_tracer: tracers.FullTracer + NETWORK_OPTIONS = """ + const options = { + "physics": { + "repulsion": { + "theta": 1, + "centralGravity": 0, + "springLength": 255, + "springConstant": 0.06, + "damping": 1, + "avoidOverlap": 1 + }, + "minVelocity": 0.75, + "solver": "repulsion" + } + } + """ + def __init__(self, full_tracer: tracers.FullTracer) -> None: self._full_tracer = full_tracer @@ -27,15 +48,21 @@ def tree( max_depth: Optional[int] = None, ) -> VisualizeNode: depth = 1 + is_root = True node_by_tree = [ - self._get_node(node, depth, aggregate, max_depth) + self._get_node(node, depth, aggregate, max_depth, is_root) for node in self._full_tracer.trees ] return node_by_tree - def visualize(self, aggregate=False, max_depth: Optional[int] = None) -> None: + def visualize( + self, + aggregate=False, + max_depth: Optional[int] = None, + dir="_variable_graphs", + ) -> None: """ Visualize the computation log of a simulation as a relationship graph in the web browser. @@ -52,8 +79,39 @@ def visualize(self, aggregate=False, max_depth: Optional[int] = None) -> None: If ``max_depth`` is set, for example to ``3``, only visualize computed vectors up to a depth of ``max_depth``. """ - for tree in self.tree(aggregate, max_depth): - print(tree.value) + + for root_node in self.tree(aggregate, max_depth): + net = self._network() + self._add_nodes_and_edges(net, root_node) + + file_name = "nx.html" + + net.show(file_name, notebook=False) + + def _network(self) -> Network: + net = Network( + height="100vh", directed=True, select_menu=True, neighborhood_highlight=True + ) + Network.set_options(net, self.NETWORK_OPTIONS) + + return net + + def _add_nodes_and_edges(self, net: Network, root_node: VisualizeNode): + stack = [root_node] + edges: set[tuple[str, str]] = set() + + while len(stack) > 0: + node = stack.pop() + + net.add_node(node.name, color=node.color(), title=node.value) + + for child in node.children: + edge = (node.name, child.name) + edges.add(edge) + stack.append(child) + + for parent, child in edges: + net.add_edge(child, parent) def _get_node( self, @@ -61,34 +119,42 @@ def _get_node( depth: int, aggregate: bool, max_depth: Optional[int], + is_root: bool, ) -> VisualizeNode: if max_depth is not None and depth > max_depth: return [] children = [ - self._get_node(child, depth + 1, aggregate, max_depth) + self._get_node(child, depth + 1, aggregate, max_depth, False) for child in node.children ] is_leaf = len(node.children) == 0 visualization_node = VisualizeNode( - node, children, is_leaf=is_leaf, aggregate=aggregate + node, children, is_leaf=is_leaf, aggregate=aggregate, is_root=is_root ) return visualization_node class VisualizeNode: + DEFAULT_COLOR = "#BFD0DF" + LEAF_COLOR = "#0099FF" + ROOT_COLOR = "#7B61FF" + def __init__( self, node: tracers.TraceNode, children: list[VisualizeNode], is_leaf=False, aggregate=False, + is_root=False, ): self.node = node + self.name = node.name self.children = children self.is_leaf = is_leaf + self.is_root = is_root self.value = self._value(aggregate) def _display( @@ -123,3 +189,11 @@ def _value(self, aggregate: bool) -> str: formatted_value = self._display(value) return f"{self.node.name}<{self.node.period}, ({self.node.branch_name})> = {formatted_value}" + + def color(self) -> str: + if self.is_root: + return self.ROOT_COLOR + if self.is_leaf: + return self.LEAF_COLOR + + return self.DEFAULT_COLOR diff --git a/setup.py b/setup.py index 2834e8a02..3dd47b05c 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "pandas>=1", "plotly>=5.6.0,<6", "ipython>=7.17.0,<8", + "pyvis>=0.3.2", ] dev_requirements = [ From 7fccc2d278662011d53121c754e9e539906667ee Mon Sep 17 00:00:00 2001 From: Caleb Date: Tue, 28 May 2024 16:11:48 -0600 Subject: [PATCH 4/9] make output variables a different color --- policyengine_core/tools/test_runner.py | 64 +++++++++------------ policyengine_core/tracers/full_tracer.py | 6 +- policyengine_core/tracers/variable_graph.py | 37 ++++++++---- 3 files changed, 59 insertions(+), 48 deletions(-) diff --git a/policyengine_core/tools/test_runner.py b/policyengine_core/tools/test_runner.py index 0e9ddeb90..406a67170 100644 --- a/policyengine_core/tools/test_runner.py +++ b/policyengine_core/tools/test_runner.py @@ -156,9 +156,7 @@ class YamlItem(pytest.Item): Terminal nodes of the test collection tree. """ - def __init__( - self, *, baseline_tax_benefit_system, test, options, **kwargs - ): + def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs): super(YamlItem, self).__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options @@ -238,9 +236,7 @@ def apply(self): try: builder.set_default_period(period) - self.simulation = builder.build_from_dict( - self.tax_benefit_system, input - ) + self.simulation = builder.build_from_dict(self.tax_benefit_system, input) except (VariableNotFoundError, SituationParsingError): raise except Exception as e: @@ -282,7 +278,20 @@ def generate_performance_tables(self, tracer): tracer.generate_performance_tables(".") def generate_variable_graph(self, tracer): - tracer.generate_variable_graph(".") + tracer.generate_variable_graph(self.test.get("name"), self._all_output_vars()) + + def _all_output_vars(self): + return self._get_leaf_keys(self.test["output"]) + + def _get_leaf_keys(self, dictionary: dict): + keys = [] + for key, value in dictionary.items(): + if type(value) is dict: + keys.extend(self._get_leaf_keys(value)) + else: + keys.append(key) + + return keys def check_output(self): output = self.test.get("output") @@ -290,19 +299,11 @@ def check_output(self): if output is None: return for key, expected_value in output.items(): - if self.tax_benefit_system.get_variable( - key - ): # If key is a variable - self.check_variable( - key, expected_value, self.test.get("period") - ) - elif self.simulation.populations.get( - key - ): # If key is an entity singular + if self.tax_benefit_system.get_variable(key): # If key is a variable + self.check_variable(key, expected_value, self.test.get("period")) + elif self.simulation.populations.get(key): # If key is an entity singular for variable_name, value in expected_value.items(): - self.check_variable( - variable_name, value, self.test.get("period") - ) + self.check_variable(variable_name, value, self.test.get("period")) else: population = self.simulation.get_population(plural=key) if population is not None: # If key is an entity plural @@ -318,9 +319,7 @@ def check_output(self): else: raise VariableNotFoundError(key, self.tax_benefit_system) - def check_variable( - self, variable_name, expected_value, period, entity_index=None - ): + def check_variable(self, variable_name, expected_value, period, entity_index=None): if self.should_ignore_variable(variable_name): return if isinstance(expected_value, dict): @@ -414,12 +413,7 @@ def _get_tax_benefit_system( key = hash( ( id(baseline), - ":".join( - [ - reform if isinstance(reform, str) else "" - for reform in reforms - ] - ), + ":".join([reform if isinstance(reform, str) else "" for reform in reforms]), reform_key, frozenset(extensions), ) @@ -431,13 +425,11 @@ def _get_tax_benefit_system( for reform_path in reforms: if isinstance(reform_path, str): - current_tax_benefit_system = ( - current_tax_benefit_system.apply_reform(reform_path) + current_tax_benefit_system = current_tax_benefit_system.apply_reform( + reform_path ) else: - current_tax_benefit_system = reform_path( - current_tax_benefit_system - ) + current_tax_benefit_system = reform_path(current_tax_benefit_system) current_tax_benefit_system._parameters_at_instant_cache = {} for extension in extensions: @@ -487,9 +479,9 @@ def assert_near( value = np.array(value).astype(np.float32) except ValueError: # Data type not translatable to floating point, assert complete equality - assert np.array(value) == np.array( - target_value - ), "{}{} differs from {}".format(message, value, target_value) + assert np.array(value) == np.array(target_value), "{}{} differs from {}".format( + message, value, target_value + ) return diff = abs(target_value - value) diff --git a/policyengine_core/tracers/full_tracer.py b/policyengine_core/tracers/full_tracer.py index 2f1d7b416..3cab802ec 100644 --- a/policyengine_core/tracers/full_tracer.py +++ b/policyengine_core/tracers/full_tracer.py @@ -141,8 +141,10 @@ def generate_performance_graph(self, dir_path: str) -> None: def generate_performance_tables(self, dir_path: str) -> None: self.performance_log.generate_performance_tables(dir_path) - def generate_variable_graph(self, dir_path: str) -> None: - self.variable_graph.visualize(False, max_depth=None) + def generate_variable_graph(self, name: str, output_vars: list[str]) -> None: + self.variable_graph.visualize( + name, aggregate=False, max_depth=None, output_vars=output_vars + ) def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: tree_call = tree.name == variable diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py index a0ace9999..0ea82edef 100644 --- a/policyengine_core/tracers/variable_graph.py +++ b/policyengine_core/tracers/variable_graph.py @@ -1,7 +1,7 @@ from __future__ import annotations import os -import webbrowser +import sys import typing from typing import Optional, Union @@ -46,12 +46,13 @@ def tree( self, aggregate: bool = False, max_depth: Optional[int] = None, + output_vars: list[str] = [], ) -> VisualizeNode: depth = 1 is_root = True node_by_tree = [ - self._get_node(node, depth, aggregate, max_depth, is_root) + self._get_node(node, depth, aggregate, max_depth, output_vars) for node in self._full_tracer.trees ] @@ -59,9 +60,10 @@ def tree( def visualize( self, + name: str, + output_vars: list[str] = [], aggregate=False, max_depth: Optional[int] = None, - dir="_variable_graphs", ) -> None: """ Visualize the computation log of a simulation as a relationship graph in the web browser. @@ -80,14 +82,22 @@ def visualize( vectors up to a depth of ``max_depth``. """ - for root_node in self.tree(aggregate, max_depth): + i = 0 + for root_node in self.tree(aggregate, max_depth, output_vars): net = self._network() self._add_nodes_and_edges(net, root_node) - file_name = "nx.html" + file_name = f"{self._to_snake_case(name)}.{i}.html" + i += 1 + + # redirect stdout to prevent net.show from printing the file name + old_stdout = sys.stdout # backup current stdout + sys.stdout = open(os.devnull, "w") net.show(file_name, notebook=False) + sys.stdout = old_stdout + def _network(self) -> Network: net = Network( height="100vh", directed=True, select_menu=True, neighborhood_highlight=True @@ -119,23 +129,30 @@ def _get_node( depth: int, aggregate: bool, max_depth: Optional[int], - is_root: bool, + output_vars: list[str], ) -> VisualizeNode: if max_depth is not None and depth > max_depth: return [] children = [ - self._get_node(child, depth + 1, aggregate, max_depth, False) + self._get_node(child, depth + 1, aggregate, max_depth, output_vars) for child in node.children ] is_leaf = len(node.children) == 0 visualization_node = VisualizeNode( - node, children, is_leaf=is_leaf, aggregate=aggregate, is_root=is_root + node, + children, + is_leaf=is_leaf, + aggregate=aggregate, + output_vars=output_vars, ) return visualization_node + def _to_snake_case(self, string: str): + return string.replace(" ", "_").lower() + class VisualizeNode: DEFAULT_COLOR = "#BFD0DF" @@ -148,13 +165,13 @@ def __init__( children: list[VisualizeNode], is_leaf=False, aggregate=False, - is_root=False, + output_vars: list[str] = [], ): self.node = node self.name = node.name self.children = children self.is_leaf = is_leaf - self.is_root = is_root + self.is_root = self.name in output_vars self.value = self._value(aggregate) def _display( From 207a57c5352e20eea917b5f211e77960cc73f6b8 Mon Sep 17 00:00:00 2001 From: Caleb Date: Tue, 28 May 2024 16:17:56 -0600 Subject: [PATCH 5/9] put all graphs in a file --- policyengine_core/tracers/variable_graph.py | 31 ++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py index 0ea82edef..07dd99620 100644 --- a/policyengine_core/tracers/variable_graph.py +++ b/policyengine_core/tracers/variable_graph.py @@ -64,6 +64,7 @@ def visualize( output_vars: list[str] = [], aggregate=False, max_depth: Optional[int] = None, + dir="_variable_graphs", ) -> None: """ Visualize the computation log of a simulation as a relationship graph in the web browser. @@ -81,22 +82,32 @@ def visualize( If ``max_depth`` is set, for example to ``3``, only visualize computed vectors up to a depth of ``max_depth``. """ + + try: + os.mkdir(dir) + except FileExistsError: + pass + os.chdir(dir) i = 0 - for root_node in self.tree(aggregate, max_depth, output_vars): - net = self._network() - self._add_nodes_and_edges(net, root_node) - file_name = f"{self._to_snake_case(name)}.{i}.html" - i += 1 + try: + for root_node in self.tree(aggregate, max_depth, output_vars): + net = self._network() + self._add_nodes_and_edges(net, root_node) - # redirect stdout to prevent net.show from printing the file name - old_stdout = sys.stdout # backup current stdout - sys.stdout = open(os.devnull, "w") + file_name = f"{self._to_snake_case(name)}.{i}.html" + i += 1 - net.show(file_name, notebook=False) + # redirect stdout to prevent net.show from printing the file name + old_stdout = sys.stdout # backup current stdout + sys.stdout = open(os.devnull, "w") - sys.stdout = old_stdout + net.show(file_name, notebook=False) + + sys.stdout = old_stdout + finally: + os.chdir('..') def _network(self) -> Network: net = Network( From c2758ccb7a31cae63b8d03ee6e41e4eecb418abf Mon Sep 17 00:00:00 2001 From: Caleb Date: Tue, 28 May 2024 16:20:32 -0600 Subject: [PATCH 6/9] format --- policyengine_core/tools/test_runner.py | 53 +++++++++++++++------ policyengine_core/tracers/full_tracer.py | 12 +++-- policyengine_core/tracers/variable_graph.py | 13 +++-- 3 files changed, 55 insertions(+), 23 deletions(-) diff --git a/policyengine_core/tools/test_runner.py b/policyengine_core/tools/test_runner.py index 406a67170..102e7cd50 100644 --- a/policyengine_core/tools/test_runner.py +++ b/policyengine_core/tools/test_runner.py @@ -156,7 +156,9 @@ class YamlItem(pytest.Item): Terminal nodes of the test collection tree. """ - def __init__(self, *, baseline_tax_benefit_system, test, options, **kwargs): + def __init__( + self, *, baseline_tax_benefit_system, test, options, **kwargs + ): super(YamlItem, self).__init__(**kwargs) self.baseline_tax_benefit_system = baseline_tax_benefit_system self.options = options @@ -236,7 +238,9 @@ def apply(self): try: builder.set_default_period(period) - self.simulation = builder.build_from_dict(self.tax_benefit_system, input) + self.simulation = builder.build_from_dict( + self.tax_benefit_system, input + ) except (VariableNotFoundError, SituationParsingError): raise except Exception as e: @@ -278,7 +282,9 @@ def generate_performance_tables(self, tracer): tracer.generate_performance_tables(".") def generate_variable_graph(self, tracer): - tracer.generate_variable_graph(self.test.get("name"), self._all_output_vars()) + tracer.generate_variable_graph( + self.test.get("name"), self._all_output_vars() + ) def _all_output_vars(self): return self._get_leaf_keys(self.test["output"]) @@ -299,11 +305,19 @@ def check_output(self): if output is None: return for key, expected_value in output.items(): - if self.tax_benefit_system.get_variable(key): # If key is a variable - self.check_variable(key, expected_value, self.test.get("period")) - elif self.simulation.populations.get(key): # If key is an entity singular + if self.tax_benefit_system.get_variable( + key + ): # If key is a variable + self.check_variable( + key, expected_value, self.test.get("period") + ) + elif self.simulation.populations.get( + key + ): # If key is an entity singular for variable_name, value in expected_value.items(): - self.check_variable(variable_name, value, self.test.get("period")) + self.check_variable( + variable_name, value, self.test.get("period") + ) else: population = self.simulation.get_population(plural=key) if population is not None: # If key is an entity plural @@ -319,7 +333,9 @@ def check_output(self): else: raise VariableNotFoundError(key, self.tax_benefit_system) - def check_variable(self, variable_name, expected_value, period, entity_index=None): + def check_variable( + self, variable_name, expected_value, period, entity_index=None + ): if self.should_ignore_variable(variable_name): return if isinstance(expected_value, dict): @@ -413,7 +429,12 @@ def _get_tax_benefit_system( key = hash( ( id(baseline), - ":".join([reform if isinstance(reform, str) else "" for reform in reforms]), + ":".join( + [ + reform if isinstance(reform, str) else "" + for reform in reforms + ] + ), reform_key, frozenset(extensions), ) @@ -425,11 +446,13 @@ def _get_tax_benefit_system( for reform_path in reforms: if isinstance(reform_path, str): - current_tax_benefit_system = current_tax_benefit_system.apply_reform( - reform_path + current_tax_benefit_system = ( + current_tax_benefit_system.apply_reform(reform_path) ) else: - current_tax_benefit_system = reform_path(current_tax_benefit_system) + current_tax_benefit_system = reform_path( + current_tax_benefit_system + ) current_tax_benefit_system._parameters_at_instant_cache = {} for extension in extensions: @@ -479,9 +502,9 @@ def assert_near( value = np.array(value).astype(np.float32) except ValueError: # Data type not translatable to floating point, assert complete equality - assert np.array(value) == np.array(target_value), "{}{} differs from {}".format( - message, value, target_value - ) + assert np.array(value) == np.array( + target_value + ), "{}{} differs from {}".format(message, value, target_value) return diff = abs(target_value - value) diff --git a/policyengine_core/tracers/full_tracer.py b/policyengine_core/tracers/full_tracer.py index 3cab802ec..052d5c2b5 100644 --- a/policyengine_core/tracers/full_tracer.py +++ b/policyengine_core/tracers/full_tracer.py @@ -30,7 +30,9 @@ def record_calculation_start( period: str, branch_name: str = "default", ) -> None: - self._simple_tracer.record_calculation_start(variable, period, branch_name) + self._simple_tracer.record_calculation_start( + variable, period, branch_name + ) self._enter_calculation(variable, period, branch_name) self._record_start_time() @@ -141,7 +143,9 @@ def generate_performance_graph(self, dir_path: str) -> None: def generate_performance_tables(self, dir_path: str) -> None: self.performance_log.generate_performance_tables(dir_path) - def generate_variable_graph(self, name: str, output_vars: list[str]) -> None: + def generate_variable_graph( + self, name: str, output_vars: list[str] + ) -> None: self.variable_graph.visualize( name, aggregate=False, max_depth=None, output_vars=output_vars ) @@ -155,7 +159,9 @@ def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int: return tree_call + children_calls def get_nb_requests(self, variable: str) -> int: - return sum(self._get_nb_requests(tree, variable) for tree in self.trees) + return sum( + self._get_nb_requests(tree, variable) for tree in self.trees + ) def get_flat_trace(self) -> dict: return self.flat_trace.get_trace() diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py index 07dd99620..313b48cca 100644 --- a/policyengine_core/tracers/variable_graph.py +++ b/policyengine_core/tracers/variable_graph.py @@ -82,16 +82,16 @@ def visualize( If ``max_depth`` is set, for example to ``3``, only visualize computed vectors up to a depth of ``max_depth``. """ - + try: os.mkdir(dir) except FileExistsError: pass os.chdir(dir) - i = 0 - try: + i = 0 + for root_node in self.tree(aggregate, max_depth, output_vars): net = self._network() self._add_nodes_and_edges(net, root_node) @@ -107,11 +107,14 @@ def visualize( sys.stdout = old_stdout finally: - os.chdir('..') + os.chdir("..") def _network(self) -> Network: net = Network( - height="100vh", directed=True, select_menu=True, neighborhood_highlight=True + height="100vh", + directed=True, + select_menu=True, + neighborhood_highlight=True, ) Network.set_options(net, self.NETWORK_OPTIONS) From 042fb982ab4a7398dae1b6afbe261d45632046b9 Mon Sep 17 00:00:00 2001 From: Caleb Date: Tue, 28 May 2024 16:30:49 -0600 Subject: [PATCH 7/9] remove incorrect documentation --- policyengine_core/tracers/variable_graph.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py index 313b48cca..bccf228ec 100644 --- a/policyengine_core/tracers/variable_graph.py +++ b/policyengine_core/tracers/variable_graph.py @@ -68,19 +68,6 @@ def visualize( ) -> None: """ Visualize the computation log of a simulation as a relationship graph in the web browser. - - If ``aggregate`` is ``False`` (default), visualize the value of each - computed vector. - - If ``aggregate`` is ``True``, only the minimum, maximum, and - average value will be used of each computed vector. - - This mode is more suited for simulations on a large population. - - If ``max_depth`` is ``None`` (default), visualize the entire computation. - - If ``max_depth`` is set, for example to ``3``, only visualize computed - vectors up to a depth of ``max_depth``. """ try: From 63cb3e1b9fe87e1501172a259fc036a3b1c0fec6 Mon Sep 17 00:00:00 2001 From: Caleb Date: Wed, 29 May 2024 09:57:51 -0600 Subject: [PATCH 8/9] combine different periods in the same graph --- policyengine_core/tracers/variable_graph.py | 39 ++++++++++++++------- 1 file changed, 27 insertions(+), 12 deletions(-) diff --git a/policyengine_core/tracers/variable_graph.py b/policyengine_core/tracers/variable_graph.py index bccf228ec..677c65f86 100644 --- a/policyengine_core/tracers/variable_graph.py +++ b/policyengine_core/tracers/variable_graph.py @@ -6,6 +6,7 @@ from typing import Optional, Union import numpy +from pyvis.node import Node from policyengine_core.enums import EnumArray @@ -77,22 +78,20 @@ def visualize( os.chdir(dir) try: - i = 0 + net = self._network() for root_node in self.tree(aggregate, max_depth, output_vars): - net = self._network() self._add_nodes_and_edges(net, root_node) - file_name = f"{self._to_snake_case(name)}.{i}.html" - i += 1 + file_name = f"{self._to_snake_case(name)}.html" - # redirect stdout to prevent net.show from printing the file name - old_stdout = sys.stdout # backup current stdout - sys.stdout = open(os.devnull, "w") + # redirect stdout to prevent net.show from printing the file name + old_stdout = sys.stdout # backup current stdout + sys.stdout = open(os.devnull, "w") - net.show(file_name, notebook=False) + net.show(file_name, notebook=False) - sys.stdout = old_stdout + sys.stdout = old_stdout finally: os.chdir("..") @@ -113,17 +112,33 @@ def _add_nodes_and_edges(self, net: Network, root_node: VisualizeNode): while len(stack) > 0: node = stack.pop() + id = node.name - net.add_node(node.name, color=node.color(), title=node.value) + net_node = self._get_network_node(net, id) + if net_node is not None: + if node.value not in net_node["title"]: + net_node["title"] += "\n" + node.value + + continue + + net.add_node( + id, color=node.color(), title=node.value, label=node.name + ) for child in node.children: - edge = (node.name, child.name) + edge = (id, child.name) edges.add(edge) stack.append(child) for parent, child in edges: net.add_edge(child, parent) + def _get_network_node(self, net: Network, id: str): + try: + return net.get_node(id) + except KeyError: + return None + def _get_node( self, node: tracers.TraceNode, @@ -206,7 +221,7 @@ def _value(self, aggregate: bool) -> str: else: formatted_value = self._display(value) - return f"{self.node.name}<{self.node.period}, ({self.node.branch_name})> = {formatted_value}" + return f"{self.node.period}: {formatted_value}" def color(self) -> str: if self.is_root: From 3a7798ddeff29c64352d84884ce2925bb30eafa4 Mon Sep 17 00:00:00 2001 From: Caleb Date: Wed, 29 May 2024 13:47:08 -0600 Subject: [PATCH 9/9] add changelog entry --- changelog_entry.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29bb..02b79f1a4 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,3 @@ +- bump: patch + changes: + added: Visualization option when running tests