Skip to content

Commit

Permalink
Merge pull request #104 from moves-rwth/54-refactor
Browse files Browse the repository at this point in the history
54 refactor
  • Loading branch information
YouGuessedMyName authored Sep 28, 2024
2 parents 9c92914 + a85aa6a commit f37171c
Show file tree
Hide file tree
Showing 9 changed files with 197 additions and 38 deletions.
11 changes: 4 additions & 7 deletions stormvogel/communication_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@
enable_server: bool = True
"""Disable if you don't want to use an internal communication server. Some features might break."""

internal_enable_server: bool = enable_server


localhost_address: str = "127.0.0.1"

min_port = 8889
Expand Down Expand Up @@ -105,8 +102,8 @@ def request(self, js: str):
Also waits for server to boot up if it is not finished yet.
Should be thread safe. (I hope).
WHEN SENDING JAVASCRIPT, DO NOT FORGET EXTRA QUOTES AROUND STRINGS."""
global internal_enable_server
if not internal_enable_server:
global server
if server is None:
raise TimeoutError("There is no server running.")

global awaiting, server_running
Expand Down Expand Up @@ -214,8 +211,8 @@ def initialize_server() -> CommunicationServer | None:
"""If server is None, then create a new server and store it in global variable server.
Use the port stored in global variable server_port.
"""
global server, server_port, internal_enable_server
if not internal_enable_server:
global server, server_port, enable_server
if not enable_server:
return None

output = widgets.Output()
Expand Down
5 changes: 3 additions & 2 deletions stormvogel/rdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def rget(d: dict, path: list) -> Any:
) # Throws KeyError if key not present.


def rset(d: dict, path: list, value: Any) -> None:
def rset(d: dict, path: list, value: Any) -> dict:
"""Recursively set dict value."""
if len(path) == 0:
return
return d

def __rset(d: dict, path: list, value: Any):
first = path.pop(0)
Expand All @@ -25,6 +25,7 @@ def __rset(d: dict, path: list, value: Any):
__rset(d[first], path, value)

__rset(d, copy.deepcopy(path), value)
return d


def merge_dict(dict1: dict, dict2: dict) -> dict:
Expand Down
9 changes: 5 additions & 4 deletions stormvogel/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,17 @@ class Result:
scheduler: Scheduler | None

def __init__(
self, model: stormvogel.model.Model, values: list[stormvogel.model.Number]
self,
model: stormvogel.model.Model,
values: list[stormvogel.model.Number],
scheduler: Scheduler | None = None,
):
self.model = model

self.scheduler = scheduler
self.values = {}
for index, val in enumerate(values):
self.values[index] = val

self.scheduler = None

def add_scheduler(self, stormpy_scheduler: stormpy.storage.Scheduler):
"""adds a scheduler to the result"""
if self.scheduler is None:
Expand Down
10 changes: 6 additions & 4 deletions stormvogel/show.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import stormvogel.visualization
import stormvogel.layout_editor
import stormvogel.communication_server
import stormvogel.result

import ipywidgets as widgets
import IPython.display as ipd
Expand All @@ -23,6 +24,7 @@ def show(
Args:
model (Model): The stormvogel model to be displayed.
result (Result): A result associatied with the model.
name (str, optional): Internally used name. Will be randomly generated if left as None.
result (Result, optional): Result corresponding to the model.
layout (Layout, optional): Layout used for the visualization.
Expand All @@ -32,10 +34,9 @@ def show(
"""
if layout is None:
layout = stormvogel.layout.DEFAULT()
if not show_editor or not stormvogel.communication_server.enable_server:
stormvogel.communication_server.internal_enable_server = False
else:
stormvogel.communication_server.internal_enable_server = True
do_init_server = False
if show_editor or stormvogel.communication_server.enable_server:
do_init_server = True

do_display = not show_editor
vis = stormvogel.visualization.Visualization(
Expand All @@ -46,6 +47,7 @@ def show(
separate_labels=separate_labels,
do_display=do_display,
debug_output=debug_output,
do_init_server=do_init_server,
)
vis.show()

Expand Down
8 changes: 5 additions & 3 deletions stormvogel/visjs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def __init__(
output: widgets.Output | None = None,
do_display: bool = True,
debug_output: widgets.Output = widgets.Output(),
do_init_server: bool = True,
) -> None:
"""Display a visjs network using IPython. The network can display by itself or you can specify an Output widget in which it should be displayed.
Expand All @@ -46,9 +47,10 @@ def __init__(
self.edges_js: str = ""
self.options_js: str = "{}"
self.new_nodes_hidden: bool = False
self.server: stormvogel.communication_server.CommunicationServer = (
stormvogel.communication_server.initialize_server()
)
if do_init_server:
self.server: stormvogel.communication_server.CommunicationServer = (
stormvogel.communication_server.initialize_server()
)
# Note that this refers to the same server as the global variable in stormvogel.communication_server.

def enable_exploration_mode(self, initial_node_id: int):
Expand Down
18 changes: 15 additions & 3 deletions stormvogel/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
output: widgets.Output | None = None,
do_display: bool = True,
debug_output: widgets.Output = widgets.Output(),
do_init_server: bool = True,
) -> None:
"""Create visualization of a Model using a pyvis Network
Expand All @@ -67,6 +68,7 @@ def __init__(
self.layout: stormvogel.layout.Layout = layout
self.separate_labels = list(map(und, separate_labels))
self.nt: stormvogel.visjs.Network | None = None
self.do_init_server = do_init_server

def show(self) -> None:
"""(Re-)load the Network and display if self.do_display is True."""
Expand All @@ -81,6 +83,7 @@ def show(self) -> None:
output=self.output,
debug_output=self.debug_output,
do_display=False,
do_init_server=self.do_init_server,
)
if self.layout.layout["misc"]["explore"]:
self.nt.enable_exploration_mode(self.model.get_initial_state().id)
Expand All @@ -89,7 +92,8 @@ def show(self) -> None:
self.__add_transitions()
self.__update_physics_enabled()
self.nt.set_options(str(self.layout))
self.nt.show()
if self.nt is not None:
self.nt.show()
self.maybe_display_output()

def update(self) -> None:
Expand Down Expand Up @@ -131,7 +135,7 @@ def __add_transitions(self) -> None:
if self.nt is None:
return
action_id = self.ACTION_ID_OFFSET
# scheduler = self.result.scheduler if self.result is not None else None
scheduler = self.result.scheduler if self.result is not None else None
# In the visualization, both actions and states are nodes, so we need to keep track of how many actions we already have.
for state_id, transition in self.model.transitions.items():
for action, branch in transition.transition.items():
Expand All @@ -144,11 +148,19 @@ def __add_transitions(self) -> None:
label=self.__format_probability(prob),
)
else:
# Put the action in the group scheduled_actions if appropriate.
group = "actions"
if scheduler is not None:
choice = scheduler.get_choice_of_state(
state=self.model.get_state_by_id(state_id)
)
if choice == action:
group = "scheduled_actions"
# Add the action's node
self.nt.add_node(
id=action_id,
label=action.name,
group="actions",
group=group,
position_dict=self.layout.layout["positions"],
)
# Add transition from this state TO the action.
Expand Down
15 changes: 0 additions & 15 deletions tests/test_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,6 @@
from stormvogel.rdict import merge_dict


def test_layout_merge_dict():
# Test priority for second dict.
d1 = {"a": 1, "b": 1, "c": 1}
d2 = {"b": 2}
assert {"a": 1, "b": 2, "c": 1} == merge_dict(d1, d2)
# Test conservation of elements in both dicts
d1 = {"a": 1, "b": 1}
d2 = {"c": 2, "d": 2}
assert {"a": 1, "b": 1, "c": 2, "d": 2} == merge_dict(d1, d2)
# Test nested
d1 = {"a": {"b": {"c": 1, "d": 1}}, "e": 1}
d2 = {"a": {"b": {"c": 2}}}
assert {"a": {"b": {"c": 2, "d": 1}}, "e": 1} == merge_dict(d1, d2)


def test_layout_loading():
"""Tests if str(Layout) returns the correctly loaded json string."""
with open(os.path.join(os.getcwd(), "stormvogel/layouts/default.json")) as f:
Expand Down
43 changes: 43 additions & 0 deletions tests/test_rdict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from stormvogel.rdict import rget, rset, merge_dict


def test_rget():
# Empty
d = {}
assert rget(d, []) == {}
# Simple
d = {"a": 1}
assert rget(d, ["a"]) == 1
# Nested
d = {"a": {"b": {"c": 3, "b": 5}, "c": 2}, "c": 1}
assert rget(d, ["a", "b", "c"]) == 3


def test_rset():
# Empty path
d = {}
assert rset(d, [], 1) == {}
# Simple
d = {}
assert rset(d, ["a"], 1) == {"a": 1}
# Existing value
d = {"a": 0}
assert rset(d, ["a"], 1) == {"a": 1}
# Nested
d = {"a": {"b": 8}}
assert rset(d, ["a", "b"], {"c": 3}) == {"a": {"b": {"c": 3}}}


def test_merge_dict():
# Test priority for second dict.
d1 = {"a": 1, "b": 1, "c": 1}
d2 = {"b": 2}
assert {"a": 1, "b": 2, "c": 1} == merge_dict(d1, d2)
# Test conservation of elements in both dicts
d1 = {"a": 1, "b": 1}
d2 = {"c": 2, "d": 2}
assert {"a": 1, "b": 1, "c": 2, "d": 2} == merge_dict(d1, d2)
# Test nested
d1 = {"a": {"b": {"c": 1, "d": 1}}, "e": 1}
d2 = {"a": {"b": {"c": 2}}}
assert {"a": {"b": {"c": 2, "d": 1}}, "e": 1} == merge_dict(d1, d2)
116 changes: 116 additions & 0 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
from stormvogel.visualization import Visualization
from stormvogel.model import Model, ModelType
from stormvogel.result import Result, Scheduler


def boilerplate(mocker):
class MockNetwork:
def __init__(self, *args, **kwargs):
self.init(*args, **kwargs)

init = mocker.stub(name="init_stub")
add_node = mocker.stub(name="add_node_stub")
add_edge = mocker.stub(name="add_edge_stub")
set_options = mocker.stub(name="set_options_stub")
show = mocker.stub(name="show_stub")

mocker.patch("stormvogel.visjs.Network", MockNetwork)
return MockNetwork


def simple_model():
model = Model("simple", ModelType.DTMC)
one = model.new_state("one")
init = model.get_initial_state()
model.set_transitions(init, [(1, one)])
return model, one, init


def test_show(mocker):
MockNetwork = boilerplate(mocker)
model, one, init = simple_model()
vis = Visualization(model)
vis.show()
MockNetwork.init.assert_called_once_with(
name=vis.name,
width=vis.layout.layout["misc"]["width"],
height=vis.layout.layout["misc"]["height"],
output=vis.output,
debug_output=vis.debug_output,
do_display=False,
do_init_server=vis.do_init_server,
)
MockNetwork.add_node.assert_any_call(
0, label="init", group="states", position_dict={}
) # type: ignore
MockNetwork.add_node.assert_any_call(
1, label="one", group="states", position_dict={}
) # type: ignore
assert MockNetwork.add_node.call_count == 2
MockNetwork.add_edge.assert_any_call(0, 1, label="1")
assert MockNetwork.add_edge.call_count == 1


def test_rewards(mocker):
MockNetwork = boilerplate(mocker)
model, one, init = simple_model()
model.set_transitions(init, [(1, one)])
model.add_rewards("LOL")
model.get_rewards("LOL").set(one, 37)
model.add_rewards("HIHI")
model.get_rewards("HIHI").set(one, 42)
vis = Visualization(model=model)
vis.show()
MockNetwork.add_node.assert_any_call(
0, label="init", group="states", position_dict={}
) # type: ignore
MockNetwork.add_node.assert_any_call(
1, label="one\nLOL: 37\nHIHI: 42", group="states", position_dict={}
) # type: ignore
assert MockNetwork.add_node.call_count == 2
MockNetwork.add_edge.assert_any_call(0, 1, label="1")
assert MockNetwork.add_edge.call_count == 1


def test_results_count(mocker):
MockNetwork = boilerplate(mocker)
model, one, init = simple_model()
result = Result(model, [69, 12])

vis = Visualization(model=model, result=result)
vis.show()
RES_SYM = vis.layout.layout["results_and_rewards"]["resultSymbol"]
MockNetwork.add_node.assert_any_call(
0, label=f"init\n{RES_SYM} 69", group="states", position_dict={}
) # type: ignore
MockNetwork.add_node.assert_any_call(
1, label=f"one\n{RES_SYM} 12", group="states", position_dict={}
) # type: ignore

assert result.values == {0: 69, 1: 12}
assert MockNetwork.add_node.call_count == 2
MockNetwork.add_edge.assert_any_call(0, 1, label="1")
assert MockNetwork.add_edge.call_count == 1


def test_results_scheduler(mocker):
MockNetwork = boilerplate(mocker)
model = Model("mdp", model_type=ModelType.MDP)
init = model.get_initial_state()
good = model.new_action("good", frozenset(["GOOD"]))
bad = model.new_action("bad", frozenset(["BAD"]))
end = model.new_state("end")
model.set_transitions(init, [(good, end), (bad, end)])
scheduler = Scheduler(model, {0: good})
result = Result(model, [1, 2], scheduler)
vis = Visualization(model=model, result=result)
vis.show()
MockNetwork.add_node.assert_any_call(
id=10000000001, label="bad", group="actions", position_dict={}
)
MockNetwork.add_node.assert_any_call(
id=10000000000, label="good", group="scheduled_actions", position_dict={}
)

assert MockNetwork.add_node.call_count == 4
assert MockNetwork.add_edge.call_count == 4

0 comments on commit f37171c

Please sign in to comment.