Skip to content

Commit

Permalink
Merge pull request #30 from tangkong/enh_tree_ser
Browse files Browse the repository at this point in the history
ENH: rework tree serialization, accompanying tests
  • Loading branch information
tangkong authored Aug 9, 2024
2 parents 113c68f + cd8bb71 commit 02e53ff
Show file tree
Hide file tree
Showing 11 changed files with 597 additions and 284 deletions.
138 changes: 138 additions & 0 deletions beams/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
"""
Serialization helpers for apischema.
"""
# Largely based on issue discussions regarding tagged unions.

from collections import defaultdict
from collections.abc import Callable, Iterator
from types import new_class
from typing import (Any, Dict, Generic, List, Tuple, TypeVar, get_origin,
get_type_hints)

from apischema import deserializer, serializer, type_name
from apischema.conversions import Conversion
from apischema.metadata import conversion
from apischema.objects import object_deserialization
from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged
from apischema.utils import to_pascal_case

_alternative_constructors: Dict[type, List[Callable]] = defaultdict(list)
Func = TypeVar("Func", bound=Callable)


def alternative_constructor(func: Func) -> Func:
"""Alternative constructor for a given type."""
return_type = get_type_hints(func)["return"]
_alternative_constructors[get_origin(return_type) or return_type].append(func)
return func


def get_all_subclasses(cls: type) -> Iterator[type]:
"""Recursive implementation of type.__subclasses__"""
for sub_cls in cls.__subclasses__():
yield sub_cls
yield from get_all_subclasses(sub_cls)


Cls = TypeVar("Cls", bound=type)


def _get_generic_name_factory(cls: type, *args: type):
def _capitalized(name: str) -> str:
return name[0].upper() + name[1:]

return "".join((cls.__name__, *(_capitalized(arg.__name__) for arg in args)))


generic_name = type_name(_get_generic_name_factory)


def as_tagged_union(cls: Cls) -> Cls:
"""
Tagged union decorator, to be used on base class.
Supports generics as well, with names generated by way of
`_get_generic_name_factory`.
"""
params = tuple(getattr(cls, "__parameters__", ()))
tagged_union_bases: Tuple[type, ...] = (TaggedUnion,)

# Generic handling is here:
if params:
tagged_union_bases = (TaggedUnion, Generic[params])
generic_name(cls)
prev_init_subclass = getattr(cls, "__init_subclass__", None)

def __init_subclass__(cls, **kwargs):
if prev_init_subclass is not None:
prev_init_subclass(**kwargs)
generic_name(cls)

cls.__init_subclass__ = classmethod(__init_subclass__)

def with_params(cls: type) -> Any:
"""Specify type of Generic if set."""
return cls[params] if params else cls

def serialization() -> Conversion:
"""
Define the serializer Conversion for the tagged union.
source is the base ``cls`` (or ``cls[T]``).
target is the new tagged union class ``TaggedUnion`` which gets the
dictionary {cls.__name__: obj} as its arguments.
"""
annotations = {
# Assume that subclasses have same generic parameters than cls
sub.__name__: Tagged[with_params(sub)]
for sub in get_all_subclasses(cls)
}
namespace = {"__annotations__": annotations}
tagged_union = new_class(
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace)
)
return Conversion(
lambda obj: tagged_union(**{obj.__class__.__name__: obj}),
source=with_params(cls),
target=with_params(tagged_union),
# Conversion must not be inherited because it would lead to
# infinite recursion otherwise
inherited=False,
)

def deserialization() -> Conversion:
"""
Define the deserializer Conversion for the tagged union.
Allows for alternative standalone constructors as per the apischema
example.
"""
annotations: dict[str, Any] = {}
namespace: dict[str, Any] = {"__annotations__": annotations}
for sub in get_all_subclasses(cls):
annotations[sub.__name__] = Tagged[with_params(sub)]
for constructor in _alternative_constructors.get(sub, ()):
# Build the alias of the field
alias = to_pascal_case(constructor.__name__)
# object_deserialization uses get_type_hints, but the constructor
# return type is stringified and the class not defined yet,
# so it must be assigned manually
constructor.__annotations__["return"] = with_params(sub)
# Use object_deserialization to wrap constructor as deserializer
deserialization = object_deserialization(constructor, generic_name)
# Add constructor tagged field with its conversion
annotations[alias] = Tagged[with_params(sub)]
namespace[alias] = Tagged(conversion(deserialization=deserialization))
# Create the deserialization tagged union class
tagged_union = new_class(
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace)
)
return Conversion(
lambda obj: get_tagged(obj)[1],
source=with_params(tagged_union),
target=with_params(cls),
)

deserializer(lazy=deserialization, target=cls)
serializer(lazy=serialization, source=cls)
return cls
40 changes: 29 additions & 11 deletions beams/tests/artifacts/eggs.json
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
{
"name": "self_test",
"check_and_do_type": "CheckPV",
"check_entry": {
"Pv": "PERC:COMP",
"Thresh": 100
},
"do_entry": {
"Pv": "PERC:COMP",
"Mode": "INC",
"Value": 10
"root": {
"CheckAndDoItem": {
"name": "self_test",
"description": "",
"check": {
"name": "",
"description": "",
"pv": "PERC:COMP",
"value": 100,
"operator": "ge"
},
"do": {
"IncPVActionItem": {
"name": "",
"description": "",
"loop_period_sec": 1.0,
"pv": "PERC:COMP",
"increment": 10,
"termination_check": {
"name": "",
"description": "",
"pv": "PERC:COMP",
"value": 100,
"operator": "ge"
}
}
}
}
}
}
}
96 changes: 67 additions & 29 deletions beams/tests/artifacts/eggs2.json
Original file line number Diff line number Diff line change
@@ -1,31 +1,69 @@
{
"name": "fake_reticle",
"children": [
{
"name": "ret_find",
"check_and_do_type": "CheckPV",
"check_entry": {
"Pv": "RET:FOUND",
"Thresh": 1
},
"do_entry": {
"Pv": "RET:FOUND",
"Mode": "SET",
"Value": 1
}
},
{
"name": "ret_insert",
"check_and_do_type": "CheckPV",
"check_entry": {
"Pv": "RET:INSERT",
"Thresh": 1
},
"do_entry": {
"Pv": "RET:INSERT",
"Mode": "SET",
"Value": 1
}
"root": {
"SequenceItem": {
"name": "fake_reticle",
"description": "",
"memory": false,
"children": [
{
"CheckAndDoItem": {
"name": "ret_find",
"description": "",
"check": {
"name": "",
"description": "",
"pv": "RET:FOUND",
"value": 1,
"operator": "ge"
},
"do": {
"SetPVActionItem": {
"name": "",
"description": "",
"loop_period_sec": 1.0,
"pv": "RET:FOUND",
"value": 1,
"termination_check": {
"name": "",
"description": "",
"pv": "RET:FOUND",
"value": 1,
"operator": "ge"
}
}
}
}
},
{
"CheckAndDoItem": {
"name": "ret_insert",
"description": "",
"check": {
"name": "",
"description": "",
"pv": "RET:INSERT",
"value": 1,
"operator": "ge"
},
"do": {
"SetPVActionItem": {
"name": "",
"description": "",
"loop_period_sec": 1.0,
"pv": "RET:INSERT",
"value": 1,
"termination_check": {
"name": "",
"description": "",
"pv": "RET:INSERT",
"value": 1,
"operator": "ge"
}
}
}
}
}
]
}
]
}
}
}
62 changes: 24 additions & 38 deletions beams/tests/test_serialize.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,43 @@
import json

from apischema import deserialize, serialize

from beams.tree_generator.TreeSerializer import (CheckAndDoNodeEntry,
CheckAndDoNodeTypeMode,
CheckEntry, DoEntry, TreeSpec)
from beams.tree_config import (BehaviorTreeItem, CheckAndDoItem, ConditionItem,
ConditionOperator, IncPVActionItem,
SequenceItem, SetPVActionItem)


class TestSerializer:
def test_serialize_basic(self):
def test_serialize_check_and_do():
# c_obj = load_config("config.json")
ce = CheckEntry(Pv="PERC:COMP", Thresh=100)
de = DoEntry(Pv="PERC:COMP", Mode=CheckAndDoNodeTypeMode.INC, Value=10)
eg = CheckAndDoNodeEntry(name="self_test", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce, do_entry=de)

ser = serialize(CheckAndDoNodeEntry, eg)
cond_item = ConditionItem(pv="PERC:COMP", value=100,
operator=ConditionOperator.greater_equal)
action_item = IncPVActionItem(pv="PERC:COMP", increment=10,
termination_check=cond_item)
cnd_item = CheckAndDoItem(name="self_test", check=cond_item, do=action_item)

fname = "beams/tests/artifacts/eggs.json"
tree_item = BehaviorTreeItem(root=cnd_item)
ser = serialize(BehaviorTreeItem, tree_item)
deser = deserialize(BehaviorTreeItem, ser)

with open(fname, 'w') as fd:
json.dump(ser, fd, indent=2)
assert deser == tree_item

with open(fname, 'r') as fd:
deser = json.load(fd)

eg2 = deserialize(CheckAndDoNodeEntry, deser)
assert eg2 == eg

def test_serialize_youre_a_father_now(self):
def test_serialize_youre_a_father_now():
"""
Build children check and dos
"""
# insert reticule if ret is not found
ce1 = CheckEntry(Pv="RET:FOUND", Thresh=1) # TODO: should make a check / set mode
de1 = DoEntry(Pv="RET:FOUND", Mode=CheckAndDoNodeTypeMode.SET, Value=1)
# de = DoEntry(Pv="RET:SET", Mode=CheckAndDoNodeTypeMode.SET, Value=1) # TODO: once we have better feel of caproto plumb this up in mock
eg1 = CheckAndDoNodeEntry(name="ret_find", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce1, do_entry=de1)
ce1 = ConditionItem(pv="RET:FOUND", value=1, operator=ConditionOperator.equal)
de1 = SetPVActionItem(pv="RET:FOUND", value=1, termination_check=ce1)
eg1 = CheckAndDoItem(name="ret_find", check=ce1, do=de1)

# acquire pixel to world frame transform
ce2 = CheckEntry(Pv="RET:INSERT", Thresh=1) # TODO: should make a check / set mode
de2 = DoEntry(Pv="RET:INSERT", Mode=CheckAndDoNodeTypeMode.SET, Value=1)
eg2 = CheckAndDoNodeEntry(name="ret_insert", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce2, do_entry=de2)

eg_root = TreeSpec(name="fake_reticle",
children=[eg1, eg2])
ce2 = ConditionItem(pv="RET:INSERT", value=1, operator=ConditionOperator.equal)
de2 = SetPVActionItem(pv="RET:INSERT", value=1, termination_check=ce2)
eg2 = CheckAndDoItem(name="ret_insert", check=ce2, do=de2)

fname = "beams/tests/artifacts/eggs2.json"
ser = serialize(TreeSpec, eg_root)
with open(fname, 'w') as fd:
json.dump(ser, fd, indent=2)
root_item = SequenceItem(children=[eg1, eg2])
eg_root = BehaviorTreeItem(root=root_item)

with open(fname, 'r') as fd:
deser = json.load(fd)
ser = serialize(BehaviorTreeItem, eg_root)

eg_deser = deserialize(TreeSpec, deser)
eg_deser = deserialize(BehaviorTreeItem, ser)
assert eg_root == eg_deser
Loading

0 comments on commit 02e53ff

Please sign in to comment.