Skip to content

Commit

Permalink
Added more visualization tests
Browse files Browse the repository at this point in the history
  • Loading branch information
YouGuessedMyName committed Sep 28, 2024
1 parent da366d3 commit a85aa6a
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 21 deletions.
9 changes: 5 additions & 4 deletions stormvogel/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,17 @@ class Result:
scheduler: Scheduler | None

def __init__(
self, model: stormvogel.model.Model, values: list[stormvogel.model.Number]
self,
model: stormvogel.model.Model,
values: list[stormvogel.model.Number],
scheduler: Scheduler | None = None,
):
self.model = model

self.scheduler = scheduler
self.values = {}
for index, val in enumerate(values):
self.values[index] = val

self.scheduler = None

def add_scheduler(self, stormpy_scheduler: stormpy.storage.Scheduler):
"""adds a scheduler to the result"""
if self.scheduler is None:
Expand Down
88 changes: 71 additions & 17 deletions tests/test_visualization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from stormvogel.visualization import Visualization
from stormvogel.model import Model, ModelType
from stormvogel.result import Result, Scheduler


def test_show(mocker):
def boilerplate(mocker):
class MockNetwork:
def __init__(self, *args, **kwargs):
self.init(*args, **kwargs)
Expand All @@ -14,12 +15,31 @@ def __init__(self, *args, **kwargs):
show = mocker.stub(name="show_stub")

mocker.patch("stormvogel.visjs.Network", MockNetwork)
model = Model("simple", ModelType.MDP)
return MockNetwork


def simple_model():
model = Model("simple", ModelType.DTMC)
one = model.new_state("one")
init = model.get_initial_state()
model.set_transitions(init, [(1, one)])
return model, one, init


def test_show(mocker):
MockNetwork = boilerplate(mocker)
model, one, init = simple_model()
vis = Visualization(model)
vis.show()
MockNetwork.init.assert_called_once_with(
name=vis.name,
width=vis.layout.layout["misc"]["width"],
height=vis.layout.layout["misc"]["height"],
output=vis.output,
debug_output=vis.debug_output,
do_display=False,
do_init_server=vis.do_init_server,
)
MockNetwork.add_node.assert_any_call(
0, label="init", group="states", position_dict={}
) # type: ignore
Expand All @@ -32,31 +52,65 @@ def __init__(self, *args, **kwargs):


def test_rewards(mocker):
class MockNetwork:
def __init__(self, *args, **kwargs):
self.init(*args, **kwargs)

init = mocker.stub(name="init_stub")
add_node = mocker.stub(name="add_node_stub")
add_edge = mocker.stub(name="add_edge_stub")
set_options = mocker.stub(name="set_options_stub")
show = mocker.stub(name="show_stub")

mocker.patch("stormvogel.visjs.Network", MockNetwork)
model = Model("simple", ModelType.MDP)
one = model.new_state("one")
init = model.get_initial_state()
MockNetwork = boilerplate(mocker)
model, one, init = simple_model()
model.set_transitions(init, [(1, one)])
model.add_rewards("LOL")
model.get_rewards("LOL").set(one, 37)
model.add_rewards("HIHI")
model.get_rewards("HIHI").set(one, 42)
vis = Visualization(model=model)
vis.show()
MockNetwork.add_node.assert_any_call(
0, label="init", group="states", position_dict={}
) # type: ignore
MockNetwork.add_node.assert_any_call(
1, label="one\nLOL: 37", group="states", position_dict={}
1, label="one\nLOL: 37\nHIHI: 42", group="states", position_dict={}
) # type: ignore
assert MockNetwork.add_node.call_count == 2
MockNetwork.add_edge.assert_any_call(0, 1, label="1")
assert MockNetwork.add_edge.call_count == 1


def test_results_count(mocker):
MockNetwork = boilerplate(mocker)
model, one, init = simple_model()
result = Result(model, [69, 12])

vis = Visualization(model=model, result=result)
vis.show()
RES_SYM = vis.layout.layout["results_and_rewards"]["resultSymbol"]
MockNetwork.add_node.assert_any_call(
0, label=f"init\n{RES_SYM} 69", group="states", position_dict={}
) # type: ignore
MockNetwork.add_node.assert_any_call(
1, label=f"one\n{RES_SYM} 12", group="states", position_dict={}
) # type: ignore

assert result.values == {0: 69, 1: 12}
assert MockNetwork.add_node.call_count == 2
MockNetwork.add_edge.assert_any_call(0, 1, label="1")
assert MockNetwork.add_edge.call_count == 1


def test_results_scheduler(mocker):
MockNetwork = boilerplate(mocker)
model = Model("mdp", model_type=ModelType.MDP)
init = model.get_initial_state()
good = model.new_action("good", frozenset(["GOOD"]))
bad = model.new_action("bad", frozenset(["BAD"]))
end = model.new_state("end")
model.set_transitions(init, [(good, end), (bad, end)])
scheduler = Scheduler(model, {0: good})
result = Result(model, [1, 2], scheduler)
vis = Visualization(model=model, result=result)
vis.show()
MockNetwork.add_node.assert_any_call(
id=10000000001, label="bad", group="actions", position_dict={}
)
MockNetwork.add_node.assert_any_call(
id=10000000000, label="good", group="scheduled_actions", position_dict={}
)

assert MockNetwork.add_node.call_count == 4
assert MockNetwork.add_edge.call_count == 4

0 comments on commit a85aa6a

Please sign in to comment.