Skip to content

Commit

Permalink
normalize, delete_state, is_well_defined and get_outgoing_transitions…
Browse files Browse the repository at this point in the history
… now also work for mdps and pomdps
  • Loading branch information
PimLeerkes committed Sep 28, 2024
1 parent 7c418b8 commit 87e9044
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 17 deletions.
3 changes: 2 additions & 1 deletion examples/die.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def create_die_dtmc():
# we add self loops to all states with no outgoing transitions
dtmc.add_self_loops()

dtmc.delete_state(dtmc.get_state_by_id(1), True, True)
# test if state deletion works
# dtmc.delete_state(dtmc.get_state_by_id(1), True, True)

return dtmc

Expand Down
63 changes: 47 additions & 16 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,17 +141,21 @@ def available_actions(self) -> list["Action"]:
raise RuntimeError(
"The model this state belongs to does not support actions"
)

def get_outgoing_transitions(self) -> list[tuple[Number, "State"]]:
"""gets the outgoing transitions (only works if the model does not support actions)"""
if not self.model.supports_actions():
branch = list(self.model.transitions[self.id].transition.values())[0]
return branch.branch
return EmptyAction

def get_outgoing_transitions(
self, action: "Action | None" = None
) -> list[tuple[Number, "State"]]:
"""gets the outgoing transitions"""
if action and self.model.supports_actions():
branch = self.model.transitions[self.id].transition[action]
else:
raise RuntimeError(
"This method does not yet work for models that support actions"
)
return []
branch = self.model.transitions[self.id].transition[EmptyAction]
if action:
print(
"This model does not support actions, so you don't need to provide one"
)
return branch.branch

def __str__(self):
res = f"State {self.id} with labels {self.labels} and features {self.features}"
Expand Down Expand Up @@ -295,7 +299,7 @@ def transition_from_shorthand(shorthand: TransitionShorthand) -> Transition:
@dataclass
class RewardModel:
"""Represents a state-exit reward model.
dtmc.delete_state(dtmc.get_state_by_id(1), True, True)
Args:
name: Name of the reward model.
rewards: The rewards, the keys are the state's ids (or state action pair ids).
Expand Down Expand Up @@ -402,9 +406,16 @@ def is_well_defined(self) -> bool:
sum_prob += transition[0]
if sum_prob != 1:
return False
elif self.get_type() in (ModelType.POMDP, ModelType.MDP):
for state in self.states.values():
sum_prob = 0
for action in state.available_actions():
for transition in state.get_outgoing_transitions(action):
if isinstance(transition[0], float):
sum_prob += transition[0]
if sum_prob != 1:
return False
elif self.get_type() in (
ModelType.POMDP,
ModelType.MDP,
ModelType.CTMC,
ModelType.MA,
):
Expand Down Expand Up @@ -432,10 +443,30 @@ def normalize(self):
self.transitions[state.id].transition[
EmptyAction
].branch = new_transitions

elif self.get_type() in (ModelType.POMDP, ModelType.MDP):
for state in self.states.values():
sum_prob = 0
for action in state.available_actions():
for tuple in state.get_outgoing_transitions(action):
if isinstance(tuple[0], float) or isinstance(
tuple[0], Fraction
):
sum_prob += tuple[0]

new_transitions = []
for tuple in state.get_outgoing_transitions(action):
if isinstance(tuple[0], float) or isinstance(
tuple[0], Fraction
):
normalized_transition = (
tuple[0] / sum_prob,
tuple[1],
)
new_transitions.append(normalized_transition)
self.transitions[state.id].transition[
action
].branch = new_transitions
elif self.get_type() in (
ModelType.POMDP,
ModelType.MDP,
ModelType.CTMC,
ModelType.MA,
):
Expand Down

0 comments on commit 87e9044

Please sign in to comment.