-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #30 from tangkong/enh_tree_ser
ENH: rework tree serialization, accompanying tests
- Loading branch information
Showing
11 changed files
with
597 additions
and
284 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
""" | ||
Serialization helpers for apischema. | ||
""" | ||
# Largely based on issue discussions regarding tagged unions. | ||
|
||
from collections import defaultdict | ||
from collections.abc import Callable, Iterator | ||
from types import new_class | ||
from typing import (Any, Dict, Generic, List, Tuple, TypeVar, get_origin, | ||
get_type_hints) | ||
|
||
from apischema import deserializer, serializer, type_name | ||
from apischema.conversions import Conversion | ||
from apischema.metadata import conversion | ||
from apischema.objects import object_deserialization | ||
from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged | ||
from apischema.utils import to_pascal_case | ||
|
||
_alternative_constructors: Dict[type, List[Callable]] = defaultdict(list) | ||
Func = TypeVar("Func", bound=Callable) | ||
|
||
|
||
def alternative_constructor(func: Func) -> Func: | ||
"""Alternative constructor for a given type.""" | ||
return_type = get_type_hints(func)["return"] | ||
_alternative_constructors[get_origin(return_type) or return_type].append(func) | ||
return func | ||
|
||
|
||
def get_all_subclasses(cls: type) -> Iterator[type]: | ||
"""Recursive implementation of type.__subclasses__""" | ||
for sub_cls in cls.__subclasses__(): | ||
yield sub_cls | ||
yield from get_all_subclasses(sub_cls) | ||
|
||
|
||
Cls = TypeVar("Cls", bound=type) | ||
|
||
|
||
def _get_generic_name_factory(cls: type, *args: type): | ||
def _capitalized(name: str) -> str: | ||
return name[0].upper() + name[1:] | ||
|
||
return "".join((cls.__name__, *(_capitalized(arg.__name__) for arg in args))) | ||
|
||
|
||
generic_name = type_name(_get_generic_name_factory) | ||
|
||
|
||
def as_tagged_union(cls: Cls) -> Cls: | ||
""" | ||
Tagged union decorator, to be used on base class. | ||
Supports generics as well, with names generated by way of | ||
`_get_generic_name_factory`. | ||
""" | ||
params = tuple(getattr(cls, "__parameters__", ())) | ||
tagged_union_bases: Tuple[type, ...] = (TaggedUnion,) | ||
|
||
# Generic handling is here: | ||
if params: | ||
tagged_union_bases = (TaggedUnion, Generic[params]) | ||
generic_name(cls) | ||
prev_init_subclass = getattr(cls, "__init_subclass__", None) | ||
|
||
def __init_subclass__(cls, **kwargs): | ||
if prev_init_subclass is not None: | ||
prev_init_subclass(**kwargs) | ||
generic_name(cls) | ||
|
||
cls.__init_subclass__ = classmethod(__init_subclass__) | ||
|
||
def with_params(cls: type) -> Any: | ||
"""Specify type of Generic if set.""" | ||
return cls[params] if params else cls | ||
|
||
def serialization() -> Conversion: | ||
""" | ||
Define the serializer Conversion for the tagged union. | ||
source is the base ``cls`` (or ``cls[T]``). | ||
target is the new tagged union class ``TaggedUnion`` which gets the | ||
dictionary {cls.__name__: obj} as its arguments. | ||
""" | ||
annotations = { | ||
# Assume that subclasses have same generic parameters than cls | ||
sub.__name__: Tagged[with_params(sub)] | ||
for sub in get_all_subclasses(cls) | ||
} | ||
namespace = {"__annotations__": annotations} | ||
tagged_union = new_class( | ||
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) | ||
) | ||
return Conversion( | ||
lambda obj: tagged_union(**{obj.__class__.__name__: obj}), | ||
source=with_params(cls), | ||
target=with_params(tagged_union), | ||
# Conversion must not be inherited because it would lead to | ||
# infinite recursion otherwise | ||
inherited=False, | ||
) | ||
|
||
def deserialization() -> Conversion: | ||
""" | ||
Define the deserializer Conversion for the tagged union. | ||
Allows for alternative standalone constructors as per the apischema | ||
example. | ||
""" | ||
annotations: dict[str, Any] = {} | ||
namespace: dict[str, Any] = {"__annotations__": annotations} | ||
for sub in get_all_subclasses(cls): | ||
annotations[sub.__name__] = Tagged[with_params(sub)] | ||
for constructor in _alternative_constructors.get(sub, ()): | ||
# Build the alias of the field | ||
alias = to_pascal_case(constructor.__name__) | ||
# object_deserialization uses get_type_hints, but the constructor | ||
# return type is stringified and the class not defined yet, | ||
# so it must be assigned manually | ||
constructor.__annotations__["return"] = with_params(sub) | ||
# Use object_deserialization to wrap constructor as deserializer | ||
deserialization = object_deserialization(constructor, generic_name) | ||
# Add constructor tagged field with its conversion | ||
annotations[alias] = Tagged[with_params(sub)] | ||
namespace[alias] = Tagged(conversion(deserialization=deserialization)) | ||
# Create the deserialization tagged union class | ||
tagged_union = new_class( | ||
cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) | ||
) | ||
return Conversion( | ||
lambda obj: get_tagged(obj)[1], | ||
source=with_params(tagged_union), | ||
target=with_params(cls), | ||
) | ||
|
||
deserializer(lazy=deserialization, target=cls) | ||
serializer(lazy=serialization, source=cls) | ||
return cls |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,31 @@ | ||
{ | ||
"name": "self_test", | ||
"check_and_do_type": "CheckPV", | ||
"check_entry": { | ||
"Pv": "PERC:COMP", | ||
"Thresh": 100 | ||
}, | ||
"do_entry": { | ||
"Pv": "PERC:COMP", | ||
"Mode": "INC", | ||
"Value": 10 | ||
"root": { | ||
"CheckAndDoItem": { | ||
"name": "self_test", | ||
"description": "", | ||
"check": { | ||
"name": "", | ||
"description": "", | ||
"pv": "PERC:COMP", | ||
"value": 100, | ||
"operator": "ge" | ||
}, | ||
"do": { | ||
"IncPVActionItem": { | ||
"name": "", | ||
"description": "", | ||
"loop_period_sec": 1.0, | ||
"pv": "PERC:COMP", | ||
"increment": 10, | ||
"termination_check": { | ||
"name": "", | ||
"description": "", | ||
"pv": "PERC:COMP", | ||
"value": 100, | ||
"operator": "ge" | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,69 @@ | ||
{ | ||
"name": "fake_reticle", | ||
"children": [ | ||
{ | ||
"name": "ret_find", | ||
"check_and_do_type": "CheckPV", | ||
"check_entry": { | ||
"Pv": "RET:FOUND", | ||
"Thresh": 1 | ||
}, | ||
"do_entry": { | ||
"Pv": "RET:FOUND", | ||
"Mode": "SET", | ||
"Value": 1 | ||
} | ||
}, | ||
{ | ||
"name": "ret_insert", | ||
"check_and_do_type": "CheckPV", | ||
"check_entry": { | ||
"Pv": "RET:INSERT", | ||
"Thresh": 1 | ||
}, | ||
"do_entry": { | ||
"Pv": "RET:INSERT", | ||
"Mode": "SET", | ||
"Value": 1 | ||
} | ||
"root": { | ||
"SequenceItem": { | ||
"name": "fake_reticle", | ||
"description": "", | ||
"memory": false, | ||
"children": [ | ||
{ | ||
"CheckAndDoItem": { | ||
"name": "ret_find", | ||
"description": "", | ||
"check": { | ||
"name": "", | ||
"description": "", | ||
"pv": "RET:FOUND", | ||
"value": 1, | ||
"operator": "ge" | ||
}, | ||
"do": { | ||
"SetPVActionItem": { | ||
"name": "", | ||
"description": "", | ||
"loop_period_sec": 1.0, | ||
"pv": "RET:FOUND", | ||
"value": 1, | ||
"termination_check": { | ||
"name": "", | ||
"description": "", | ||
"pv": "RET:FOUND", | ||
"value": 1, | ||
"operator": "ge" | ||
} | ||
} | ||
} | ||
} | ||
}, | ||
{ | ||
"CheckAndDoItem": { | ||
"name": "ret_insert", | ||
"description": "", | ||
"check": { | ||
"name": "", | ||
"description": "", | ||
"pv": "RET:INSERT", | ||
"value": 1, | ||
"operator": "ge" | ||
}, | ||
"do": { | ||
"SetPVActionItem": { | ||
"name": "", | ||
"description": "", | ||
"loop_period_sec": 1.0, | ||
"pv": "RET:INSERT", | ||
"value": 1, | ||
"termination_check": { | ||
"name": "", | ||
"description": "", | ||
"pv": "RET:INSERT", | ||
"value": 1, | ||
"operator": "ge" | ||
} | ||
} | ||
} | ||
} | ||
} | ||
] | ||
} | ||
] | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,57 +1,43 @@ | ||
import json | ||
|
||
from apischema import deserialize, serialize | ||
|
||
from beams.tree_generator.TreeSerializer import (CheckAndDoNodeEntry, | ||
CheckAndDoNodeTypeMode, | ||
CheckEntry, DoEntry, TreeSpec) | ||
from beams.tree_config import (BehaviorTreeItem, CheckAndDoItem, ConditionItem, | ||
ConditionOperator, IncPVActionItem, | ||
SequenceItem, SetPVActionItem) | ||
|
||
|
||
class TestSerializer: | ||
def test_serialize_basic(self): | ||
def test_serialize_check_and_do(): | ||
# c_obj = load_config("config.json") | ||
ce = CheckEntry(Pv="PERC:COMP", Thresh=100) | ||
de = DoEntry(Pv="PERC:COMP", Mode=CheckAndDoNodeTypeMode.INC, Value=10) | ||
eg = CheckAndDoNodeEntry(name="self_test", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce, do_entry=de) | ||
|
||
ser = serialize(CheckAndDoNodeEntry, eg) | ||
cond_item = ConditionItem(pv="PERC:COMP", value=100, | ||
operator=ConditionOperator.greater_equal) | ||
action_item = IncPVActionItem(pv="PERC:COMP", increment=10, | ||
termination_check=cond_item) | ||
cnd_item = CheckAndDoItem(name="self_test", check=cond_item, do=action_item) | ||
|
||
fname = "beams/tests/artifacts/eggs.json" | ||
tree_item = BehaviorTreeItem(root=cnd_item) | ||
ser = serialize(BehaviorTreeItem, tree_item) | ||
deser = deserialize(BehaviorTreeItem, ser) | ||
|
||
with open(fname, 'w') as fd: | ||
json.dump(ser, fd, indent=2) | ||
assert deser == tree_item | ||
|
||
with open(fname, 'r') as fd: | ||
deser = json.load(fd) | ||
|
||
eg2 = deserialize(CheckAndDoNodeEntry, deser) | ||
assert eg2 == eg | ||
|
||
def test_serialize_youre_a_father_now(self): | ||
def test_serialize_youre_a_father_now(): | ||
""" | ||
Build children check and dos | ||
""" | ||
# insert reticule if ret is not found | ||
ce1 = CheckEntry(Pv="RET:FOUND", Thresh=1) # TODO: should make a check / set mode | ||
de1 = DoEntry(Pv="RET:FOUND", Mode=CheckAndDoNodeTypeMode.SET, Value=1) | ||
# de = DoEntry(Pv="RET:SET", Mode=CheckAndDoNodeTypeMode.SET, Value=1) # TODO: once we have better feel of caproto plumb this up in mock | ||
eg1 = CheckAndDoNodeEntry(name="ret_find", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce1, do_entry=de1) | ||
ce1 = ConditionItem(pv="RET:FOUND", value=1, operator=ConditionOperator.equal) | ||
de1 = SetPVActionItem(pv="RET:FOUND", value=1, termination_check=ce1) | ||
eg1 = CheckAndDoItem(name="ret_find", check=ce1, do=de1) | ||
|
||
# acquire pixel to world frame transform | ||
ce2 = CheckEntry(Pv="RET:INSERT", Thresh=1) # TODO: should make a check / set mode | ||
de2 = DoEntry(Pv="RET:INSERT", Mode=CheckAndDoNodeTypeMode.SET, Value=1) | ||
eg2 = CheckAndDoNodeEntry(name="ret_insert", check_and_do_type=CheckAndDoNodeEntry.CheckAndDoNodeType.CHECKPV, check_entry=ce2, do_entry=de2) | ||
|
||
eg_root = TreeSpec(name="fake_reticle", | ||
children=[eg1, eg2]) | ||
ce2 = ConditionItem(pv="RET:INSERT", value=1, operator=ConditionOperator.equal) | ||
de2 = SetPVActionItem(pv="RET:INSERT", value=1, termination_check=ce2) | ||
eg2 = CheckAndDoItem(name="ret_insert", check=ce2, do=de2) | ||
|
||
fname = "beams/tests/artifacts/eggs2.json" | ||
ser = serialize(TreeSpec, eg_root) | ||
with open(fname, 'w') as fd: | ||
json.dump(ser, fd, indent=2) | ||
root_item = SequenceItem(children=[eg1, eg2]) | ||
eg_root = BehaviorTreeItem(root=root_item) | ||
|
||
with open(fname, 'r') as fd: | ||
deser = json.load(fd) | ||
ser = serialize(BehaviorTreeItem, eg_root) | ||
|
||
eg_deser = deserialize(TreeSpec, deser) | ||
eg_deser = deserialize(BehaviorTreeItem, ser) | ||
assert eg_root == eg_deser |
Oops, something went wrong.