Skip to content

Commit

Permalink
small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Oct 20, 2024
1 parent 7f2a0ba commit 257bb30
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
31 changes: 20 additions & 11 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ class Path:
Path object that represents a path created by a simulator on a certain model.
Args:
path: The path itself is a dictionary where we either store for each step a state or a state action pair.
model: The model the path traverses
path: The path itself is a dictionary where we either store for each step a state or a state action pair,
depending on if we are working with a dtmc or an mdp.
model: model that the path traverses through
"""

path: (
Expand All @@ -34,8 +35,12 @@ def __init__(
| dict[int, stormvogel.model.State],
model: stormvogel.model.Model,
):
self.path = path
self.model = model
if model.get_type() != stormvogel.model.ModelType.MA:
self.path = path
self.model = model
else:
# TODO make the simulators work for markov automata
raise NotImplementedError

def get_state_in_step(self, step: int) -> stormvogel.model.State | None:
"""returns the state discovered in the given step in the path"""
Expand Down Expand Up @@ -89,6 +94,12 @@ def __str__(self) -> str:
path += f" --> state: {state.id}"
return path

def __eq__(self, other):
if isinstance(other, Path):
return self.path == other.path and self.model == other.model
else:
return False


def simulate_path(
model: stormvogel.model.Model,
Expand All @@ -97,8 +108,7 @@ def simulate_path(
seed: int | None = None,
) -> Path:
"""
Simulates the model a given number of steps.
Returns the resulting path of the simulator.
Simulates the model a given number of steps and returns the path created by the process.
"""

def get_range_index(stateid: int):
Expand Down Expand Up @@ -130,8 +140,6 @@ def get_range_index(stateid: int):
break
else:
state = 0
if model.get_type() == stormvogel.model.ModelType.POMDP:
simulator.set_full_observability(True)
path = {}
simulator.restart()
for i in range(steps):
Expand Down Expand Up @@ -221,8 +229,8 @@ def get_range_index(stateid: int):
break
else:
state = 0

for i in range(runs):
# state, reward, labels = simulator.restart()
simulator.restart()
for j in range(steps):
# we first choose an action
Expand Down Expand Up @@ -280,7 +288,7 @@ def get_range_index(stateid: int):
path = simulate_path(dtmc, 5)
print(path)
"""

"""
# then we test it with an mdp
mdp = examples.monty_hall.create_monty_hall_mdp()
rewardmodel = mdp.add_rewards("rewardmodel")
Expand All @@ -301,8 +309,8 @@ def get_range_index(stateid: int):
assert partial_model is not None
print(path)
print(partial_model.rewards)

"""

# then we test it with a pomdp
pomdp = examples.monty_hall_pomdp.create_monty_hall_pomdp()

Expand All @@ -316,6 +324,7 @@ def get_range_index(stateid: int):
print(partial_model)
print(path)

"""
# then we test it with a ctmc
ctmc = examples.nuclear_fusion_ctmc.create_nuclear_fusion_ctmc()
partial_model = simulate(ctmc, 10, 10)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_simulate_path():
ctmc,
)

assert str(path) == str(other_path)
assert path == other_path

# we make the monty hall pomdp and run simulate path with it
pomdp = examples.monty_hall_pomdp.create_monty_hall_pomdp()
Expand All @@ -109,4 +109,4 @@ def test_simulate_path():
pomdp,
)

assert str(path) == str(other_path)
assert path == other_path

0 comments on commit 257bb30

Please sign in to comment.