diff --git a/stormvogel/simulator.py b/stormvogel/simulator.py index 0d10b61..b366a81 100644 --- a/stormvogel/simulator.py +++ b/stormvogel/simulator.py @@ -18,8 +18,9 @@ class Path: Path object that represents a path created by a simulator on a certain model. Args: - path: The path itself is a dictionary where we either store for each step a state or a state action pair. - model: The model the path traverses + path: The path itself is a dictionary where we either store for each step a state or a state action pair, + depending on if we are working with a dtmc or an mdp. + model: model that the path traverses through """ path: ( @@ -34,8 +35,12 @@ def __init__( | dict[int, stormvogel.model.State], model: stormvogel.model.Model, ): - self.path = path - self.model = model + if model.get_type() != stormvogel.model.ModelType.MA: + self.path = path + self.model = model + else: + # TODO make the simulators work for markov automata + raise NotImplementedError def get_state_in_step(self, step: int) -> stormvogel.model.State | None: """returns the state discovered in the given step in the path""" @@ -89,6 +94,12 @@ def __str__(self) -> str: path += f" --> state: {state.id}" return path + def __eq__(self, other): + if isinstance(other, Path): + return self.path == other.path and self.model == other.model + else: + return False + def simulate_path( model: stormvogel.model.Model, @@ -97,8 +108,7 @@ def simulate_path( seed: int | None = None, ) -> Path: """ - Simulates the model a given number of steps. - Returns the resulting path of the simulator. + Simulates the model a given number of steps and returns the path created by the process. """ def get_range_index(stateid: int): @@ -130,8 +140,6 @@ def get_range_index(stateid: int): break else: state = 0 - if model.get_type() == stormvogel.model.ModelType.POMDP: - simulator.set_full_observability(True) path = {} simulator.restart() for i in range(steps): @@ -221,8 +229,8 @@ def get_range_index(stateid: int): break else: state = 0 + for i in range(runs): - # state, reward, labels = simulator.restart() simulator.restart() for j in range(steps): # we first choose an action @@ -280,7 +288,7 @@ def get_range_index(stateid: int): path = simulate_path(dtmc, 5) print(path) """ - + """ # then we test it with an mdp mdp = examples.monty_hall.create_monty_hall_mdp() rewardmodel = mdp.add_rewards("rewardmodel") @@ -301,8 +309,8 @@ def get_range_index(stateid: int): assert partial_model is not None print(path) print(partial_model.rewards) - """ + # then we test it with a pomdp pomdp = examples.monty_hall_pomdp.create_monty_hall_pomdp() @@ -316,6 +324,7 @@ def get_range_index(stateid: int): print(partial_model) print(path) + """ # then we test it with a ctmc ctmc = examples.nuclear_fusion_ctmc.create_nuclear_fusion_ctmc() partial_model = simulate(ctmc, 10, 10) diff --git a/tests/test_simulator.py b/tests/test_simulator.py index ebf78bd..b839d97 100644 --- a/tests/test_simulator.py +++ b/tests/test_simulator.py @@ -84,7 +84,7 @@ def test_simulate_path(): ctmc, ) - assert str(path) == str(other_path) + assert path == other_path # we make the monty hall pomdp and run simulate path with it pomdp = examples.monty_hall_pomdp.create_monty_hall_pomdp() @@ -109,4 +109,4 @@ def test_simulate_path(): pomdp, ) - assert str(path) == str(other_path) + assert path == other_path