Skip to content

Commit

Permalink
added state names
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Sep 25, 2024
1 parent 4dc5425 commit 35a4e85
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 8 deletions.
3 changes: 0 additions & 3 deletions examples/monty_hall_pomdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,6 @@ def create_monty_hall_pomdp():
for state in pomdp.states.values():
state.new_observation(0)

# delete a state:
# pomdp.delete_state(pomdp.get_state_by_id(1))

return pomdp


Expand Down
43 changes: 38 additions & 5 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,47 @@ class State:
features: The features of this state. Corresponds to Storm features.
id: The number of this state in the matrix.
model: The model this state belongs to.
observation: the observation of this state in case the model is a pomdp
observation: the observation of this state in case the model is a pomdp.
name: the name of this state.
"""

# name: str | None
labels: list[str]
features: dict[str, int]
id: int
model: "Model"
observation: Observation | None
name: str

def __init__(
self,
labels: list[str],
features: dict[str, int],
id: int,
model,
name: str | None = None,
):
self.model = model

if id in self.model.states.keys():
raise RuntimeError(
"There is already a state with this id. Make sure the id is unique."
)

names = [state.name for state in self.model.states.values()]
if name in names:
raise RuntimeError(
"There is already a state with this name. Make sure the name is unique."
)

def __init__(self, labels: list[str], features: dict[str, int], id: int, model):
self.labels = labels
self.features = features
self.id = id
self.model = model
self.observation = None

# TODO how to handle state names?
if name is None:
self.name = str(id)
else:
self.name = name

def add_label(self, label: str):
"""adds a new label to the state"""
Expand Down Expand Up @@ -517,6 +540,16 @@ def get_state_by_id(self, state_id) -> State:
raise RuntimeError("Requested a non-existing state")
return self.states[state_id]

def get_state_by_name(self, state_name) -> State | None:
"""Get a state by its name."""
names = [state.name for state in self.states.values()]
if state_name not in names:
raise RuntimeError("Requested a non-existing state")

for state in self.states.values():
if state.name == state_name:
return state

def get_initial_state(self) -> State:
"""Gets the initial state (id=0)."""
return self.states[0]
Expand Down

0 comments on commit 35a4e85

Please sign in to comment.