Skip to content

Commit

Permalink
refactoring + is_stochastic function for continous time models
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Nov 3, 2024
1 parent b612e31 commit e876e34
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
53 changes: 28 additions & 25 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,8 @@ def __init__(
"There is already a state with this id. Make sure the id is unique."
)

names = [state.name for state in self.model.states.values()]
if name in names:
used_names = [state.name for state in self.model.states.values()]
if name in used_names:
raise RuntimeError(
"There is already a state with this name. Make sure the name is unique."
)
Expand Down Expand Up @@ -162,16 +162,16 @@ def __str__(self):
def __eq__(self, other):
if isinstance(other, State):
if self.id == other.id:
self.labels.sort()
other.labels.sort()
if self.model.supports_observations():
if self.observation is not None and other.observation is not None:
observations_equal = self.observation == other.observation
else:
observations_equal = True
else:
observations_equal = True
return self.labels == other.labels and observations_equal
return (
sorted(self.labels) == sorted(other.labels) and observations_equal
)
return False
return False

Expand Down Expand Up @@ -222,9 +222,7 @@ def __str__(self):

def __eq__(self, other):
if isinstance(other, Branch):
self.branch.sort()
other.branch.sort()
return self.branch == other.branch
return sorted(self.branch) == sorted(other.branch)
return False

def __add__(self, other):
Expand Down Expand Up @@ -261,11 +259,9 @@ def __str__(self):

def __eq__(self, other):
if isinstance(other, Transition):
self_values = list(self.transition.values())
other_values = list(other.transition.values())
self_values.sort()
other_values.sort()
return self_values == other_values
return sorted(list(self.transition.values())) == sorted(
list(other.transition.values())
)
return False

def has_empty_action(self) -> bool:
Expand Down Expand Up @@ -309,7 +305,6 @@ def transition_from_shorthand(shorthand: TransitionShorthand) -> Transition:
@dataclass(order=True)
class RewardModel:
"""Represents a state-exit reward model.
dtmc.delete_state(dtmc.get_state_by_id(1), True, True)
Args:
name: Name of the reward model.
rewards: The rewards, the keys are the state's ids (or state action pair ids).
Expand Down Expand Up @@ -340,10 +335,10 @@ class Model:
name: An optional name for this model.
type: The model type.
states: The states of the model. The keys are the state's ids.
transitions: The transitions of this model.
actions: The actions of the model, if this is a model that supports actions.
rewards: The rewardsmodels of this model.
exit_rates: The exit rates of the model, optional if this model supports rates.
transitions: The transitions of this model.
markovian_states: list of markovian states in the case of a ma.
"""

Expand Down Expand Up @@ -392,7 +387,7 @@ def __init__(
else:
self.markovian_states = None

# Add the initial state
# Add the initial state if specified to do so
if create_initial_state:
self.new_state(["init"])

Expand All @@ -408,10 +403,12 @@ def supports_observations(self):
"""Returns whether this model supports observations."""
return self.type == ModelType.POMDP

def is_well_defined(self) -> bool:
"""Checks if all sums of outgoing transition probabilities for all states equal 1"""
def is_stochastic(self) -> bool:
"""For discrete models: Checks if all sums of outgoing transition probabilities for all states equal 1
For continuous models: Checks if all sums of outgoing rates sum to 0
"""

if self.get_type() in (ModelType.DTMC, ModelType.MDP, ModelType.POMDP):
if not self.supports_rates():
for state in self.states.values():
for action in state.available_actions():
sum_prob = 0
Expand All @@ -424,12 +421,18 @@ def is_well_defined(self) -> bool:
sum_prob += transition[0]
if sum_prob != 1:
return False
elif self.get_type() in (
ModelType.CTMC,
ModelType.MA,
):
# TODO make it work for these models
raise RuntimeError("Not implemented")
else:
for state in self.states.values():
sum_rates = 0
for transition in state.get_outgoing_transitions():
if (
isinstance(transition[0], float)
or isinstance(transition[0], Fraction)
or isinstance(transition[0], int)
):
sum_rates += transition[0]
if sum_rates != 0:
return False

return True

Expand Down
10 changes: 5 additions & 5 deletions tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,26 @@ def test_transition_from_shorthand():
)


def test_is_well_defined():
# we check for an instance where it is not well defined
def test_is_stochastic():
# we check for an instance where it is not stochastic
dtmc = stormvogel.model.new_dtmc()
state = dtmc.new_state()
dtmc.set_transitions(
dtmc.get_initial_state(),
[(1 / 2, state)],
)

assert not dtmc.is_well_defined()
assert not dtmc.is_stochastic()

# we check for an instance where it is well defined
# we check for an instance where it is stochastic
dtmc.set_transitions(
dtmc.get_initial_state(),
[(1 / 2, state), (1 / 2, state)],
)

dtmc.add_self_loops()

assert dtmc.is_well_defined()
assert dtmc.is_stochastic()


def test_normalize():
Expand Down

0 comments on commit e876e34

Please sign in to comment.