Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a test flag to visualize a test #212

Merged
merged 9 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- bump: patch
changes:
added: Visualization option when running tests
6 changes: 6 additions & 0 deletions policyengine_core/scripts/policyengine_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def build_test_parser(parser):
default=False,
help="increase output verbosity. If specified, output the entire calculation trace.",
)
parser.add_argument(
"--visualize",
action="store_true",
default=False,
help="output a relationship graph of the variables being tested",
)
parser.add_argument(
"-a",
"--aggregate",
Expand Down
1 change: 1 addition & 0 deletions policyengine_core/scripts/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def main(parser):
"name_filter": args.name_filter,
"only_variables": args.only_variables,
"ignore_variables": args.ignore_variables,
"visualize": args.visualize,
}

paths = [os.path.abspath(path) for path in args.path]
Expand Down
23 changes: 22 additions & 1 deletion policyengine_core/tools/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def apply(self):
verbose = self.options.get("verbose")
performance_graph = self.options.get("performance_graph")
performance_tables = self.options.get("performance_tables")
visualize = self.options.get("visualize")

try:
builder.set_default_period(period)
Expand All @@ -256,7 +257,7 @@ def apply(self):

try:
self.simulation.trace = (
verbose or performance_graph or performance_tables
verbose or performance_graph or performance_tables or visualize
)
self.check_output()
finally:
Expand All @@ -267,6 +268,8 @@ def apply(self):
self.generate_performance_graph(tracer)
if performance_tables:
self.generate_performance_tables(tracer)
if visualize:
self.generate_variable_graph(tracer)

def print_computation_log(self, tracer):
print("Computation log:") # noqa T001
Expand All @@ -278,6 +281,24 @@ def generate_performance_graph(self, tracer):
def generate_performance_tables(self, tracer):
tracer.generate_performance_tables(".")

def generate_variable_graph(self, tracer):
tracer.generate_variable_graph(
self.test.get("name"), self._all_output_vars()
)

def _all_output_vars(self):
return self._get_leaf_keys(self.test["output"])

def _get_leaf_keys(self, dictionary: dict):
keys = []
for key, value in dictionary.items():
if type(value) is dict:
keys.extend(self._get_leaf_keys(value))
else:
keys.append(key)

return keys

def check_output(self):
output = self.test.get("output")

Expand Down
1 change: 1 addition & 0 deletions policyengine_core/tracers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .flat_trace import FlatTrace
from .full_tracer import FullTracer
from .performance_log import PerformanceLog
from .variable_graph import VariableGraph
from .simple_tracer import SimpleTracer
from .trace_node import TraceNode
from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant
11 changes: 11 additions & 0 deletions policyengine_core/tracers/full_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def computation_log(self) -> tracers.ComputationLog:
def performance_log(self) -> tracers.PerformanceLog:
return tracers.PerformanceLog(self)

@property
def variable_graph(self) -> tracers.VariableGraph:
return tracers.VariableGraph(self)

@property
def flat_trace(self) -> tracers.FlatTrace:
return tracers.FlatTrace(self)
Expand All @@ -139,6 +143,13 @@ def generate_performance_graph(self, dir_path: str) -> None:
def generate_performance_tables(self, dir_path: str) -> None:
self.performance_log.generate_performance_tables(dir_path)

def generate_variable_graph(
self, name: str, output_vars: list[str]
) -> None:
self.variable_graph.visualize(
name, aggregate=False, max_depth=None, output_vars=output_vars
)

def _get_nb_requests(self, tree: tracers.TraceNode, variable: str) -> int:
tree_call = tree.name == variable
children_calls = sum(
Expand Down
232 changes: 232 additions & 0 deletions policyengine_core/tracers/variable_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
from __future__ import annotations

import os
import sys
import typing
from typing import Optional, Union

import numpy
from pyvis.node import Node

from policyengine_core.enums import EnumArray

from .. import tracers

from pyvis.network import Network

if typing.TYPE_CHECKING:
from numpy.typing import ArrayLike

Array = Union[EnumArray, ArrayLike]


class VariableGraph:
_full_tracer: tracers.FullTracer

NETWORK_OPTIONS = """
const options = {
"physics": {
"repulsion": {
"theta": 1,
"centralGravity": 0,
"springLength": 255,
"springConstant": 0.06,
"damping": 1,
"avoidOverlap": 1
},
"minVelocity": 0.75,
"solver": "repulsion"
}
}
"""

def __init__(self, full_tracer: tracers.FullTracer) -> None:
self._full_tracer = full_tracer

def tree(
self,
aggregate: bool = False,
max_depth: Optional[int] = None,
output_vars: list[str] = [],
) -> VisualizeNode:
depth = 1
is_root = True

node_by_tree = [
self._get_node(node, depth, aggregate, max_depth, output_vars)
for node in self._full_tracer.trees
]

return node_by_tree

def visualize(
self,
name: str,
output_vars: list[str] = [],
aggregate=False,
max_depth: Optional[int] = None,
dir="_variable_graphs",
) -> None:
"""
Visualize the computation log of a simulation as a relationship graph in the web browser.
"""

try:
os.mkdir(dir)
except FileExistsError:
pass
os.chdir(dir)

try:
net = self._network()

for root_node in self.tree(aggregate, max_depth, output_vars):
self._add_nodes_and_edges(net, root_node)

file_name = f"{self._to_snake_case(name)}.html"

# redirect stdout to prevent net.show from printing the file name
old_stdout = sys.stdout # backup current stdout
sys.stdout = open(os.devnull, "w")

net.show(file_name, notebook=False)

sys.stdout = old_stdout
finally:
os.chdir("..")

def _network(self) -> Network:
net = Network(
height="100vh",
directed=True,
select_menu=True,
neighborhood_highlight=True,
)
Network.set_options(net, self.NETWORK_OPTIONS)

return net

def _add_nodes_and_edges(self, net: Network, root_node: VisualizeNode):
stack = [root_node]
edges: set[tuple[str, str]] = set()

while len(stack) > 0:
node = stack.pop()
id = node.name

net_node = self._get_network_node(net, id)
if net_node is not None:
if node.value not in net_node["title"]:
net_node["title"] += "\n" + node.value

continue

net.add_node(
id, color=node.color(), title=node.value, label=node.name
)

for child in node.children:
edge = (id, child.name)
edges.add(edge)
stack.append(child)

for parent, child in edges:
net.add_edge(child, parent)

def _get_network_node(self, net: Network, id: str):
try:
return net.get_node(id)
except KeyError:
return None

def _get_node(
self,
node: tracers.TraceNode,
depth: int,
aggregate: bool,
max_depth: Optional[int],
output_vars: list[str],
) -> VisualizeNode:
if max_depth is not None and depth > max_depth:
return []

children = [
self._get_node(child, depth + 1, aggregate, max_depth, output_vars)
for child in node.children
]

is_leaf = len(node.children) == 0
visualization_node = VisualizeNode(
node,
children,
is_leaf=is_leaf,
aggregate=aggregate,
output_vars=output_vars,
)

return visualization_node

def _to_snake_case(self, string: str):
return string.replace(" ", "_").lower()


class VisualizeNode:
DEFAULT_COLOR = "#BFD0DF"
LEAF_COLOR = "#0099FF"
ROOT_COLOR = "#7B61FF"

def __init__(
self,
node: tracers.TraceNode,
children: list[VisualizeNode],
is_leaf=False,
aggregate=False,
output_vars: list[str] = [],
):
self.node = node
self.name = node.name
self.children = children
self.is_leaf = is_leaf
self.is_root = self.name in output_vars
self.value = self._value(aggregate)

def _display(
self,
value: Optional[Array],
) -> str:
if isinstance(value, EnumArray):
value = value.decode_to_str()

return numpy.array2string(value, max_line_width=float("inf"))

def _value(self, aggregate: bool) -> str:
value = self.node.value

if value is None:
formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}"

elif aggregate:
try:
formatted_value = str(
{
"avg": numpy.mean(value),
"max": numpy.max(value),
"min": numpy.min(value),
}
)

except TypeError:
formatted_value = "{'avg': '?', 'max': '?', 'min': '?'}"

else:
formatted_value = self._display(value)

return f"{self.node.period}: {formatted_value}"

def color(self) -> str:
if self.is_root:
return self.ROOT_COLOR
if self.is_leaf:
return self.LEAF_COLOR

return self.DEFAULT_COLOR
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"pandas>=1",
"plotly>=5.6.0,<6",
"ipython>=7.17.0,<8",
"pyvis>=0.3.2",
]

dev_requirements = [
Expand Down
Loading