Skip to content

Commit

Permalink
Transitions and paths get compared better
Browse files Browse the repository at this point in the history
  • Loading branch information
PimLeerkes committed Nov 14, 2024
1 parent f46c2c8 commit 10ee288
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 19 deletions.
50 changes: 35 additions & 15 deletions stormvogel/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ class Action:
def __str__(self):
return f"Action {self.name} with labels {self.labels}"

def __eq__(self, other):
if isinstance(other, Action):
return self.labels == other.labels
return False


# The empty action. Used for DTMCs and empty action transitions in mdps.
EmptyAction = Action("empty", frozenset())
Expand Down Expand Up @@ -270,17 +275,22 @@ def __str__(self):
parts.append(f"{action} => {branch}")
return "; ".join(parts + [])

def __eq__(self, other):
if isinstance(other, Transition):
return sorted(list(self.transition.values())) == sorted(
list(other.transition.values())
)
return False

def has_empty_action(self) -> bool:
# Note that we don't have to deal with the corner case where there are both empty and non-empty transitions. This is dealt with at __init__.
return self.transition.keys() == {EmptyAction}

def __eq__(self, other):
if isinstance(other, Transition):
if len(self.transition) != len(other.transition):
return False
for item, other_item in zip(
sorted(self.transition.items()), sorted(other.transition.items())
):
if not (item[0] == other_item[0] and item[1] == other_item[1]):
return False
return True
return False


TransitionShorthand = list[tuple[Number, State]] | list[tuple[Action, State]]

Expand Down Expand Up @@ -542,7 +552,7 @@ def set_transitions(self, s: State, transitions: Transition | TransitionShorthan
self.transitions[s.id] = transitions

def add_transitions(self, s: State, transitions: Transition | TransitionShorthand):
"""Add new transitions from a state. If no transition currently exists, the result will be the same as set_transitions."""
"""Add new transitions from a state to the model. If no transition currently exists, the result will be the same as set_transitions."""

if not isinstance(transitions, Transition):
transitions = transition_from_shorthand(transitions)
Expand Down Expand Up @@ -581,8 +591,11 @@ def add_transitions(self, s: State, transitions: Transition | TransitionShorthan
transitions.transition[EmptyAction]
)
else:
for choice, branch in transitions.transition.items():
self.transitions[s.id].transition[choice] = branch
for action, branch in transitions.transition.items():
assert self.actions is not None
if action not in self.actions.values():
self.actions[action.name] = action
self.transitions[s.id].transition[action] = branch

def get_transitions(self, state_or_id: State | int) -> Transition:
"""Get the transition at state s. Throws a KeyError if not present."""
Expand Down Expand Up @@ -610,10 +623,7 @@ def new_action(self, name: str, labels: frozenset[str] | None = None) -> Action:
raise RuntimeError(
f"Tried to add action {name} but that action already exists"
)
if labels:
action = Action(name, labels)
else:
action = Action(name, frozenset())
action = Action(name, labels if labels else frozenset())
self.actions[name] = action
return action

Expand Down Expand Up @@ -887,14 +897,24 @@ def __str__(self) -> str:

def __eq__(self, other) -> bool:
if isinstance(other, Model):
# TODO compare action dicts
# if self.supports_actions():
# actions_equal = sorted(self.actions.values()) == sorted(
# other.actions.values()
# )
# else:
# actions_equal = True

# if not actions_equal:
# print(self.actions,'\n', other.actions)
return (
self.type == other.type
and self.states == other.states
and self.transitions == other.transitions
and sorted(self.rewards) == sorted(other.rewards)
and self.exit_rates == other.exit_rates
and self.markovian_states == other.markovian_states
# TODO: and self.actions == other.actions
# and actions_equal
)
return False

Expand Down
16 changes: 15 additions & 1 deletion stormvogel/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,21 @@ def __str__(self) -> str:

def __eq__(self, other):
if isinstance(other, Path):
return self.path == other.path and self.model == other.model
if not self.model.supports_actions():
return self.path == other.path and self.model == other.model
else:
if len(self.path) != len(other.path):
return False
for tuple, other_tuple in zip(
sorted(self.path.values()), sorted(other.path.values())
):
assert not (
isinstance(tuple, stormvogel.model.State)
or isinstance(other_tuple, stormvogel.model.State)
)
if not (tuple[0] == other_tuple[0] and tuple[1] == other_tuple[1]):
return False
return self.model == other.model
else:
return False

Expand Down
6 changes: 3 additions & 3 deletions tests/test_simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def test_simulate_path():
other_path = stormvogel.simulator.Path(
{
1: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(3)),
2: (pomdp.actions["open2"], pomdp.get_state_by_id(12)),
3: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(23)),
4: (pomdp.actions["switch"], pomdp.get_state_by_id(46)),
2: (pomdp.actions["open0"], pomdp.get_state_by_id(10)),
3: (stormvogel.model.EmptyAction, pomdp.get_state_by_id(21)),
4: (pomdp.actions["stay"], pomdp.get_state_by_id(41)),
},
pomdp,
)
Expand Down

0 comments on commit 10ee288

Please sign in to comment.