diff --git a/examples/die.py b/examples/die.py index e51534e..1ba5262 100644 --- a/examples/die.py +++ b/examples/die.py @@ -16,7 +16,7 @@ def create_die_dtmc(): dtmc.add_self_loops() # test if state deletion works - # dtmc.delete_state(dtmc.get_state_by_id(1), True, True) + dtmc.delete_state(dtmc.get_state_by_id(1), True, True) return dtmc diff --git a/examples/monty_hall_pomdp.py b/examples/monty_hall_pomdp.py index 2a2be56..251ee77 100644 --- a/examples/monty_hall_pomdp.py +++ b/examples/monty_hall_pomdp.py @@ -69,7 +69,9 @@ def create_monty_hall_pomdp(): # we add the observations TODO: let it make sense for state in pomdp.states.values(): - state.new_observation(0) + state.set_observation(0) + + pomdp.normalize() return pomdp diff --git a/examples/simple_ma.py b/examples/simple_ma.py index ac6b72b..4034c62 100644 --- a/examples/simple_ma.py +++ b/examples/simple_ma.py @@ -30,7 +30,7 @@ def create_simple_ma(): ma.add_self_loops() # we delete a state - ma.delete_state(ma.get_state_by_id(3), True, True) + ma.delete_state(ma.get_state_by_id(3), False, True) return ma diff --git a/stormvogel/mapping.py b/stormvogel/mapping.py index e1e33b1..bc4e852 100644 --- a/stormvogel/mapping.py +++ b/stormvogel/mapping.py @@ -283,7 +283,7 @@ def map_ma(model: stormvogel.model.Model) -> stormpy.storage.SparseMA: return ma if model.all_states_outgoing_transition(): - # we make a mapping between stormvogel and stormpy ids in case they arent in order. + # we make a mapping between stormvogel and stormpy ids in case they are out of order. stormpy_id = {} for index, stormvogel_id in enumerate(model.states.keys()): stormpy_id[stormvogel_id] = index @@ -501,7 +501,7 @@ def map_pomdp(sparsepomdp: stormpy.storage.SparsePomdp) -> stormvogel.model.Mode # we add the observations: for state in model.states.values(): - state.new_observation(sparsepomdp.get_observation(state.id)) + state.set_observation(sparsepomdp.get_observation(state.id)) return model diff --git a/stormvogel/model.py b/stormvogel/model.py index 3920314..ce0dbd2 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -7,7 +7,7 @@ Parameter = str -Number = float | Parameter | Fraction +Number = float | Parameter | Fraction | int class ModelType(Enum): @@ -100,7 +100,7 @@ def add_label(self, label: str): if label not in self.labels: self.labels.append(label) - def new_observation(self, observation: int) -> Observation: + def set_observation(self, observation: int) -> Observation: """sets the observation for this state""" if self.model.get_type() == ModelType.POMDP: self.observation = Observation(observation) @@ -138,10 +138,7 @@ def available_actions(self) -> list["Action"]: action_list.append(action) return action_list else: - raise RuntimeError( - "The model this state belongs to does not support actions" - ) - return EmptyAction + return [EmptyAction] def get_outgoing_transitions( self, action: "Action | None" = None @@ -149,12 +146,11 @@ def get_outgoing_transitions( """gets the outgoing transitions""" if action and self.model.supports_actions(): branch = self.model.transitions[self.id].transition[action] + elif self.model.supports_actions() and not action: + raise RuntimeError("You need to provide a specific action") else: branch = self.model.transitions[self.id].transition[EmptyAction] - if action: - print( - "This model does not support actions, so you don't need to provide one" - ) + return branch.branch def __str__(self): @@ -398,20 +394,16 @@ def supports_observations(self): def is_well_defined(self) -> bool: """Checks if all sums of outgoing transition probabilities for all states equal 1""" - if self.get_type() == ModelType.DTMC: + if self.get_type() in (ModelType.DTMC, ModelType.MDP, ModelType.POMDP): for state in self.states.values(): - sum_prob = 0 - for transition in state.get_outgoing_transitions(): - if isinstance(transition[0], float): - sum_prob += transition[0] - if sum_prob != 1: - return False - elif self.get_type() in (ModelType.POMDP, ModelType.MDP): - for state in self.states.values(): - sum_prob = 0 for action in state.available_actions(): + sum_prob = 0 for transition in state.get_outgoing_transitions(action): - if isinstance(transition[0], float): + if ( + isinstance(transition[0], float) + or isinstance(transition[0], Fraction) + or isinstance(transition[0], int) + ): sum_prob += transition[0] if sum_prob != 1: return False @@ -419,34 +411,16 @@ def is_well_defined(self) -> bool: ModelType.CTMC, ModelType.MA, ): - print("Not implemented") + raise RuntimeError("Not implemented") return True def normalize(self): """Normalizes a model (for states where outgoing transition probabilities don't sum to 1, we divide each probability by the sum)""" - if self.get_type() == ModelType.DTMC: - for state in self.states.values(): - sum_prob = 0 - for tuple in state.get_outgoing_transitions(): - if isinstance(tuple[0], float) or isinstance(tuple[0], Fraction): - sum_prob += tuple[0] - - new_transitions = [] - for tuple in state.get_outgoing_transitions(): - if isinstance(tuple[0], float) or isinstance(tuple[0], Fraction): - normalized_transition = ( - tuple[0] / sum_prob, - tuple[1], - ) - new_transitions.append(normalized_transition) - self.transitions[state.id].transition[ - EmptyAction - ].branch = new_transitions - elif self.get_type() in (ModelType.POMDP, ModelType.MDP): + if self.get_type() in (ModelType.DTMC, ModelType.POMDP, ModelType.MDP): for state in self.states.values(): - sum_prob = 0 for action in state.available_actions(): + sum_prob = 0 for tuple in state.get_outgoing_transitions(action): if isinstance(tuple[0], float) or isinstance( tuple[0], Fraction @@ -470,7 +444,7 @@ def normalize(self): ModelType.CTMC, ModelType.MA, ): - print("Not implemented") + raise RuntimeError("Not implemented") def __free_state_id(self): """Gets a free id in the states dict."""