diff --git a/examples/monty_hall_pomdp.py b/examples/monty_hall_pomdp.py index fdaca80..7f2c11b 100644 --- a/examples/monty_hall_pomdp.py +++ b/examples/monty_hall_pomdp.py @@ -71,9 +71,6 @@ def create_monty_hall_pomdp(): for state in pomdp.states.values(): state.new_observation(0) - # delete a state: - # pomdp.delete_state(pomdp.get_state_by_id(1)) - return pomdp diff --git a/stormvogel/model.py b/stormvogel/model.py index d790b49..063d6f9 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -53,24 +53,47 @@ class State: features: The features of this state. Corresponds to Storm features. id: The number of this state in the matrix. model: The model this state belongs to. - observation: the observation of this state in case the model is a pomdp + observation: the observation of this state in case the model is a pomdp. + name: the name of this state. """ - # name: str | None labels: list[str] features: dict[str, int] id: int model: "Model" observation: Observation | None + name: str + + def __init__( + self, + labels: list[str], + features: dict[str, int], + id: int, + model, + name: str | None = None, + ): + self.model = model + + if id in self.model.states.keys(): + raise RuntimeError( + "There is already a state with this id. Make sure the id is unique." + ) + + names = [state.name for state in self.model.states.values()] + if name in names: + raise RuntimeError( + "There is already a state with this name. Make sure the name is unique." + ) - def __init__(self, labels: list[str], features: dict[str, int], id: int, model): self.labels = labels self.features = features self.id = id - self.model = model self.observation = None - # TODO how to handle state names? + if name is None: + self.name = str(id) + else: + self.name = name def add_label(self, label: str): """adds a new label to the state""" @@ -517,6 +540,16 @@ def get_state_by_id(self, state_id) -> State: raise RuntimeError("Requested a non-existing state") return self.states[state_id] + def get_state_by_name(self, state_name) -> State | None: + """Get a state by its name.""" + names = [state.name for state in self.states.values()] + if state_name not in names: + raise RuntimeError("Requested a non-existing state") + + for state in self.states.values(): + if state.name == state_name: + return state + def get_initial_state(self) -> State: """Gets the initial state (id=0).""" return self.states[0]