Skip to content

Commit

Permalink
comparing rewards now works properly
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Oct 18, 2024
1 parent f7bb747 commit 4dbe498
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 9 deletions.
4 changes: 2 additions & 2 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def transition_from_shorthand(shorthand: TransitionShorthand) -> Transition:
)


@dataclass
@dataclass(order=True)
class RewardModel:
"""Represents a state-exit reward model.
dtmc.delete_state(dtmc.get_state_by_id(1), True, True)
Expand Down Expand Up @@ -762,7 +762,7 @@ def __eq__(self, other):
self.type == other.type
and self.states == other.states
and self.transitions == other.transitions
and self.rewards == other.rewards
and sorted(self.rewards) == sorted(other.rewards)
and self.exit_rates == other.exit_rates
and self.markovian_states == other.markovian_states
)
Expand Down
7 changes: 4 additions & 3 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,14 @@ def get_range_index(stateid: int):
simulator.restart()
for j in range(steps):
state, reward, labels = simulator.step()
reward.reverse()

# we add to the partial model what we discovered (if new)
if state not in discovered_states:
discovered_states.add(state)
partial_model.new_state(list(labels))
for index, rewardmodel in enumerate(partial_model.rewards):
rewardmodel.set(model.get_state_by_id(state), reward[index])
new_state = partial_model.new_state(list(labels))
for index, rewardmodel in enumerate(partial_model.rewards):
rewardmodel.set(new_state, reward[index])

if simulator.is_done():
break
Expand Down
18 changes: 18 additions & 0 deletions tests/test_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ def test_stormpy_to_stormvogel_and_back_dtmc():
def test_stormvogel_to_stormpy_and_back_dtmc():
# we test it for the die dtmc
stormvogel_dtmc = examples.die.create_die_dtmc()

# we test if rewardmodels work:
rewardmodel = stormvogel_dtmc.add_rewards("rewardmodel")
for stateid in stormvogel_dtmc.states.keys():
rewardmodel.rewards[stateid] = 1
rewardmodel2 = stormvogel_dtmc.add_rewards("rewardmodel2")
for stateid in stormvogel_dtmc.states.keys():
rewardmodel2.rewards[stateid] = 2

# print(stormvogel_dtmc)
stormpy_dtmc = stormvogel.mapping.stormvogel_to_stormpy(stormvogel_dtmc)
# print(stormpy_dtmc)
Expand All @@ -136,6 +145,15 @@ def test_stormpy_to_stormvogel_and_back_mdp():
def test_stormvogel_to_stormpy_and_back_mdp():
# we test it for monty hall mdp
stormvogel_mdp = examples.monty_hall.create_monty_hall_mdp()

# we additionally test if reward models work
rewardmodel = stormvogel_mdp.add_rewards("rewardmodel")
for i in range(67):
rewardmodel.rewards[i] = i
rewardmodel2 = stormvogel_mdp.add_rewards("rewardmodel2")
for i in range(67):
rewardmodel2.rewards[i] = i

# print(stormvogel_mdp)
stormpy_mdp = stormvogel.mapping.stormvogel_to_stormpy(stormvogel_mdp)
# print(stormpy_mdp)
Expand Down
18 changes: 14 additions & 4 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@ def test_simulate():
dtmc = examples.die.create_die_dtmc()
rewardmodel = dtmc.add_rewards("rewardmodel")
for stateid in dtmc.states.keys():
rewardmodel.rewards[stateid] = 5
rewardmodel.rewards[stateid] = 3
rewardmodel2 = dtmc.add_rewards("rewardmodel2")
for stateid in dtmc.states.keys():
rewardmodel2.rewards[stateid] = 2
rewardmodel3 = dtmc.add_rewards("rewardmodel3")
for stateid in dtmc.states.keys():
rewardmodel3.rewards[stateid] = 1
partial_model = stormvogel.simulator.simulate(dtmc, runs=5, steps=1, seed=1)

# we make the partial model that should be created by the simulator
Expand All @@ -20,9 +26,13 @@ def test_simulate():

rewardmodel = other_dtmc.add_rewards("rewardmodel")
for stateid in other_dtmc.states.keys():
rewardmodel.rewards[stateid] = float(5)

# print(partial_model.rewards, other_dtmc.rewards)
rewardmodel.rewards[stateid] = float(3)
rewardmodel2 = other_dtmc.add_rewards("rewardmodel2")
for stateid in other_dtmc.states.keys():
rewardmodel2.rewards[stateid] = float(2)
rewardmodel3 = other_dtmc.add_rewards("rewardmodel3")
for stateid in other_dtmc.states.keys():
rewardmodel3.rewards[stateid] = float(1)

assert partial_model == other_dtmc

Expand Down

0 comments on commit 4dbe498

Please sign in to comment.