Skip to content

Commit

Permalink
all simulator tests work
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Oct 18, 2024
1 parent 4dbe498 commit 7f2a0ba
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 7 deletions.
1 change: 1 addition & 0 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,7 @@ def __eq__(self, other):
and sorted(self.rewards) == sorted(other.rewards)
and self.exit_rates == other.exit_rates
and self.markovian_states == other.markovian_states
# TODO: and self.actions == other.actions
)
return False

Expand Down
10 changes: 6 additions & 4 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,11 @@ def get_range_index(stateid: int):
if simulator.is_done():
break
else:
state = 0
if model.get_type() == stormvogel.model.ModelType.POMDP:
simulator.set_full_observability(True)
path = {}
state, reward, labels = simulator.restart()
simulator.restart()
for i in range(steps):
# we first choose an action (randomly or according to scheduler)
actions = simulator.available_actions()
Expand All @@ -143,7 +144,7 @@ def get_range_index(stateid: int):
)

# we add the state action pair to the path
stormvogel_action = stormvogel.model.EmptyAction
stormvogel_action = model.states[state].available_actions()[select_action]
next_step = simulator.step(actions[select_action])
state, reward, labels = next_step
path[i + 1] = (stormvogel_action, model.states[state])
Expand Down Expand Up @@ -219,11 +220,12 @@ def get_range_index(stateid: int):
if simulator.is_done():
break
else:
state = 0
for i in range(runs):
state, reward, labels = simulator.restart()
# state, reward, labels = simulator.restart()
simulator.restart()
for j in range(steps):
# we first choose an action

actions = simulator.available_actions()
select_action = (
random.randint(0, len(actions) - 1)
Expand Down
53 changes: 50 additions & 3 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,43 @@ def test_simulate():
assert partial_model == other_dtmc

# we make a monty hall mdp and run the simulator with it
mdp = examples.monty_hall.create_monty_hall_mdp()
rewardmodel = mdp.add_rewards("rewardmodel")
for i in range(67):
rewardmodel.rewards[i] = i
rewardmodel2 = mdp.add_rewards("rewardmodel2")
for i in range(67):
rewardmodel2.rewards[i] = i

taken_actions = {}
for id, state in mdp.states.items():
taken_actions[id] = state.available_actions()[0]
scheduler = stormvogel.result.Scheduler(mdp, taken_actions)

partial_model = stormvogel.simulator.simulate(
mdp, runs=1, steps=3, seed=1, scheduler=scheduler
)

# we make the partial model that should be created by the simulator
other_mdp = stormvogel.model.new_mdp()
other_mdp.new_state(labels=["carchosen"])
other_mdp.new_state(labels=["open"])
other_mdp.new_state(labels=["goatrevealed"])

rewardmodel = other_mdp.add_rewards("rewardmodel")
rewardmodel.rewards = {0: 0, 7: 7, 16: 16}
rewardmodel2 = other_mdp.add_rewards("rewardmodel2")
rewardmodel2.rewards = {0: 0, 7: 7, 16: 16}

assert partial_model == other_mdp


def test_simulate_path():
# we make the nuclear fusion ctmc and run simulate path with it
ctmc = examples.nuclear_fusion_ctmc.create_nuclear_fusion_ctmc()
path = stormvogel.simulator.simulate_path(ctmc, steps=5, seed=1)

# we make the path that the simulate path function should create
other_path = stormvogel.simulator.Path(
{
1: ctmc.get_state_by_id(1),
Expand All @@ -56,10 +84,29 @@ def test_simulate_path():
ctmc,
)

# print(path, other_path)

assert str(path) == str(other_path)

# we make the monty hall pomdp and run simulate path with it
pomdp = examples.monty_hall_pomdp.create_monty_hall_pomdp()
taken_actions = {}
for id, state in pomdp.states.items():
taken_actions[id] = state.available_actions()[
len(state.available_actions()) - 1
]
scheduler = stormvogel.result.Scheduler(pomdp, taken_actions)
path = stormvogel.simulator.simulate_path(
pomdp, steps=4, seed=1, scheduler=scheduler
)

# we make the path that the simulate path function should create
other_path = stormvogel.simulator.Path(
{
1: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(3)),
2: (pomdp.actions["open2"], pomdp.get_state_by_id(12)),
3: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(23)),
4: (pomdp.actions["switch"], pomdp.get_state_by_id(46)),
},
pomdp,
)

# we make the monty hall pomdp and run simulate path with it
assert str(path) == str(other_path)

0 comments on commit 7f2a0ba

Please sign in to comment.