Skip to content

Commit

Permalink
Merge pull request #133 from moves-rwth/125-change-remove_state-funci…
Browse files Browse the repository at this point in the history
…on-to-allow-for-reasigning-of-ids-without-normalizing

125 change remove state funcion to allow for reasigning of ids without normalizing
  • Loading branch information
PimLeerkes authored Nov 7, 2024
2 parents 10fda32 + 462bfdd commit 67f7ff0
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 95 deletions.
5 changes: 1 addition & 4 deletions examples/simple_ma.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def create_simple_ma():

init = ma.get_initial_state()

# We have 2 actions
# We have 5 actions
init.set_transitions(
[
(
Expand All @@ -29,9 +29,6 @@ def create_simple_ma():
# we add self loops to all states with no outgoing transitions
ma.add_self_loops()

# we delete a state
ma.remove_state(ma.get_state_by_id(3), True)

return ma


Expand Down
173 changes: 100 additions & 73 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
"There is already a state with this id. Make sure the id is unique."
)

names = [state.name for state in self.model.states.values()]
if name in names:
used_names = [state.name for state in self.model.states.values()]
if name in used_names:
raise RuntimeError(
"There is already a state with this name. Make sure the name is unique."
)
Expand Down Expand Up @@ -142,16 +142,28 @@ def available_actions(self) -> list["Action"]:

def get_outgoing_transitions(
self, action: "Action | None" = None
) -> list[tuple[Number, "State"]]:
) -> list[tuple[Number, "State"]] | None:
"""gets the outgoing transitions"""
if action and self.model.supports_actions():
branch = self.model.transitions[self.id].transition[action]
if self.id in self.model.transitions.keys():
branch = self.model.transitions[self.id].transition[action]
return branch.branch
elif self.model.supports_actions() and not action:
raise RuntimeError("You need to provide a specific action")
else:
branch = self.model.transitions[self.id].transition[EmptyAction]

return branch.branch
if self.id in self.model.transitions.keys():
branch = self.model.transitions[self.id].transition[EmptyAction]
return branch.branch
return None

def is_absorbing(self, action: "Action | None" = None) -> bool:
"""returns if the state has a nonzero transition going to another state or not"""
transitions = self.get_outgoing_transitions(action)
if transitions is not None:
for transition in transitions:
if float(transition[0]) > 0 and transition[1] != self:
return False
return True

def __str__(self):
res = f"State {self.id} with labels {self.labels} and features {self.features}"
Expand All @@ -162,16 +174,16 @@ def __str__(self):
def __eq__(self, other):
if isinstance(other, State):
if self.id == other.id:
self.labels.sort()
other.labels.sort()
if self.model.supports_observations():
if self.observation is not None and other.observation is not None:
observations_equal = self.observation == other.observation
else:
observations_equal = True
else:
observations_equal = True
return self.labels == other.labels and observations_equal
return (
sorted(self.labels) == sorted(other.labels) and observations_equal
)
return False
return False

Expand Down Expand Up @@ -222,9 +234,7 @@ def __str__(self):

def __eq__(self, other):
if isinstance(other, Branch):
self.branch.sort()
other.branch.sort()
return self.branch == other.branch
return sorted(self.branch) == sorted(other.branch)
return False

def __add__(self, other):
Expand Down Expand Up @@ -261,11 +271,9 @@ def __str__(self):

def __eq__(self, other):
if isinstance(other, Transition):
self_values = list(self.transition.values())
other_values = list(other.transition.values())
self_values.sort()
other_values.sort()
return self_values == other_values
return sorted(list(self.transition.values())) == sorted(
list(other.transition.values())
)
return False

def has_empty_action(self) -> bool:
Expand Down Expand Up @@ -309,7 +317,6 @@ def transition_from_shorthand(shorthand: TransitionShorthand) -> Transition:
@dataclass(order=True)
class RewardModel:
"""Represents a state-exit reward model.
dtmc.delete_state(dtmc.get_state_by_id(1), True, True)
Args:
name: Name of the reward model.
rewards: The rewards, the keys are the state's ids (or state action pair ids).
Expand Down Expand Up @@ -340,10 +347,10 @@ class Model:
name: An optional name for this model.
type: The model type.
states: The states of the model. The keys are the state's ids.
transitions: The transitions of this model.
actions: The actions of the model, if this is a model that supports actions.
rewards: The rewardsmodels of this model.
exit_rates: The exit rates of the model, optional if this model supports rates.
transitions: The transitions of this model.
markovian_states: list of markovian states in the case of a ma.
"""

Expand Down Expand Up @@ -392,7 +399,7 @@ def __init__(
else:
self.markovian_states = None

# Add the initial state
# Add the initial state if specified to do so
if create_initial_state:
self.new_state(["init"])

Expand All @@ -408,14 +415,18 @@ def supports_observations(self):
"""Returns whether this model supports observations."""
return self.type == ModelType.POMDP

def is_well_defined(self) -> bool:
"""Checks if all sums of outgoing transition probabilities for all states equal 1"""
def is_stochastic(self) -> bool:
"""For discrete models: Checks if all sums of outgoing transition probabilities for all states equal 1
For continuous models: Checks if all sums of outgoing rates sum to 0
"""

if self.get_type() in (ModelType.DTMC, ModelType.MDP, ModelType.POMDP):
if not self.supports_rates():
for state in self.states.values():
for action in state.available_actions():
sum_prob = 0
for transition in state.get_outgoing_transitions(action):
transitions = state.get_outgoing_transitions(action)
assert transitions is not None
for transition in transitions:
if (
isinstance(transition[0], float)
or isinstance(transition[0], Fraction)
Expand All @@ -424,23 +435,34 @@ def is_well_defined(self) -> bool:
sum_prob += transition[0]
if sum_prob != 1:
return False
elif self.get_type() in (
ModelType.CTMC,
ModelType.MA,
):
# TODO make it work for these models
raise RuntimeError("Not implemented")
else:
for state in self.states.values():
for action in state.available_actions():
sum_rates = 0
transitions = state.get_outgoing_transitions(action)
assert transitions is not None
for transition in transitions:
if (
isinstance(transition[0], float)
or isinstance(transition[0], Fraction)
or isinstance(transition[0], int)
):
sum_rates += transition[0]
if sum_rates != 0:
return False

return True

def normalize(self):
"""Normalizes a model (for states where outgoing transition probabilities don't sum to 1, we divide each probability by the sum)"""
if self.get_type() in (ModelType.DTMC, ModelType.POMDP, ModelType.MDP):
if not self.supports_rates():
self.add_self_loops()
for state in self.states.values():
for action in state.available_actions():
sum_prob = 0
for tuple in state.get_outgoing_transitions(action):
transitions = state.get_outgoing_transitions(action)
assert transitions is not None
for tuple in transitions:
if (
isinstance(tuple[0], float)
or isinstance(tuple[0], Fraction)
Expand All @@ -449,7 +471,7 @@ def normalize(self):
sum_prob += tuple[0]

new_transitions = []
for tuple in state.get_outgoing_transitions(action):
for tuple in transitions:
if (
isinstance(tuple[0], float)
or isinstance(tuple[0], Fraction)
Expand All @@ -463,11 +485,8 @@ def normalize(self):
self.transitions[state.id].transition[
action
].branch = new_transitions
elif self.get_type() in (
ModelType.CTMC,
ModelType.MA,
):
# TODO: As of now, for the CTMCs and MAs we only add self loops
else:
# for ctmcs and mas we currently only add self loops
self.add_self_loops()

def __free_state_id(self):
Expand All @@ -482,15 +501,16 @@ def add_self_loops(self):
"""adds self loops to all states that do not have an outgoing transition"""
for id, state in self.states.items():
if self.transitions.get(id) is None:
self.set_transitions(state, [(float(1), state)])
self.set_transitions(
state, [(float(0) if self.supports_rates() else float(1), state)]
)

def all_states_outgoing_transition(self) -> bool:
"""checks if all states have an outgoing transition"""
all_states_outgoing_transition = True
for state in self.states.items():
if self.transitions.get(state[0]) is None:
all_states_outgoing_transition = False
return all_states_outgoing_transition
return False
return True

def add_markovian_state(self, markovian_state: State):
"""Adds a state to the markovian states."""
Expand Down Expand Up @@ -581,27 +601,51 @@ def new_action(self, name: str, labels: frozenset[str] | None = None) -> Action:
self.actions[name] = action
return action

def remove_state(self, state: State, normalize_and_reassign_ids: bool = True):
def reassign_ids(self):
"""reassigns the ids of states, transitions and rates to be in order again"""
self.states = {
new_id: value
for new_id, (old_id, value) in enumerate(sorted(self.states.items()))
}

self.transitions = {
new_id: value
for new_id, (old_id, value) in enumerate(sorted(self.transitions.items()))
}

if self.supports_rates and self.exit_rates is not None:
self.exit_rates = {
new_id: value
for new_id, (old_id, value) in enumerate(
sorted(self.exit_rates.items())
)
}

def remove_state(
self, state: State, normalize: bool = True, reassign_ids: bool = True
):
"""properly removes a state, it can optionally normalize the model and reassign ids automatically"""
if state in self.states.values():
# we remove the state from the transitions
# first we remove transitions that go into the state
remove_actions_index = []
for index, transition in self.transitions.items():
for action in transition.transition.items():
for index_tuple, tuple in enumerate(action[1].branch):
for action, branch in transition.transition.items():
for index_tuple, tuple in enumerate(branch.branch):
# remove the tuple if it goes to the state
if tuple[1].id == state.id:
self.transitions[index].transition[action[0]].branch.pop(
self.transitions[index].transition[action].branch.pop(
index_tuple
)

# if we have empty objects we need to remove those as well
if self.transitions[index].transition[action[0]].branch == []:
remove_actions_index.append((action[0], index))
# here we remove those empty objects
# if we have empty actions we need to remove those as well (later)
if branch.branch == []:
remove_actions_index.append((action, index))
# here we remove those empty actions (this needs to happen after the other for loops)
for action, index in remove_actions_index:
self.transitions[index].transition.pop(action)
if self.transitions[index].transition == {}:
# if we have no actions at all anymore, delete the transition
if self.transitions[index].transition == {} and not index == state.id:
self.transitions.pop(index)

# we remove transitions that come out of the state
Expand All @@ -620,33 +664,16 @@ def remove_state(self, state: State, normalize_and_reassign_ids: bool = True):
self.markovian_states.remove(state)

# we normalize the model if specified to do so
if normalize_and_reassign_ids:
if normalize:
self.normalize()

self.states = {
new_id: value
for new_id, (old_id, value) in enumerate(
sorted(self.states.items())
)
}
# we reassign the ids if specified to do so
if reassign_ids:
self.reassign_ids()
for other_state in self.states.values():
if other_state.id > state.id:
other_state.id -= 1

self.transitions = {
new_id: value
for new_id, (old_id, value) in enumerate(
sorted(self.transitions.items())
)
}
if self.supports_rates and self.exit_rates is not None:
self.exit_rates = {
new_id: value
for new_id, (old_id, value) in enumerate(
sorted(self.exit_rates.items())
)
}

def remove_transitions_between_states(
self, state0: State, state1: State, normalize: bool = True
):
Expand Down
27 changes: 15 additions & 12 deletions stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,18 @@ def get_range_index(stateid: int):
assert simulator is not None

# we start adding states or state action pairs to the path
state = 0
path = {}
simulator.restart()
if not model.supports_actions():
path = {}
simulator.restart()
for i in range(steps):
# for each step we add a state to the path
state, reward, labels = simulator.step()
path[i + 1] = model.states[state]
if simulator.is_done():
if not model.states[state].is_absorbing() and not simulator.is_done():
state, reward, labels = simulator.step()
path[i + 1] = model.states[state]
else:
break
else:
state = 0
path = {}
simulator.restart()
for i in range(steps):
# we first choose an action (randomly or according to scheduler)
actions = simulator.available_actions()
Expand All @@ -155,10 +154,14 @@ def get_range_index(stateid: int):

# we add the state action pair to the path
stormvogel_action = model.states[state].available_actions()[select_action]
next_step = simulator.step(actions[select_action])
state, reward, labels = next_step
path[i + 1] = (stormvogel_action, model.states[state])
if simulator.is_done():

if (
not model.states[state].is_absorbing(stormvogel_action)
and not simulator.is_done()
):
state, reward, labels = simulator.step(actions[select_action])
path[i + 1] = (stormvogel_action, model.states[state])
else:
break

path_object = Path(path, model)
Expand Down
Loading

0 comments on commit 67f7ff0

Please sign in to comment.