diff --git a/beams/serialization.py b/beams/serialization.py new file mode 100644 index 0000000..69c96fd --- /dev/null +++ b/beams/serialization.py @@ -0,0 +1,138 @@ +""" +Serialization helpers for apischema. +""" +# Largely based on issue discussions regarding tagged unions. + +from collections import defaultdict +from collections.abc import Callable, Iterator +from types import new_class +from typing import (Any, Dict, Generic, List, Tuple, TypeVar, get_origin, + get_type_hints) + +from apischema import deserializer, serializer, type_name +from apischema.conversions import Conversion +from apischema.metadata import conversion +from apischema.objects import object_deserialization +from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged +from apischema.utils import to_pascal_case + +_alternative_constructors: Dict[type, List[Callable]] = defaultdict(list) +Func = TypeVar("Func", bound=Callable) + + +def alternative_constructor(func: Func) -> Func: + """Alternative constructor for a given type.""" + return_type = get_type_hints(func)["return"] + _alternative_constructors[get_origin(return_type) or return_type].append(func) + return func + + +def get_all_subclasses(cls: type) -> Iterator[type]: + """Recursive implementation of type.__subclasses__""" + for sub_cls in cls.__subclasses__(): + yield sub_cls + yield from get_all_subclasses(sub_cls) + + +Cls = TypeVar("Cls", bound=type) + + +def _get_generic_name_factory(cls: type, *args: type): + def _capitalized(name: str) -> str: + return name[0].upper() + name[1:] + + return "".join((cls.__name__, *(_capitalized(arg.__name__) for arg in args))) + + +generic_name = type_name(_get_generic_name_factory) + + +def as_tagged_union(cls: Cls) -> Cls: + """ + Tagged union decorator, to be used on base class. + + Supports generics as well, with names generated by way of + `_get_generic_name_factory`. + """ + params = tuple(getattr(cls, "__parameters__", ())) + tagged_union_bases: Tuple[type, ...] = (TaggedUnion,) + + # Generic handling is here: + if params: + tagged_union_bases = (TaggedUnion, Generic[params]) + generic_name(cls) + prev_init_subclass = getattr(cls, "__init_subclass__", None) + + def __init_subclass__(cls, **kwargs): + if prev_init_subclass is not None: + prev_init_subclass(**kwargs) + generic_name(cls) + + cls.__init_subclass__ = classmethod(__init_subclass__) + + def with_params(cls: type) -> Any: + """Specify type of Generic if set.""" + return cls[params] if params else cls + + def serialization() -> Conversion: + """ + Define the serializer Conversion for the tagged union. + + source is the base ``cls`` (or ``cls[T]``). + target is the new tagged union class ``TaggedUnion`` which gets the + dictionary {cls.__name__: obj} as its arguments. + """ + annotations = { + # Assume that subclasses have same generic parameters than cls + sub.__name__: Tagged[with_params(sub)] + for sub in get_all_subclasses(cls) + } + namespace = {"__annotations__": annotations} + tagged_union = new_class( + cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) + ) + return Conversion( + lambda obj: tagged_union(**{obj.__class__.__name__: obj}), + source=with_params(cls), + target=with_params(tagged_union), + # Conversion must not be inherited because it would lead to + # infinite recursion otherwise + inherited=False, + ) + + def deserialization() -> Conversion: + """ + Define the deserializer Conversion for the tagged union. + + Allows for alternative standalone constructors as per the apischema + example. + """ + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {"__annotations__": annotations} + for sub in get_all_subclasses(cls): + annotations[sub.__name__] = Tagged[with_params(sub)] + for constructor in _alternative_constructors.get(sub, ()): + # Build the alias of the field + alias = to_pascal_case(constructor.__name__) + # object_deserialization uses get_type_hints, but the constructor + # return type is stringified and the class not defined yet, + # so it must be assigned manually + constructor.__annotations__["return"] = with_params(sub) + # Use object_deserialization to wrap constructor as deserializer + deserialization = object_deserialization(constructor, generic_name) + # Add constructor tagged field with its conversion + annotations[alias] = Tagged[with_params(sub)] + namespace[alias] = Tagged(conversion(deserialization=deserialization)) + # Create the deserialization tagged union class + tagged_union = new_class( + cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) + ) + return Conversion( + lambda obj: get_tagged(obj)[1], + source=with_params(tagged_union), + target=with_params(cls), + ) + + deserializer(lazy=deserialization, target=cls) + serializer(lazy=serialization, source=cls) + return cls diff --git a/beams/tests/artifacts/eggs.json b/beams/tests/artifacts/eggs.json index edda130..573d1c8 100644 --- a/beams/tests/artifacts/eggs.json +++ b/beams/tests/artifacts/eggs.json @@ -1,13 +1,31 @@ { - "name": "self_test", - "check_and_do_type": "CheckPV", - "check_entry": { - "Pv": "PERC:COMP", - "Thresh": 100 - }, - "do_entry": { - "Pv": "PERC:COMP", - "Mode": "INC", - "Value": 10 + "root": { + "CheckAndDoItem": { + "name": "self_test", + "description": "", + "check": { + "name": "", + "description": "", + "pv": "PERC:COMP", + "value": 100, + "operator": "ge" + }, + "do": { + "IncPVActionItem": { + "name": "", + "description": "", + "loop_period_sec": 1.0, + "pv": "PERC:COMP", + "increment": 10, + "termination_check": { + "name": "", + "description": "", + "pv": "PERC:COMP", + "value": 100, + "operator": "ge" + } + } + } + } } -} \ No newline at end of file +} diff --git a/beams/tests/artifacts/eggs2.json b/beams/tests/artifacts/eggs2.json index 3e41c6d..1ec0fd7 100644 --- a/beams/tests/artifacts/eggs2.json +++ b/beams/tests/artifacts/eggs2.json @@ -1,31 +1,69 @@ { - "name": "fake_reticle", - "children": [ - { - "name": "ret_find", - "check_and_do_type": "CheckPV", - "check_entry": { - "Pv": "RET:FOUND", - "Thresh": 1 - }, - "do_entry": { - "Pv": "RET:FOUND", - "Mode": "SET", - "Value": 1 - } - }, - { - "name": "ret_insert", - "check_and_do_type": "CheckPV", - "check_entry": { - "Pv": "RET:INSERT", - "Thresh": 1 - }, - "do_entry": { - "Pv": "RET:INSERT", - "Mode": "SET", - "Value": 1 - } + "root": { + "SequenceItem": { + "name": "fake_reticle", + "description": "", + "memory": false, + "children": [ + { + "CheckAndDoItem": { + "name": "ret_find", + "description": "", + "check": { + "name": "", + "description": "", + "pv": "RET:FOUND", + "value": 1, + "operator": "ge" + }, + "do": { + "SetPVActionItem": { + "name": "", + "description": "", + "loop_period_sec": 1.0, + "pv": "RET:FOUND", + "value": 1, + "termination_check": { + "name": "", + "description": "", + "pv": "RET:FOUND", + "value": 1, + "operator": "ge" + } + } + } + } + }, + { + "CheckAndDoItem": { + "name": "ret_insert", + "description": "", + "check": { + "name": "", + "description": "", + "pv": "RET:INSERT", + "value": 1, + "operator": "ge" + }, + "do": { + "SetPVActionItem": { + "name": "", + "description": "", + "loop_period_sec": 1.0, + "pv": "RET:INSERT", + "value": 1, + "termination_check": { + "name": "", + "description": "", + "pv": "RET:INSERT", + "value": 1, + "operator": "ge" + } + } + } + } + } + ] } - ] -} \ No newline at end of file + } +} diff --git a/beams/tests/test_serialize.py b/beams/tests/test_serialize.py index cd519b0..17270dd 100644 --- a/beams/tests/test_serialize.py +++ b/beams/tests/test_serialize.py @@ -1,57 +1,43 @@ -import json - from apischema import deserialize, serialize -from beams.tree_generator.TreeSerializer import (CheckAndDoNodeEntry, - CheckAndDoNodeTypeMode, - CheckEntry, DoEntry, TreeSpec) +from beams.tree_config import (BehaviorTreeItem, CheckAndDoItem, ConditionItem, + ConditionOperator, IncPVActionItem, + SequenceItem, SetPVActionItem) -class TestSerializer: - def test_serialize_basic(self): +def test_serialize_check_and_do(): # c_obj = load_config("config.json") - ce = CheckEntry(Pv="PERC:COMP", Thresh=100) - de = DoEntry(Pv="PERC:COMP", Mode=CheckAndDoNodeTypeMode.INC, Value=10) - eg = CheckAndDoNodeEntry(name="self_test", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce, do_entry=de) - - ser = serialize(CheckAndDoNodeEntry, eg) + cond_item = ConditionItem(pv="PERC:COMP", value=100, + operator=ConditionOperator.greater_equal) + action_item = IncPVActionItem(pv="PERC:COMP", increment=10, + termination_check=cond_item) + cnd_item = CheckAndDoItem(name="self_test", check=cond_item, do=action_item) - fname = "beams/tests/artifacts/eggs.json" + tree_item = BehaviorTreeItem(root=cnd_item) + ser = serialize(BehaviorTreeItem, tree_item) + deser = deserialize(BehaviorTreeItem, ser) - with open(fname, 'w') as fd: - json.dump(ser, fd, indent=2) + assert deser == tree_item - with open(fname, 'r') as fd: - deser = json.load(fd) - eg2 = deserialize(CheckAndDoNodeEntry, deser) - assert eg2 == eg - - def test_serialize_youre_a_father_now(self): +def test_serialize_youre_a_father_now(): """ Build children check and dos """ # insert reticule if ret is not found - ce1 = CheckEntry(Pv="RET:FOUND", Thresh=1) # TODO: should make a check / set mode - de1 = DoEntry(Pv="RET:FOUND", Mode=CheckAndDoNodeTypeMode.SET, Value=1) - # de = DoEntry(Pv="RET:SET", Mode=CheckAndDoNodeTypeMode.SET, Value=1) # TODO: once we have better feel of caproto plumb this up in mock - eg1 = CheckAndDoNodeEntry(name="ret_find", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce1, do_entry=de1) + ce1 = ConditionItem(pv="RET:FOUND", value=1, operator=ConditionOperator.equal) + de1 = SetPVActionItem(pv="RET:FOUND", value=1, termination_check=ce1) + eg1 = CheckAndDoItem(name="ret_find", check=ce1, do=de1) # acquire pixel to world frame transform - ce2 = CheckEntry(Pv="RET:INSERT", Thresh=1) # TODO: should make a check / set mode - de2 = DoEntry(Pv="RET:INSERT", Mode=CheckAndDoNodeTypeMode.SET, Value=1) - eg2 = CheckAndDoNodeEntry(name="ret_insert", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce2, do_entry=de2) - - eg_root = TreeSpec(name="fake_reticle", - children=[eg1, eg2]) + ce2 = ConditionItem(pv="RET:INSERT", value=1, operator=ConditionOperator.equal) + de2 = SetPVActionItem(pv="RET:INSERT", value=1, termination_check=ce2) + eg2 = CheckAndDoItem(name="ret_insert", check=ce2, do=de2) - fname = "beams/tests/artifacts/eggs2.json" - ser = serialize(TreeSpec, eg_root) - with open(fname, 'w') as fd: - json.dump(ser, fd, indent=2) + root_item = SequenceItem(children=[eg1, eg2]) + eg_root = BehaviorTreeItem(root=root_item) - with open(fname, 'r') as fd: - deser = json.load(fd) + ser = serialize(BehaviorTreeItem, eg_root) - eg_deser = deserialize(TreeSpec, deser) + eg_deser = deserialize(BehaviorTreeItem, ser) assert eg_root == eg_deser diff --git a/beams/tests/test_tree_generator.py b/beams/tests/test_tree_generator.py index 4319974..06ebf65 100644 --- a/beams/tests/test_tree_generator.py +++ b/beams/tests/test_tree_generator.py @@ -5,31 +5,21 @@ from caproto.tests.conftest import run_example_ioc from epics import caget -from beams.tree_generator.TreeGenerator import TreeGenerator -from beams.tree_generator.TreeSerializer import (CheckAndDoNodeEntry, - CheckAndDoNodeTypeMode, - CheckEntry, DoEntry, TreeSpec) +from beams.behavior_tree.CheckAndDo import CheckAndDo +from beams.tree_config import get_tree_from_path def test_tree_obj_ser(): fname = Path(__file__).parent / "artifacts" / "eggs.json" - tg = TreeGenerator(fname, CheckAndDoNodeEntry) + tg = get_tree_from_path(fname) - ce = CheckEntry(Pv="PERC:COMP", Thresh=100) - de = DoEntry(Pv="PERC:COMP", Mode=CheckAndDoNodeTypeMode.INC, Value=10) - eg = CheckAndDoNodeEntry( - name="self_test", - check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, - check_entry=ce, - do_entry=de, - ) - - assert tg.tree_spec == eg + assert isinstance(tg, py_trees.trees.BehaviourTree) + assert isinstance(tg.root, CheckAndDo) def test_tree_obj_execution(request): fname = Path(__file__).parent / "artifacts" / "eggs.json" - tg = TreeGenerator(fname, CheckAndDoNodeEntry) + tree = get_tree_from_path(fname) # start mock IOC # NOTE: assumes test is being run from top level of run_example_ioc( @@ -38,16 +28,13 @@ def test_tree_obj_execution(request): pv_to_check="PERC:COMP", ) - tree = tg.get_tree_from_config() - tree.setup_with_descendants() + tree.setup() while ( - tree.status != py_trees.common.Status.SUCCESS - and tree.status != py_trees.common.Status.FAILURE + tree.root.status not in (py_trees.common.Status.SUCCESS, + py_trees.common.Status.FAILURE) ): - for n in tree.tick(): - print(f"ticking: {n}") - time.sleep(0.05) - print(f"status of tick: {n.status}") + tree.tick() + time.sleep(0.05) rel_val = caget("PERC:COMP") assert rel_val >= 100 @@ -60,22 +47,19 @@ def test_father_tree_execution(request): pv_to_check="RET:INSERT", ) - fname = "beams/tests/artifacts/eggs2.json" - tg = TreeGenerator(fname, TreeSpec) - tree = tg.get_tree_from_config() - + fname = Path(__file__).parent / "artifacts" / "eggs2.json" + tree = get_tree_from_path(fname) + tree.setup() ct = 0 while ( - tree.root.status != py_trees.common.Status.SUCCESS - and tree.root.status != py_trees.common.Status.FAILURE + tree.root.status not in (py_trees.common.Status.SUCCESS, + py_trees.common.Status.FAILURE) and ct < 50 ): ct += 1 print((tree.root.status, tree.root.status, ct)) - for n in tree.root.tick(): - print(f"ticking: {n}") - time.sleep(0.05) - print(f"status of tick: {n.status}") + tree.tick() + time.sleep(0.05) check_insert = caget("RET:INSERT") diff --git a/beams/tree_config.py b/beams/tree_config.py new file mode 100644 index 0000000..ac7c28c --- /dev/null +++ b/beams/tree_config.py @@ -0,0 +1,298 @@ +from __future__ import annotations + +import json +import operator +import os +import time +from dataclasses import dataclass, field +from enum import Enum +from multiprocessing import Event, Lock +from pathlib import Path +from typing import Any, Callable, List, Optional + +import py_trees +from apischema import deserialize +from epics import caget, caput +from py_trees.behaviour import Behaviour +from py_trees.common import ParallelPolicy +from py_trees.composites import Parallel, Selector, Sequence + +from beams.behavior_tree.ActionNode import ActionNode +from beams.behavior_tree.CheckAndDo import CheckAndDo +from beams.behavior_tree.ConditionNode import ConditionNode +from beams.serialization import as_tagged_union + + +def get_tree_from_path(path: Path) -> py_trees.trees.BehaviourTree: + """Deserialize a json file, return the tree it specifies""" + with open(path, 'r') as fd: + deser = json.load(fd) + tree_item = deserialize(BehaviorTreeItem, deser) + + return tree_item.get_tree() + + +@dataclass +class BehaviorTreeItem: + root: BaseItem + + def get_tree(self) -> py_trees.trees.BehaviourTree: + return py_trees.trees.BehaviourTree(self.root.get_tree()) + + +@as_tagged_union +@dataclass +class BaseItem: + name: str = '' + description: str = '' + + def get_tree(self) -> Behaviour: + """Get the tree node that this dataclass represents""" + raise NotImplementedError + + +@dataclass +class ExternalItem(BaseItem): + path: str = '' + + def get_tree(self) -> Behaviour: + # grab file + # de-serialize tree, return it + raise NotImplementedError + + +class ParallelMode(Enum): + """Simple enum mimicing the ``py_trees.common.ParallelPolicy`` options""" + Base = "Base" + SuccessOnAll = "SuccesOnAll" + SuccessOnONe = "SuccessOnOne" + SuccessOnSelected = "SuccessOnSelected" + + +@dataclass +class ParallelItem(BaseItem): + policy: ParallelMode = ParallelMode.Base + children: Optional[List[BaseItem]] = field(default_factory=list) + + def get_tree(self) -> Parallel: + children = [] + for child in self.children: + children.append(child.get_tree()) + + node = Parallel( + name=self.name, + policy=getattr(ParallelPolicy, self.policy.value), + children=children + ) + + return node + + +@dataclass +class SelectorItem(BaseItem): + """aka fallback node""" + memory: bool = False + children: Optional[List[BaseItem]] = field(default_factory=list) + + def get_tree(self) -> Selector: + children = [] + for child in self.children: + children.append(child.get_tree()) + + node = Selector( + name=self.name, + memory=self.memory, + children=children + ) + return node + + +@dataclass +class SequenceItem(BaseItem): + memory: bool = False + children: Optional[List[BaseItem]] = field(default_factory=list) + + def get_tree(self) -> Sequence: + children = [] + for child in self.children: + children.append(child.get_tree()) + + node = Sequence( + name=self.name, + memory=self.memory, + children=children + ) + + return node + + +# Custom LCLS-built Behaviors (idioms) +class ConditionOperator(Enum): + equal = 'eq' + not_equal = 'ne' + less = 'lt' + greater = 'gt' + less_equal = 'le' + greater_equal = 'ge' + + +@dataclass +class ConditionItem(BaseItem): + pv: str = '' + value: Any = 1 + operator: ConditionOperator = ConditionOperator.equal + + def get_tree(self) -> ConditionNode: + cond_func = self.get_condition_function() + return ConditionNode(self.name, cond_func) + + def get_condition_function(self) -> Callable[[], bool]: + op = getattr(operator, self.operator.value) + + def cond_func(): + val = caget(self.pv) + if val is None: + return False + + return op(val, self.value) + + return cond_func + + +@as_tagged_union +@dataclass +class ActionItem(BaseItem): + loop_period_sec: float = 1.0 + + +@dataclass +class SetPVActionItem(ActionItem): + pv: str = '' + value: Any = 1 + + termination_check: ConditionItem = field(default_factory=ConditionItem) + + def get_tree(self) -> ActionNode: + wait_for_tick = Event() + wait_for_tick_lock = Lock() + + def work_func(comp_condition, volatile_status): + py_trees.console.logdebug(f"WAITING FOR INIT {os.getpid()} " + f"from node: {self.name}") + wait_for_tick.wait() + + # Set to running + value = 0 + + # While termination_check is not True + while not comp_condition(): # TODO check work_gate.is_set() + py_trees.console.logdebug( + f"CALLING CAGET FROM {os.getpid()} from node: " + f"{self.name}" + ) + value = caget(self.termination_check.pv) + + if comp_condition(): + volatile_status.set_value(py_trees.common.Status.SUCCESS) + py_trees.console.logdebug( + f"{self.name}: Value is {value}, BT Status: " + f"{volatile_status.get_value()}" + ) + + # specific caput logic to SetPVActionItem + caput(self.pv, self.value) + time.sleep(self.loop_period_sec) + + # one last check + if comp_condition(): + volatile_status.set_value(py_trees.common.Status.SUCCESS) + else: + volatile_status.set_value(py_trees.common.Status.FAILURE) + + comp_cond = self.termination_check.get_condition_function() + + node = ActionNode( + name=self.name, + work_func=work_func, + completion_condition=comp_cond, + work_gate=wait_for_tick, + work_lock=wait_for_tick_lock, + ) + + return node + + +@dataclass +class IncPVActionItem(ActionItem): + pv: str = '' + increment: float = 1 + + termination_check: ConditionItem = field(default_factory=ConditionItem) + + # TODO: DRY this out a bit + def get_tree(self) -> ActionNode: + wait_for_tick = Event() + wait_for_tick_lock = Lock() + + def work_func(comp_condition, volatile_status): + py_trees.console.logdebug(f"WAITING FOR INIT {os.getpid()} " + f"from node: {self.name}") + wait_for_tick.wait() + + # Set to running + value = 0 + + # While termination_check is not True + while not comp_condition(): # TODO check work_gate.is_set() + py_trees.console.logdebug( + f"CALLING CAGET FROM {os.getpid()} from node: " + f"{self.name}" + ) + value = caget(self.pv) + + if comp_condition(): + volatile_status.set_value(py_trees.common.Status.SUCCESS) + py_trees.console.logdebug( + f"{self.name}: Value is {value}, BT Status: " + f"{volatile_status.get_value()}" + ) + + # specific caput logic to IncPVActionItem + caput(self.pv, value + self.increment) + time.sleep(self.loop_period_sec) + + # one last check + if comp_condition(): + volatile_status.set_value(py_trees.common.Status.SUCCESS) + else: + volatile_status.set_value(py_trees.common.Status.FAILURE) + + comp_cond = self.termination_check.get_condition_function() + + node = ActionNode( + name=self.name, + work_func=work_func, + completion_condition=comp_cond, + work_gate=wait_for_tick, + work_lock=wait_for_tick_lock, + ) + + return node + + +@dataclass +class CheckAndDoItem(BaseItem): + check: ConditionItem = field(default_factory=ConditionItem) + do: ActionItem = field(default_factory=ActionItem) + + def get_tree(self) -> CheckAndDo: + check_node = self.check.get_tree() + do_node = self.do.get_tree() + + node = CheckAndDo( + name=self.name, + check=check_node, + do=do_node + ) + + return node diff --git a/beams/tree_generator/README.md b/beams/tree_generator/README.md deleted file mode 100644 index d3072a0..0000000 --- a/beams/tree_generator/README.md +++ /dev/null @@ -1,11 +0,0 @@ -# Tree Generator -Contained is mechanisms to: -1. Parse and write to config files that specify tree structure -2. Generate tree objects to be ticked from file contents above - -### TreeSerializer -Uses `apischema` to define file structure of config file and parse contents. -TODO: make resilient to poorly formatted files. hash config file to get some form of version control - -### TreeGenerator -Pass a node type and a file path for the config document and it will generate a tree \ No newline at end of file diff --git a/beams/tree_generator/TreeGenerator.py b/beams/tree_generator/TreeGenerator.py deleted file mode 100644 index a45916a..0000000 --- a/beams/tree_generator/TreeGenerator.py +++ /dev/null @@ -1,18 +0,0 @@ -from beams.sequencer.remote_calls.sequencer_pb2 import GenericMessage - -from apischema import deserialize -import json - - -def GenerateTreeFromRequest(request: GenericMessage): - pass - - -class TreeGenerator(): - def __init__(self, config_fname, node_type): - with open(config_fname, "r+") as fd: - self.tree_spec = deserialize(node_type, json.load(fd)) - - def get_tree_from_config(self): - # check if requested sequence name is in entries - return self.tree_spec.get_tree() diff --git a/beams/tree_generator/TreeSerializer.py b/beams/tree_generator/TreeSerializer.py deleted file mode 100644 index c70acc3..0000000 --- a/beams/tree_generator/TreeSerializer.py +++ /dev/null @@ -1,143 +0,0 @@ -from collections.abc import Collection -from dataclasses import dataclass, field -from uuid import UUID, uuid4 -from apischema import ValidationError, deserialize, serialize -from apischema.json_schema import deserialization_schema -from typing import Optional, List, Union -from enum import Enum -from multiprocessing import Event, Lock -import json -import time -import os - -from epics import caput, caget -import py_trees - - -from beams.behavior_tree.ActionNode import ActionNode -from beams.behavior_tree.CheckAndDo import CheckAndDo -from beams.behavior_tree.ConditionNode import ConditionNode - - -class CheckAndDoNodeTypeMode(Enum): - INC = "INC" # self.Value will be interpreted as the value to INCREMENT the current value by - SET = "SET" # self.Value will be interpreted as the value to SET the current value to - - -# Define a schema with standard dataclasses -@dataclass -class _NodeEntry: - name: str - - def get_tree(self): - raise NotImplementedError("Cannot get tree from abstract base class!") - - -@dataclass -class ActionNodeEntry(_NodeEntry): - - class ActionNodeType(Enum): - CHECKPV = "CheckPV" - - -@dataclass -class ConditonNodeEntry(_NodeEntry): - - class ConditionNodeType(Enum): - CHECKPV = "CheckPV" - - -@dataclass -class CheckEntry(): - Pv: str - Thresh: int - - -@dataclass -class DoEntry(): - Pv: str - Mode: CheckAndDoNodeTypeMode - Value: int - - -@dataclass -class CheckAndDoNodeEntry(_NodeEntry): - - class CheckAndDoNodeType(Enum): - CHECKPV = "CheckPV" - - check_and_do_type: CheckAndDoNodeType - check_entry: CheckEntry - do_entry: DoEntry - - def get_tree(self): - if (self.check_and_do_type == CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV): - # Determine what the lambda will caput: - caput_lambda = lambda : 0 - # if we are in increment mode, produce a function that can increment current value - if (self.do_entry.Mode == CheckAndDoNodeTypeMode.INC): - caput_lambda = lambda x : x + self.do_entry.Value - # if we are in set mode just set it to a value - elif (self.do_entry.Mode == CheckAndDoNodeTypeMode.SET): - caput_lambda = lambda x : self.do_entry.Value - - wait_for_tick = Event() - wait_for_tick_lock = Lock() - - # Work function generator for DO of check and do - def update_pv(comp_condition, volatile_status, **kwargs): - py_trees.console.logdebug(f"WAITING FOR INIT {os.getpid()} from node: {self.name}") - wait_for_tick.wait() - - # Set to running - - value = 0 - while not comp_condition(value): # TODO check work_gate.is_set() - py_trees.console.logdebug(f"CALLING CAGET FROM {os.getpid()} from node: {self.name}") - value = caget(self.check_entry.Pv) - if (value >= self.check_entry.Thresh): # TODO: we are implicitly connecting the check thresh value with the lamda produced from the do. Maybe fix - volatile_status.set_value(py_trees.common.Status.SUCCESS) - py_trees.console.logdebug(f"{self.name}: Value is {value}, BT Status: {volatile_status.get_value()}") - caput(self.do_entry.Pv, caput_lambda(value)) - time.sleep(0.1) # TODO(josh): this is a very important hard coded constant, reflcect on where to place with more visibility - - # TODO: here is where we can build more complex trees - # Build Check Node - def check_func(): - val = caget(self.check_entry.Pv) - if val is None: - return False - return val >= self.check_entry.Thresh - - condition_node = ConditionNode(f"{self.name}_check", check_func) - - # Build Do Node - comp_cond = lambda check_val: check_val > self.check_entry.Thresh - action_node = ActionNode(name=f"{self.name}_do", - work_func=update_pv, - completion_condition=comp_cond, - work_gate=wait_for_tick, - work_lock=wait_for_tick_lock) - - check_and_do_node = CheckAndDo(f"{self.name}_check_and_do_root", condition_node, action_node) - check_and_do_node.setup() - - return check_and_do_node - - -# TODO: Ask if we want this and beams.beahvior_tree.CheckAndDo to share a baseclass... -@dataclass -class TreeSpec(): - name : str - children: Optional[List[CheckAndDoNodeEntry]] = None - - def get_tree(self): - children_trees = [x.get_tree() for x in self.children] - print(children_trees) - self.root = py_trees.composites.Sequence(self.name, memory=True) - self.root.add_children(children_trees) - self.setup() - return self - - def setup(self): - self.root.setup_with_descendants() diff --git a/beams/tree_generator/__init__.py b/beams/tree_generator/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/docs/source/upcoming_release_notes/30-enh_tree_ser.rst b/docs/source/upcoming_release_notes/30-enh_tree_ser.rst new file mode 100644 index 0000000..f9e419a --- /dev/null +++ b/docs/source/upcoming_release_notes/30-enh_tree_ser.rst @@ -0,0 +1,23 @@ +30 enh_tree_ser +############### + +API Breaks +---------- +- Refactors tree serialization, replacing TreeGenerator and TreeSerializer with dataclasses + that each produce the py_trees Behaviour specified by the dataclass. + +Features +-------- +- N/A + +Bugfixes +-------- +- Adjusts ActionNode logic to always set completion status + +Maintenance +----------- +- N/A + +Contributors +------------ +- tangkong