diff --git a/stormvogel/model.py b/stormvogel/model.py index ad34686..ef9cd3e 100644 --- a/stormvogel/model.py +++ b/stormvogel/model.py @@ -227,8 +227,10 @@ def __eq__(self, other): return self.branch == other.branch return False + def __add__(self, other): + return Branch(self.branch + other.branch) + -@dataclass class Transition: """Represents a transition, which map actions to branches. Note that an EmptyAction may be used if we want a non-action transition. @@ -240,6 +242,14 @@ class Transition: transition: dict[Action, Branch] + def __init__(self, transition: dict[Action, Branch]): + # Input validation, see RuntimeError. + if len(transition) > 1 and EmptyAction in transition: + raise RuntimeError( + "It is impossible to create a transition that contains more than one action, and an emtpy action" + ) + self.transition = transition + def __str__(self): parts = [] for action, branch in self.transition.items(): @@ -258,6 +268,10 @@ def __eq__(self, other): return self_values == other_values return False + def has_empty_action(self) -> bool: + # Note that we don't have to deal with the corner case where there are both empty and non-empty transitions. This is dealt with at __init__. + return self.transition.keys() == {EmptyAction} + TransitionShorthand = list[tuple[Number, State]] | list[tuple[Action, State]] @@ -489,23 +503,56 @@ def set_transitions(self, s: State, transitions: Transition | TransitionShorthan self.transitions[s.id] = transitions def add_transitions(self, s: State, transitions: Transition | TransitionShorthand): - """Add new transitions from a state.""" + """Add new transitions from a state. If no transition currently exists, the result will be the same as set_transitions.""" if not self.supports_actions(): raise RuntimeError( - "In a model that does not support actions, you have to set transitions, not add them" + "Models without actions do not support add_transitions. Use set_transitions instead." ) if not isinstance(transitions, Transition): transitions = transition_from_shorthand(transitions) - for choice, branch in transitions.transition.items(): - self.transitions[s.id].transition[choice] = branch + try: + existing_transitions = self.get_transitions(s) + except KeyError: + # Empty transitions case, act like set_transitions. + self.set_transitions(s, transitions) + return + + # Adding a transition is only valid if they are both empty or both non-empty. + if ( + not transitions.has_empty_action() + and existing_transitions.has_empty_action() + ): + raise RuntimeError( + "You cannot add a transition with an non-empty action to a transition which has an empty action. Use set_transition instead." + ) + if ( + transitions.has_empty_action() + and not existing_transitions.has_empty_action() + ): + raise RuntimeError( + "You cannot add a transition with an empty action to a transition which has no empty action. Use set_transition instead." + ) + + # Empty action case, add the branches together. + if transitions.has_empty_action(): + self.transitions[s.id].transition[EmptyAction] += transitions.transition[ + EmptyAction + ] + else: + for choice, branch in transitions.transition.items(): + self.transitions[s.id].transition[choice] = branch - def get_transitions(self, s: State) -> Transition: - """Get the transition at state s.""" - return self.transitions[s.id] + def get_transitions(self, state_or_id: State | int) -> Transition: + """Get the transition at state s. Throws a KeyError if not present.""" + if isinstance(state_or_id, State): + return self.transitions[state_or_id.id] + else: + return self.transitions[state_or_id] - def get_branch(self, s: State) -> Branch: + def get_branch(self, state_or_id: State | int) -> Branch: """Get the branch at state s. Only intended for emtpy transitions, otherwise a RuntimeError is thrown.""" - transition = self.transitions[s.id].transition + s_id = state_or_id if isinstance(state_or_id, int) else state_or_id.id + transition = self.transitions[s_id].transition if EmptyAction not in transition: raise RuntimeError("Called get_branch on a non-empty transition.") return transition[EmptyAction] diff --git a/tests/test_model_creation.py b/tests/test_model_creation.py index 547e300..084e059 100644 --- a/tests/test_model_creation.py +++ b/tests/test_model_creation.py @@ -1,6 +1,5 @@ # content of test_sysexit.py import stormvogel.model -from stormvogel.model import EmptyAction def test_mdp_creation(): @@ -21,4 +20,4 @@ def test_mdp_creation(): # Check that all states 1..6 have self loops for i in range(1, 7): # yeah we need transition getting syntax - assert dtmc.transitions[i].transition[EmptyAction].branch[0][1].id == i + assert dtmc.get_branch(i).branch[0][1].id == i diff --git a/tests/test_model_methods.py b/tests/test_model_methods.py index bc3074a..2b3dee1 100644 --- a/tests/test_model_methods.py +++ b/tests/test_model_methods.py @@ -164,3 +164,84 @@ def test_remove_transitions_between_states(): new_dtmc.add_self_loops() assert dtmc == new_dtmc + + +def test_add_transitions(): + dtmc = stormvogel.model.new_dtmc() + state = dtmc.new_state() + # A non-action model should throw an exception. + with pytest.raises(RuntimeError) as excinfo: + dtmc.add_transitions( + dtmc.get_initial_state(), + [(0.5, state)], + ) + assert ( + str(excinfo.value) + == "Models without actions do not support add_transitions. Use set_transitions instead." + ) + + # Empty transition case, act exactly like set_transitions. + mdp = stormvogel.model.new_mdp() + state = mdp.new_state() + mdp.add_transitions( + mdp.get_initial_state(), + [(0.5, state)], + ) + mdp2 = stormvogel.model.new_mdp() + state2 = mdp2.new_state() + mdp2.set_transitions( + mdp2.get_initial_state(), + [(0.5, state2)], + ) + assert mdp == mdp2 + + # Fail to add a real action to an empty action. + mdp3 = stormvogel.model.new_mdp() + state3 = mdp2.new_state() + mdp3.set_transitions( + mdp3.get_initial_state(), + [(0.5, state3)], + ) + action3 = mdp3.new_action("action") + with pytest.raises(RuntimeError) as excinfo: + mdp3.add_transitions(mdp3.get_initial_state(), [(action3, state3)]) + assert ( + str(excinfo.value) + == "You cannot add a transition with an non-empty action to a transition which has an empty action. Use set_transition instead." + ) + # And the other way round. + mdp3 = stormvogel.model.new_mdp() + state3 = mdp2.new_state() + action3 = mdp3.new_action("action") + mdp3.set_transitions( + mdp3.get_initial_state(), + [(action3, state3)], + ) + + with pytest.raises(RuntimeError) as excinfo: + mdp3.add_transitions(mdp3.get_initial_state(), [(0.5, state3)]) + assert ( + str(excinfo.value) + == "You cannot add a transition with an empty action to a transition which has no empty action. Use set_transition instead." + ) + + # Empty action case, add the branches together. + mdp5 = stormvogel.model.new_mdp() + state5 = mdp5.new_state() + mdp5.set_transitions(mdp5.get_initial_state(), [((0.4), state5)]) + mdp5.add_transitions(mdp5.get_initial_state(), [(0.6, state5)]) + assert mdp5.get_branch(mdp5.get_initial_state()).branch == [ + ((0.4), state5), + (0.6, state5), + ] + + # Non-empty action case, add the actions to the list. + mdp6 = stormvogel.model.new_mdp() + state6 = mdp6.new_state() + action6a = mdp6.new_action("a") + action6b = mdp6.new_action("b") + mdp6.set_transitions(mdp6.get_initial_state(), [(action6a, state6)]) + mdp6.add_transitions(mdp6.get_initial_state(), [(action6b, state6)]) + print(mdp6.get_transitions(mdp6.get_initial_state()).transition) + print([(action6a, state6), (action6b, state6)]) + assert len(mdp6.get_transitions(mdp6.get_initial_state()).transition) == 2