Skip to content

Commit

Permalink
started working on the tests
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Oct 14, 2024
1 parent 3f84314 commit 9199f99
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
27 changes: 21 additions & 6 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def simulate_path(
model: stormvogel.model.Model,
steps: int = 1,
scheduler: stormvogel.result.Scheduler | None = None,
seed: int | None = None,
) -> Path:
"""
Simulates the model a given number of steps.
Expand All @@ -111,7 +112,10 @@ def get_range_index(stateid: int):

# we initialize the simulator
stormpy_model = stormvogel.mapping.stormvogel_to_stormpy(model)
simulator = stormpy.simulator.create_simulator(stormpy_model)
if seed:
simulator = stormpy.simulator.create_simulator(stormpy_model, seed)
else:
simulator = stormpy.simulator.create_simulator(stormpy_model)
assert simulator is not None

# we start adding states or state action pairs to the path
Expand Down Expand Up @@ -156,6 +160,7 @@ def simulate(
steps: int = 1,
runs: int = 1,
scheduler: stormvogel.result.Scheduler | None = None,
seed: int | None = None,
) -> stormvogel.model.Model | None:
"""
Simulates the model a given number of steps for a given number of runs.
Expand All @@ -174,7 +179,10 @@ def get_range_index(stateid: int):
# we initialize the simulator
stormpy_model = stormvogel.mapping.stormvogel_to_stormpy(model)
assert stormpy_model is not None
simulator = stormpy.simulator.create_simulator(stormpy_model)
if seed:
simulator = stormpy.simulator.create_simulator(stormpy_model, seed)
else:
simulator = stormpy.simulator.create_simulator(stormpy_model)
assert simulator is not None

# we keep track of all discovered states over all runs and add them to the partial model
Expand All @@ -185,20 +193,26 @@ def get_range_index(stateid: int):
assert len(model.rewards) in [0, 1]
if model.rewards:
reward_model = partial_model.add_rewards(model.rewards[0].name)
reward_model.set(
partial_model.get_initial_state(),
model.rewards[0].get(model.get_initial_state()),
)
else:
reward_model = None

discovered_states = {0}
if not partial_model.supports_actions():
for i in range(runs):
simulator.restart()
for j in range(steps):
state, reward, labels = simulator.step()

# we add to the partial model what we discovered (if new)
if state not in partial_model.states.keys():
if state not in discovered_states:
discovered_states.add(state)
partial_model.new_state(list(labels))
if reward_model:
reward_model.set(model.get_state_by_id(state), reward)
reward_model.set(model.get_state_by_id(state), reward[0])

if simulator.is_done():
break
Expand Down Expand Up @@ -233,9 +247,10 @@ def get_range_index(stateid: int):
state
)
state_action_pair = row_group + select_action
reward_model.set_action_state(state_action_pair, reward)
reward_model.set_action_state(state_action_pair, reward[0])
state, labels = discovery[0], discovery[2]
if state not in partial_model.states.keys():
if state not in discovered_states:
discovered_states.add(state)
partial_model.new_state(list(labels))

if simulator.is_done():
Expand Down
55 changes: 55 additions & 0 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import stormvogel.model
import examples.die
import examples.nuclear_fusion_ctmc
import stormvogel.simulator


def test_simulate():
# we make a die dtmc and run the simulator with it
dtmc = examples.die.create_die_dtmc()
# rewardmodel = dtmc.add_rewards("rewardmodel")
# for stateid in dtmc.states.keys():
# rewardmodel.rewards[stateid] = 5
partial_model = stormvogel.simulator.simulate(dtmc, runs=5, steps=1, seed=1)

# we make the partial model that should be created by the simulator
other_dtmc = stormvogel.model.new_dtmc()
other_dtmc.new_state(labels=["rolled5"])
other_dtmc.new_state(labels=["rolled0"])
other_dtmc.new_state(labels=["rolled1"])

# rewardmodel = other_dtmc.add_rewards("rewardmodel")
# for stateid in other_dtmc.states.keys():
# rewardmodel.rewards[stateid] = float(5)

# print(partial_model.rewards, other_dtmc.rewards)

assert partial_model == other_dtmc

# we make a monty hall mdp and run the simulator with it

# we make the partial model that should be created by the simulator


def test_simulate_path():
# we make the nuclear fusion ctmc and run simulate path with it
ctmc = examples.nuclear_fusion_ctmc.create_nuclear_fusion_ctmc()
path = stormvogel.simulator.simulate_path(ctmc, steps=5, seed=1)

other_path = stormvogel.simulator.Path(
{
1: ctmc.get_state_by_id(1),
2: ctmc.get_state_by_id(2),
3: ctmc.get_state_by_id(3),
4: ctmc.get_state_by_id(4),
},
ctmc,
)

# print(path, other_path)

assert str(path) == str(other_path)

# we make the path that the simulate path function should create

# we make the monty hall pomdp and run simulate path with it

0 comments on commit 9199f99

Please sign in to comment.