Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Sep 30, 2024
1 parent 87e9044 commit 9df560f
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 48 deletions.
2 changes: 1 addition & 1 deletion examples/die.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def create_die_dtmc():
dtmc.add_self_loops()

# test if state deletion works
# dtmc.delete_state(dtmc.get_state_by_id(1), True, True)
dtmc.delete_state(dtmc.get_state_by_id(1), True, True)

return dtmc

Expand Down
4 changes: 3 additions & 1 deletion examples/monty_hall_pomdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def create_monty_hall_pomdp():

# we add the observations TODO: let it make sense
for state in pomdp.states.values():
state.new_observation(0)
state.set_observation(0)

pomdp.normalize()

return pomdp

Expand Down
2 changes: 1 addition & 1 deletion examples/simple_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def create_simple_ma():
ma.add_self_loops()

# we delete a state
ma.delete_state(ma.get_state_by_id(3), True, True)
ma.delete_state(ma.get_state_by_id(3), False, True)

return ma

Expand Down
4 changes: 2 additions & 2 deletions stormvogel/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def map_ma(model: stormvogel.model.Model) -> stormpy.storage.SparseMA:
return ma

if model.all_states_outgoing_transition():
# we make a mapping between stormvogel and stormpy ids in case they arent in order.
# we make a mapping between stormvogel and stormpy ids in case they are out of order.
stormpy_id = {}
for index, stormvogel_id in enumerate(model.states.keys()):
stormpy_id[stormvogel_id] = index
Expand Down Expand Up @@ -501,7 +501,7 @@ def map_pomdp(sparsepomdp: stormpy.storage.SparsePomdp) -> stormvogel.model.Mode

# we add the observations:
for state in model.states.values():
state.new_observation(sparsepomdp.get_observation(state.id))
state.set_observation(sparsepomdp.get_observation(state.id))

return model

Expand Down
60 changes: 17 additions & 43 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

Parameter = str

Number = float | Parameter | Fraction
Number = float | Parameter | Fraction | int


class ModelType(Enum):
Expand Down Expand Up @@ -100,7 +100,7 @@ def add_label(self, label: str):
if label not in self.labels:
self.labels.append(label)

def new_observation(self, observation: int) -> Observation:
def set_observation(self, observation: int) -> Observation:
"""sets the observation for this state"""
if self.model.get_type() == ModelType.POMDP:
self.observation = Observation(observation)
Expand Down Expand Up @@ -138,23 +138,19 @@ def available_actions(self) -> list["Action"]:
action_list.append(action)
return action_list
else:
raise RuntimeError(
"The model this state belongs to does not support actions"
)
return EmptyAction
return [EmptyAction]

def get_outgoing_transitions(
self, action: "Action | None" = None
) -> list[tuple[Number, "State"]]:
"""gets the outgoing transitions"""
if action and self.model.supports_actions():
branch = self.model.transitions[self.id].transition[action]
elif self.model.supports_actions() and not action:
raise RuntimeError("You need to provide a specific action")
else:
branch = self.model.transitions[self.id].transition[EmptyAction]
if action:
print(
"This model does not support actions, so you don't need to provide one"
)

return branch.branch

def __str__(self):
Expand Down Expand Up @@ -398,55 +394,33 @@ def supports_observations(self):
def is_well_defined(self) -> bool:
"""Checks if all sums of outgoing transition probabilities for all states equal 1"""

if self.get_type() == ModelType.DTMC:
if self.get_type() in (ModelType.DTMC, ModelType.MDP, ModelType.POMDP):
for state in self.states.values():
sum_prob = 0
for transition in state.get_outgoing_transitions():
if isinstance(transition[0], float):
sum_prob += transition[0]
if sum_prob != 1:
return False
elif self.get_type() in (ModelType.POMDP, ModelType.MDP):
for state in self.states.values():
sum_prob = 0
for action in state.available_actions():
sum_prob = 0
for transition in state.get_outgoing_transitions(action):
if isinstance(transition[0], float):
if (
isinstance(transition[0], float)
or isinstance(transition[0], Fraction)
or isinstance(transition[0], int)
):
sum_prob += transition[0]
if sum_prob != 1:
return False
elif self.get_type() in (
ModelType.CTMC,
ModelType.MA,
):
print("Not implemented")
raise RuntimeError("Not implemented")

return True

def normalize(self):
"""Normalizes a model (for states where outgoing transition probabilities don't sum to 1, we divide each probability by the sum)"""
if self.get_type() == ModelType.DTMC:
for state in self.states.values():
sum_prob = 0
for tuple in state.get_outgoing_transitions():
if isinstance(tuple[0], float) or isinstance(tuple[0], Fraction):
sum_prob += tuple[0]

new_transitions = []
for tuple in state.get_outgoing_transitions():
if isinstance(tuple[0], float) or isinstance(tuple[0], Fraction):
normalized_transition = (
tuple[0] / sum_prob,
tuple[1],
)
new_transitions.append(normalized_transition)
self.transitions[state.id].transition[
EmptyAction
].branch = new_transitions
elif self.get_type() in (ModelType.POMDP, ModelType.MDP):
if self.get_type() in (ModelType.DTMC, ModelType.POMDP, ModelType.MDP):
for state in self.states.values():
sum_prob = 0
for action in state.available_actions():
sum_prob = 0
for tuple in state.get_outgoing_transitions(action):
if isinstance(tuple[0], float) or isinstance(
tuple[0], Fraction
Expand All @@ -470,7 +444,7 @@ def normalize(self):
ModelType.CTMC,
ModelType.MA,
):
print("Not implemented")
raise RuntimeError("Not implemented")

def __free_state_id(self):
"""Gets a free id in the states dict."""
Expand Down

0 comments on commit 9df560f

Please sign in to comment.