Skip to content

Commit

Permalink
remove transitions and get outgoing transitions methods
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Sep 25, 2024
1 parent 35a4e85 commit 95f9c2a
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
6 changes: 3 additions & 3 deletions examples/monty_hall.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def create_monty_hall_mdp():
)

# we choose a door in each case
for s in mdp.get_states_with("carchosen"):
for s in mdp.get_states_with_label("carchosen"):
s.set_transitions(
[
(
Expand All @@ -24,7 +24,7 @@ def create_monty_hall_mdp():
)

# the other goat is revealed
for s in mdp.get_states_with("open"):
for s in mdp.get_states_with_label("open"):
car_pos = s.features["car_pos"]
chosen_pos = s.features["chosen_pos"]
other_pos = {0, 1, 2} - {car_pos, chosen_pos}
Expand All @@ -39,7 +39,7 @@ def create_monty_hall_mdp():
)

# we must choose whether we want to switch
for s in mdp.get_states_with("goatrevealed"):
for s in mdp.get_states_with_label("goatrevealed"):
car_pos = s.features["car_pos"]
chosen_pos = s.features["chosen_pos"]
reveal_pos = s.features["reveal_pos"]
Expand Down
6 changes: 3 additions & 3 deletions examples/monty_hall_pomdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def create_monty_hall_pomdp():
)

# we choose a door in each case
for s in pomdp.get_states_with("carchosen"):
for s in pomdp.get_states_with_label("carchosen"):
s.set_transitions(
[
(
Expand All @@ -25,7 +25,7 @@ def create_monty_hall_pomdp():
)

# the other goat is revealed
for s in pomdp.get_states_with("open"):
for s in pomdp.get_states_with_label("open"):
car_pos = s.features["car_pos"]
chosen_pos = s.features["chosen_pos"]
other_pos = {0, 1, 2} - {car_pos, chosen_pos}
Expand All @@ -40,7 +40,7 @@ def create_monty_hall_pomdp():
)

# we must choose whether we want to switch
for s in pomdp.get_states_with("goatrevealed"):
for s in pomdp.get_states_with_label("goatrevealed"):
car_pos = s.features["car_pos"]
chosen_pos = s.features["chosen_pos"]
reveal_pos = s.features["reveal_pos"]
Expand Down
33 changes: 29 additions & 4 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ def available_actions(self) -> list["Action"]:
"The model this state belongs to does not support actions"
)

def get_outgoing_transitions(self) -> list[tuple[Number, "State"]] | None:
"""gets the outgoing transitions (only works if the model does not support actions)"""
if not self.model.supports_actions():
branch = list(self.model.transitions[self.id].transition.values())[0]
return branch.branch
else:
raise RuntimeError(
"This method does not yet work for models that support actions"
)

def __str__(self):
res = f"State {self.id} with labels {self.labels} and features {self.features}"
if self.model.supports_observations() and self.observation is not None:
Expand Down Expand Up @@ -445,9 +455,9 @@ def new_action(self, name: str, labels: frozenset[str] | None = None) -> Action:
return action

def delete_state(self, state: State):
"""properly deletes a state and reasigns ids"""
"""properly deletes a state and reassigns ids"""
if state in self.states.values():
# We remove the state and reasign ids
# We remove the state and reassign ids
self.states.pop(state.id)
self.states = {
new_id: value
Expand All @@ -457,7 +467,7 @@ def delete_state(self, state: State):
if other_state.id > state.id:
other_state.id -= 1

# we remove the state from the transitions and reasign the ids
# we remove the state from the transitions and reassign the ids
self.transitions.pop(state.id)
self.transitions = {
new_id: value
Expand All @@ -482,6 +492,21 @@ def delete_state(self, state: State):
if state in self.markovian_states:
self.markovian_states.remove(state)

def delete_transitions_between_states(self, state0: State, state1: State):
"""
Deletes the transition(s) present between the two given states.
Only works on models that don't support actions.
"""
if not self.supports_actions():
for branch in self.transitions[state0.id].transition.values():
for tuple in branch.branch:
if tuple[1] == state1:
branch.branch.remove(tuple)
else:
raise RuntimeError(
"This method only works for models that don't support actions."
)

def get_action(self, name: str) -> Action:
"""Gets an existing action."""
if not self.supports_actions():
Expand Down Expand Up @@ -525,7 +550,7 @@ def new_state(

return state

def get_states_with(self, label: str) -> list[State]:
def get_states_with_label(self, label: str) -> list[State]:
"""Get all states with a given label."""
# TODO: slow, not sure if that will become a problem though
collected_states = []
Expand Down

0 comments on commit 95f9c2a

Please sign in to comment.