From f9e9f387654d16da53e67dcacd01e22919364a5d Mon Sep 17 00:00:00 2001 From: YouGuessedMyName Date: Thu, 26 Sep 2024 21:28:16 +0200 Subject: [PATCH 1/3] Added visualization test --- stormvogel/rdict.py | 5 +++-- stormvogel/show.py | 2 ++ stormvogel/visualization.py | 35 +++++++++++++++++++++--------- tests/test_layout.py | 15 ------------- tests/test_rdict.py | 43 +++++++++++++++++++++++++++++++++++++ tests/test_visualization.py | 22 +++++++++++++++++++ 6 files changed, 95 insertions(+), 27 deletions(-) create mode 100644 tests/test_rdict.py create mode 100644 tests/test_visualization.py diff --git a/stormvogel/rdict.py b/stormvogel/rdict.py index 86bed53..3bc2a4a 100644 --- a/stormvogel/rdict.py +++ b/stormvogel/rdict.py @@ -12,10 +12,10 @@ def rget(d: dict, path: list) -> Any: ) # Throws KeyError if key not present. -def rset(d: dict, path: list, value: Any) -> None: +def rset(d: dict, path: list, value: Any) -> dict: """Recursively set dict value.""" if len(path) == 0: - return + return d def __rset(d: dict, path: list, value: Any): first = path.pop(0) @@ -25,6 +25,7 @@ def __rset(d: dict, path: list, value: Any): __rset(d[first], path, value) __rset(d, copy.deepcopy(path), value) + return d def merge_dict(dict1: dict, dict2: dict) -> dict: diff --git a/stormvogel/show.py b/stormvogel/show.py index da6cb47..77c9adf 100644 --- a/stormvogel/show.py +++ b/stormvogel/show.py @@ -5,6 +5,7 @@ import stormvogel.visualization import stormvogel.layout_editor import stormvogel.communication_server +import stormvogel.result import ipywidgets as widgets import IPython.display as ipd @@ -23,6 +24,7 @@ def show( Args: model (Model): The stormvogel model to be displayed. + result (Result): A result associatied with the model. name (str, optional): Internally used name. Will be randomly generated if left as None. result (Result, optional): Result corresponding to the model. layout (Layout, optional): Layout used for the visualization. diff --git a/stormvogel/visualization.py b/stormvogel/visualization.py index 1d52d9a..24b5858 100644 --- a/stormvogel/visualization.py +++ b/stormvogel/visualization.py @@ -68,6 +68,18 @@ def __init__( self.separate_labels = list(map(und, separate_labels)) self.nt: stormvogel.visjs.Network | None = None + def prepare(self) -> None: + """Prepare to show the network. Don't call this method yourself, use show instead.""" + if self.nt is None: + return + if self.layout.layout["misc"]["explore"]: + self.nt.enable_exploration_mode(self.model.get_initial_state().id) + self.layout.set_groups(self.separate_labels) + self.__add_states() + self.__add_transitions() + self.__update_physics_enabled() + self.nt.set_options(str(self.layout)) + def show(self) -> None: """(Re-)load the Network and display if self.do_display is True.""" with self.debug_output: @@ -82,14 +94,9 @@ def show(self) -> None: debug_output=self.debug_output, do_display=False, ) - if self.layout.layout["misc"]["explore"]: - self.nt.enable_exploration_mode(self.model.get_initial_state().id) - self.layout.set_groups(self.separate_labels) - self.__add_states() - self.__add_transitions() - self.__update_physics_enabled() - self.nt.set_options(str(self.layout)) - self.nt.show() + self.prepare() + if self.nt is not None: + self.nt.show() self.maybe_display_output() def update(self) -> None: @@ -131,7 +138,7 @@ def __add_transitions(self) -> None: if self.nt is None: return action_id = self.ACTION_ID_OFFSET - # scheduler = self.result.scheduler if self.result is not None else None + scheduler = self.result.scheduler if self.result is not None else None # In the visualization, both actions and states are nodes, so we need to keep track of how many actions we already have. for state_id, transition in self.model.transitions.items(): for action, branch in transition.transition.items(): @@ -144,11 +151,19 @@ def __add_transitions(self) -> None: label=self.__format_probability(prob), ) else: + # Put the action in the group scheduled_actions if appropriate. + group = "actions" + if scheduler is not None: + choice = scheduler.get_choice_of_state( + state=self.model.get_state_by_id(state_id) + ) + if choice == action: + group = "scheduled_actions" # Add the action's node self.nt.add_node( id=action_id, label=action.name, - group="actions", + group=group, position_dict=self.layout.layout["positions"], ) # Add transition from this state TO the action. diff --git a/tests/test_layout.py b/tests/test_layout.py index b806e72..5f7f14c 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -5,21 +5,6 @@ from stormvogel.rdict import merge_dict -def test_layout_merge_dict(): - # Test priority for second dict. - d1 = {"a": 1, "b": 1, "c": 1} - d2 = {"b": 2} - assert {"a": 1, "b": 2, "c": 1} == merge_dict(d1, d2) - # Test conservation of elements in both dicts - d1 = {"a": 1, "b": 1} - d2 = {"c": 2, "d": 2} - assert {"a": 1, "b": 1, "c": 2, "d": 2} == merge_dict(d1, d2) - # Test nested - d1 = {"a": {"b": {"c": 1, "d": 1}}, "e": 1} - d2 = {"a": {"b": {"c": 2}}} - assert {"a": {"b": {"c": 2, "d": 1}}, "e": 1} == merge_dict(d1, d2) - - def test_layout_loading(): """Tests if str(Layout) returns the correctly loaded json string.""" with open(os.path.join(os.getcwd(), "stormvogel/layouts/default.json")) as f: diff --git a/tests/test_rdict.py b/tests/test_rdict.py new file mode 100644 index 0000000..3c133bd --- /dev/null +++ b/tests/test_rdict.py @@ -0,0 +1,43 @@ +from stormvogel.rdict import rget, rset, merge_dict + + +def test_rget(): + # Empty + d = {} + assert rget(d, []) == {} + # Simple + d = {"a": 1} + assert rget(d, ["a"]) == 1 + # Nested + d = {"a": {"b": {"c": 3, "b": 5}, "c": 2}, "c": 1} + assert rget(d, ["a", "b", "c"]) == 3 + + +def test_rset(): + # Empty path + d = {} + assert rset(d, [], 1) == {} + # Simple + d = {} + assert rset(d, ["a"], 1) == {"a": 1} + # Existing value + d = {"a": 0} + assert rset(d, ["a"], 1) == {"a": 1} + # Nested + d = {"a": {"b": 8}} + assert rset(d, ["a", "b"], {"c": 3}) == {"a": {"b": {"c": 3}}} + + +def test_merge_dict(): + # Test priority for second dict. + d1 = {"a": 1, "b": 1, "c": 1} + d2 = {"b": 2} + assert {"a": 1, "b": 2, "c": 1} == merge_dict(d1, d2) + # Test conservation of elements in both dicts + d1 = {"a": 1, "b": 1} + d2 = {"c": 2, "d": 2} + assert {"a": 1, "b": 1, "c": 2, "d": 2} == merge_dict(d1, d2) + # Test nested + d1 = {"a": {"b": {"c": 1, "d": 1}}, "e": 1} + d2 = {"a": {"b": {"c": 2}}} + assert {"a": {"b": {"c": 2, "d": 1}}, "e": 1} == merge_dict(d1, d2) diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..fb35187 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,22 @@ +from stormvogel.visualization import Visualization +from stormvogel.model import Model, ModelType + + +def test_visualization(mocker): + class MockNetwork: + add_node = mocker.stub(name="add_node_stub") + add_edge = mocker.stub(name="add_edge_stub") + set_options = mocker.stub(name="set_options_stub") + + model = Model("simple", ModelType.MDP) + model.new_state("one") + vis = Visualization(model) + vis.nt = MockNetwork + vis.prepare() + MockNetwork.add_node.assert_any_call( + 0, label="init", group="states", position_dict={} + ) + MockNetwork.add_node.assert_any_call( + 1, label="one", group="states", position_dict={} + ) + assert MockNetwork.add_node.call_count == 2 From da366d3a179c72901ff2a36ebd173fff44aa88ea Mon Sep 17 00:00:00 2001 From: YouGuessedMyName Date: Fri, 27 Sep 2024 10:31:30 +0200 Subject: [PATCH 2/3] Progress with visualization test --- stormvogel/communication_server.py | 11 +++---- stormvogel/show.py | 8 ++--- stormvogel/visjs.py | 8 +++-- stormvogel/visualization.py | 23 ++++++------- tests/test_visualization.py | 52 ++++++++++++++++++++++++++---- 5 files changed, 69 insertions(+), 33 deletions(-) diff --git a/stormvogel/communication_server.py b/stormvogel/communication_server.py index 76b73c6..69185b7 100644 --- a/stormvogel/communication_server.py +++ b/stormvogel/communication_server.py @@ -17,9 +17,6 @@ enable_server: bool = True """Disable if you don't want to use an internal communication server. Some features might break.""" -internal_enable_server: bool = enable_server - - localhost_address: str = "127.0.0.1" min_port = 8889 @@ -105,8 +102,8 @@ def request(self, js: str): Also waits for server to boot up if it is not finished yet. Should be thread safe. (I hope). WHEN SENDING JAVASCRIPT, DO NOT FORGET EXTRA QUOTES AROUND STRINGS.""" - global internal_enable_server - if not internal_enable_server: + global server + if server is None: raise TimeoutError("There is no server running.") global awaiting, server_running @@ -214,8 +211,8 @@ def initialize_server() -> CommunicationServer | None: """If server is None, then create a new server and store it in global variable server. Use the port stored in global variable server_port. """ - global server, server_port, internal_enable_server - if not internal_enable_server: + global server, server_port, enable_server + if not enable_server: return None output = widgets.Output() diff --git a/stormvogel/show.py b/stormvogel/show.py index 77c9adf..5a7d0a1 100644 --- a/stormvogel/show.py +++ b/stormvogel/show.py @@ -34,10 +34,9 @@ def show( """ if layout is None: layout = stormvogel.layout.DEFAULT() - if not show_editor or not stormvogel.communication_server.enable_server: - stormvogel.communication_server.internal_enable_server = False - else: - stormvogel.communication_server.internal_enable_server = True + do_init_server = False + if show_editor or stormvogel.communication_server.enable_server: + do_init_server = True do_display = not show_editor vis = stormvogel.visualization.Visualization( @@ -48,6 +47,7 @@ def show( separate_labels=separate_labels, do_display=do_display, debug_output=debug_output, + do_init_server=do_init_server, ) vis.show() diff --git a/stormvogel/visjs.py b/stormvogel/visjs.py index b910238..e71d56e 100644 --- a/stormvogel/visjs.py +++ b/stormvogel/visjs.py @@ -25,6 +25,7 @@ def __init__( output: widgets.Output | None = None, do_display: bool = True, debug_output: widgets.Output = widgets.Output(), + do_init_server: bool = True, ) -> None: """Display a visjs network using IPython. The network can display by itself or you can specify an Output widget in which it should be displayed. @@ -46,9 +47,10 @@ def __init__( self.edges_js: str = "" self.options_js: str = "{}" self.new_nodes_hidden: bool = False - self.server: stormvogel.communication_server.CommunicationServer = ( - stormvogel.communication_server.initialize_server() - ) + if do_init_server: + self.server: stormvogel.communication_server.CommunicationServer = ( + stormvogel.communication_server.initialize_server() + ) # Note that this refers to the same server as the global variable in stormvogel.communication_server. def enable_exploration_mode(self, initial_node_id: int): diff --git a/stormvogel/visualization.py b/stormvogel/visualization.py index 24b5858..ed1a6b1 100644 --- a/stormvogel/visualization.py +++ b/stormvogel/visualization.py @@ -42,6 +42,7 @@ def __init__( output: widgets.Output | None = None, do_display: bool = True, debug_output: widgets.Output = widgets.Output(), + do_init_server: bool = True, ) -> None: """Create visualization of a Model using a pyvis Network @@ -67,18 +68,7 @@ def __init__( self.layout: stormvogel.layout.Layout = layout self.separate_labels = list(map(und, separate_labels)) self.nt: stormvogel.visjs.Network | None = None - - def prepare(self) -> None: - """Prepare to show the network. Don't call this method yourself, use show instead.""" - if self.nt is None: - return - if self.layout.layout["misc"]["explore"]: - self.nt.enable_exploration_mode(self.model.get_initial_state().id) - self.layout.set_groups(self.separate_labels) - self.__add_states() - self.__add_transitions() - self.__update_physics_enabled() - self.nt.set_options(str(self.layout)) + self.do_init_server = do_init_server def show(self) -> None: """(Re-)load the Network and display if self.do_display is True.""" @@ -93,8 +83,15 @@ def show(self) -> None: output=self.output, debug_output=self.debug_output, do_display=False, + do_init_server=self.do_init_server, ) - self.prepare() + if self.layout.layout["misc"]["explore"]: + self.nt.enable_exploration_mode(self.model.get_initial_state().id) + self.layout.set_groups(self.separate_labels) + self.__add_states() + self.__add_transitions() + self.__update_physics_enabled() + self.nt.set_options(str(self.layout)) if self.nt is not None: self.nt.show() self.maybe_display_output() diff --git a/tests/test_visualization.py b/tests/test_visualization.py index fb35187..e890361 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -2,21 +2,61 @@ from stormvogel.model import Model, ModelType -def test_visualization(mocker): +def test_show(mocker): class MockNetwork: + def __init__(self, *args, **kwargs): + self.init(*args, **kwargs) + + init = mocker.stub(name="init_stub") add_node = mocker.stub(name="add_node_stub") add_edge = mocker.stub(name="add_edge_stub") set_options = mocker.stub(name="set_options_stub") + show = mocker.stub(name="show_stub") + mocker.patch("stormvogel.visjs.Network", MockNetwork) model = Model("simple", ModelType.MDP) - model.new_state("one") + one = model.new_state("one") + init = model.get_initial_state() + model.set_transitions(init, [(1, one)]) vis = Visualization(model) - vis.nt = MockNetwork - vis.prepare() + vis.show() MockNetwork.add_node.assert_any_call( 0, label="init", group="states", position_dict={} - ) + ) # type: ignore MockNetwork.add_node.assert_any_call( 1, label="one", group="states", position_dict={} - ) + ) # type: ignore + assert MockNetwork.add_node.call_count == 2 + MockNetwork.add_edge.assert_any_call(0, 1, label="1") + assert MockNetwork.add_edge.call_count == 1 + + +def test_rewards(mocker): + class MockNetwork: + def __init__(self, *args, **kwargs): + self.init(*args, **kwargs) + + init = mocker.stub(name="init_stub") + add_node = mocker.stub(name="add_node_stub") + add_edge = mocker.stub(name="add_edge_stub") + set_options = mocker.stub(name="set_options_stub") + show = mocker.stub(name="show_stub") + + mocker.patch("stormvogel.visjs.Network", MockNetwork) + model = Model("simple", ModelType.MDP) + one = model.new_state("one") + init = model.get_initial_state() + model.set_transitions(init, [(1, one)]) + model.add_rewards("LOL") + model.get_rewards("LOL").set(one, 37) + vis = Visualization(model=model) + vis.show() + MockNetwork.add_node.assert_any_call( + 0, label="init", group="states", position_dict={} + ) # type: ignore + MockNetwork.add_node.assert_any_call( + 1, label="one\nLOL: 37", group="states", position_dict={} + ) # type: ignore assert MockNetwork.add_node.call_count == 2 + MockNetwork.add_edge.assert_any_call(0, 1, label="1") + assert MockNetwork.add_edge.call_count == 1 From a85aa6af98c2413c312cec508e922ae5b86ac911 Mon Sep 17 00:00:00 2001 From: YouGuessedMyName Date: Sat, 28 Sep 2024 17:07:17 +0200 Subject: [PATCH 3/3] Added more visualization tests --- stormvogel/result.py | 9 ++-- tests/test_visualization.py | 88 ++++++++++++++++++++++++++++++------- 2 files changed, 76 insertions(+), 21 deletions(-) diff --git a/stormvogel/result.py b/stormvogel/result.py index 04de6be..c99e728 100644 --- a/stormvogel/result.py +++ b/stormvogel/result.py @@ -58,16 +58,17 @@ class Result: scheduler: Scheduler | None def __init__( - self, model: stormvogel.model.Model, values: list[stormvogel.model.Number] + self, + model: stormvogel.model.Model, + values: list[stormvogel.model.Number], + scheduler: Scheduler | None = None, ): self.model = model - + self.scheduler = scheduler self.values = {} for index, val in enumerate(values): self.values[index] = val - self.scheduler = None - def add_scheduler(self, stormpy_scheduler: stormpy.storage.Scheduler): """adds a scheduler to the result""" if self.scheduler is None: diff --git a/tests/test_visualization.py b/tests/test_visualization.py index e890361..2a671fc 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -1,8 +1,9 @@ from stormvogel.visualization import Visualization from stormvogel.model import Model, ModelType +from stormvogel.result import Result, Scheduler -def test_show(mocker): +def boilerplate(mocker): class MockNetwork: def __init__(self, *args, **kwargs): self.init(*args, **kwargs) @@ -14,12 +15,31 @@ def __init__(self, *args, **kwargs): show = mocker.stub(name="show_stub") mocker.patch("stormvogel.visjs.Network", MockNetwork) - model = Model("simple", ModelType.MDP) + return MockNetwork + + +def simple_model(): + model = Model("simple", ModelType.DTMC) one = model.new_state("one") init = model.get_initial_state() model.set_transitions(init, [(1, one)]) + return model, one, init + + +def test_show(mocker): + MockNetwork = boilerplate(mocker) + model, one, init = simple_model() vis = Visualization(model) vis.show() + MockNetwork.init.assert_called_once_with( + name=vis.name, + width=vis.layout.layout["misc"]["width"], + height=vis.layout.layout["misc"]["height"], + output=vis.output, + debug_output=vis.debug_output, + do_display=False, + do_init_server=vis.do_init_server, + ) MockNetwork.add_node.assert_any_call( 0, label="init", group="states", position_dict={} ) # type: ignore @@ -32,31 +52,65 @@ def __init__(self, *args, **kwargs): def test_rewards(mocker): - class MockNetwork: - def __init__(self, *args, **kwargs): - self.init(*args, **kwargs) - - init = mocker.stub(name="init_stub") - add_node = mocker.stub(name="add_node_stub") - add_edge = mocker.stub(name="add_edge_stub") - set_options = mocker.stub(name="set_options_stub") - show = mocker.stub(name="show_stub") - - mocker.patch("stormvogel.visjs.Network", MockNetwork) - model = Model("simple", ModelType.MDP) - one = model.new_state("one") - init = model.get_initial_state() + MockNetwork = boilerplate(mocker) + model, one, init = simple_model() model.set_transitions(init, [(1, one)]) model.add_rewards("LOL") model.get_rewards("LOL").set(one, 37) + model.add_rewards("HIHI") + model.get_rewards("HIHI").set(one, 42) vis = Visualization(model=model) vis.show() MockNetwork.add_node.assert_any_call( 0, label="init", group="states", position_dict={} ) # type: ignore MockNetwork.add_node.assert_any_call( - 1, label="one\nLOL: 37", group="states", position_dict={} + 1, label="one\nLOL: 37\nHIHI: 42", group="states", position_dict={} ) # type: ignore assert MockNetwork.add_node.call_count == 2 MockNetwork.add_edge.assert_any_call(0, 1, label="1") assert MockNetwork.add_edge.call_count == 1 + + +def test_results_count(mocker): + MockNetwork = boilerplate(mocker) + model, one, init = simple_model() + result = Result(model, [69, 12]) + + vis = Visualization(model=model, result=result) + vis.show() + RES_SYM = vis.layout.layout["results_and_rewards"]["resultSymbol"] + MockNetwork.add_node.assert_any_call( + 0, label=f"init\n{RES_SYM} 69", group="states", position_dict={} + ) # type: ignore + MockNetwork.add_node.assert_any_call( + 1, label=f"one\n{RES_SYM} 12", group="states", position_dict={} + ) # type: ignore + + assert result.values == {0: 69, 1: 12} + assert MockNetwork.add_node.call_count == 2 + MockNetwork.add_edge.assert_any_call(0, 1, label="1") + assert MockNetwork.add_edge.call_count == 1 + + +def test_results_scheduler(mocker): + MockNetwork = boilerplate(mocker) + model = Model("mdp", model_type=ModelType.MDP) + init = model.get_initial_state() + good = model.new_action("good", frozenset(["GOOD"])) + bad = model.new_action("bad", frozenset(["BAD"])) + end = model.new_state("end") + model.set_transitions(init, [(good, end), (bad, end)]) + scheduler = Scheduler(model, {0: good}) + result = Result(model, [1, 2], scheduler) + vis = Visualization(model=model, result=result) + vis.show() + MockNetwork.add_node.assert_any_call( + id=10000000001, label="bad", group="actions", position_dict={} + ) + MockNetwork.add_node.assert_any_call( + id=10000000000, label="good", group="scheduled_actions", position_dict={} + ) + + assert MockNetwork.add_node.call_count == 4 + assert MockNetwork.add_edge.call_count == 4