diff --git a/beams/behavior_tree/ActionNode.py b/beams/behavior_tree/ActionNode.py index 647d3d0..a33efd8 100644 --- a/beams/behavior_tree/ActionNode.py +++ b/beams/behavior_tree/ActionNode.py @@ -37,6 +37,7 @@ def __init__( comp_cond=completion_condition, stop_func=None ) # TODO: some standard notion of stop function could be valuable + self.is_set_up = False logger.debug("%s.__init__()" % (self.__class__.__name__)) def setup(self, **kwargs: int) -> None: @@ -51,8 +52,14 @@ def setup(self, **kwargs: int) -> None: # Having this in setup means the workthread should always be running. self.worker.start_work() atexit.register( - self.worker.stop_work + self.shutdown ) # TODO(josh): make sure this cleans up resources when it dies + self.is_set_up = True + + def shutdown(self) -> None: + if self.is_set_up: + self.worker.stop_work() + self.is_set_up = False def initialise(self) -> None: """ diff --git a/beams/tests/conftest.py b/beams/tests/conftest.py index 6d1d224..9743750 100644 --- a/beams/tests/conftest.py +++ b/beams/tests/conftest.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os import sys @@ -6,6 +8,8 @@ import py_trees.logging import pytest +from py_trees.behaviour import Behaviour +from py_trees.trees import BehaviourTree from beams.logging import setup_logging @@ -31,6 +35,41 @@ def ca_env_vars(): os.environ["EPICS_CA_ADDR_LIST"] = "localhost" +class BTCleaner: + """ + Helper to call shutdown early to avoid pytest atexit spam + """ + nodes: list[Behaviour] + trees: list[BehaviourTree] + + def __init__(self): + self.nodes = [] + self.trees = [] + + def register(self, node_or_tree: Behaviour | BehaviourTree): + if isinstance(node_or_tree, Behaviour): + self.nodes.append(node_or_tree) + elif isinstance(node_or_tree, BehaviourTree): + self.trees.append(node_or_tree) + else: + raise TypeError("Can only register Behavior and BehaviorTree instances!") + + def clean(self): + for node in self.nodes: + node.shutdown() + for child_node in node.children: + child_node.shutdown() + for tree in self.trees: + tree.shutdown() + + +@pytest.fixture(scope="function") +def bt_cleaner(): + cleaner = BTCleaner() + yield cleaner + cleaner.clean() + + @contextmanager def cli_args(args): """ diff --git a/beams/tests/test_check_and_do.py b/beams/tests/test_check_and_do.py index 4a82aac..70d5d1a 100644 --- a/beams/tests/test_check_and_do.py +++ b/beams/tests/test_check_and_do.py @@ -12,7 +12,7 @@ logger = logging.getLogger(__name__) -def test_check_and_do(): +def test_check_and_do(bt_cleaner): percentage_complete = Value("i", 0) @wrapped_action_work(loop_period_sec=0.001) @@ -38,6 +38,7 @@ def check_fn(x: Value): check = ConditionNode("check", check_fn, percentage_complete) candd = CheckAndDo("yuhh", check, action) + bt_cleaner.register(candd) candd.setup_with_descendants() while ( diff --git a/beams/tests/test_leaf_node.py b/beams/tests/test_leaf_node.py index a834194..ca1ac58 100644 --- a/beams/tests/test_leaf_node.py +++ b/beams/tests/test_leaf_node.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -def test_action_node(): +def test_action_node(bt_cleaner): # For test percentage_complete = Value("i", 0) @@ -28,6 +28,7 @@ def comp_cond(): action = ActionNode(name="action", work_func=work_func, completion_condition=comp_cond) + bt_cleaner.register(action) action.setup() for _ in range(20): time.sleep(0.01) @@ -35,7 +36,7 @@ def comp_cond(): assert percentage_complete.value == 100 -def test_action_node_timeout(): +def test_action_node_timeout(bt_cleaner): # For test percentage_complete = Value("i", 0) @@ -52,6 +53,7 @@ def comp_cond(): action = ActionNode(name="action", work_func=work_func, completion_condition=comp_cond) + bt_cleaner.register(action) action.setup() while action.status not in ( @@ -64,11 +66,12 @@ def comp_cond(): assert percentage_complete.value != 100 -def test_condition_node(): +def test_condition_node(bt_cleaner): def condition_fn(): return True con = ConditionNode("con", condition_fn) + bt_cleaner.register(con) con.setup() assert con.status == Status.INVALID for _ in range(3): @@ -78,12 +81,13 @@ def condition_fn(): assert con.status == Status.SUCCESS -def test_condition_node_with_arg(): +def test_condition_node_with_arg(bt_cleaner): def check(val): return val value = False con = ConditionNode("con", check, value) + bt_cleaner.register(con) con.setup() assert con.status == Status.INVALID for _ in range(3): diff --git a/beams/tests/test_tree_generator.py b/beams/tests/test_tree_generator.py index fd8dec7..3d836ea 100644 --- a/beams/tests/test_tree_generator.py +++ b/beams/tests/test_tree_generator.py @@ -19,9 +19,10 @@ def test_tree_obj_ser(): assert isinstance(tg.root, CheckAndDo) -def test_tree_obj_execution(request): +def test_tree_obj_execution(request, bt_cleaner): fname = Path(__file__).parent / "artifacts" / "eggs.json" tree = get_tree_from_path(fname) + bt_cleaner.register(tree) # start mock IOC # NOTE: assumes test is being run from top level of run_example_ioc( @@ -43,7 +44,7 @@ def test_tree_obj_execution(request): assert rel_val >= 100 -def test_father_tree_execution(request): +def test_father_tree_execution(request, bt_cleaner): run_example_ioc( "beams.tests.mock_iocs.ImagerNaysh", request=request, @@ -52,6 +53,7 @@ def test_father_tree_execution(request): fname = Path(__file__).parent / "artifacts" / "eggs2.json" tree = get_tree_from_path(fname) + bt_cleaner.register(tree) tree.setup() ct = 0 while ( @@ -79,7 +81,7 @@ def test_save_tree_item_round_trip(tmp_path: Path): assert loaded_tree.root.name == item.name -def test_stop_hitting_yourself(request): +def test_stop_hitting_yourself(request, bt_cleaner): run_example_ioc( "beams.tests.mock_iocs.IM2L0", request=request, @@ -88,6 +90,7 @@ def test_stop_hitting_yourself(request): fname = Path(__file__).parent / "artifacts" / "im2l0_test.json" tree = get_tree_from_path(fname) + bt_cleaner.register(tree) tree.setup() ct = 0 while (