From 4becf5aaf87fa6728ff08383e5589920db394b00 Mon Sep 17 00:00:00 2001 From: Sacrimento Date: Sat, 1 Jul 2023 20:59:36 +0200 Subject: [PATCH] feat: support `functools.partial` handlers Handlers can now be wrapped by `functools.partial`, and take arbitrary arguments. --- CHANGELOG.md | 8 ++++ docs/src/handlers.md | 10 +++++ docs/src/recipes.md | 41 +++++++++++++++++-- src/bigxml/handler_creator.py | 20 ++++++++-- src/bigxml/marks.py | 12 +++++- tests/unit/test_handler_creator.py | 64 ++++++++++++++++++++++++++++++ tests/unit/test_handler_marker.py | 11 +++++ tests/unit/test_marks.py | 20 ++++++++++ 8 files changed, 177 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e1c69d..d211718 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/), and this project adheres to [Semantic Versioning](https://semver.org/). +## [Unreleased] + +- [unreleased]: https://github.com/rogdham/bigxml/compare/v0.10.0...HEAD + +### :rocket: Added + +- Support `functools.partial` handlers + ## [0.10.0] - 2023-04-22 [0.10.0]: https://github.com/rogdham/bigxml/compare/v0.9.0...v0.10.0 diff --git a/docs/src/handlers.md b/docs/src/handlers.md index 79ddbd2..3fd2d03 100644 --- a/docs/src/handlers.md +++ b/docs/src/handlers.md @@ -6,6 +6,11 @@ The methods `iter_from` and `return_from` take _handlers_ as arguments. A handler can be a generator function taking a _node_ as an argument. +!!! Tip + + A handler can take more than a node argument, by using `functools.partial`. + See [this recipe](recipes.md#arbitrary-handler-args) for examples. + Such functions are usually decorated with `xml_handle_element` or `xml_handle_text`, to restrict the type of nodes they are called with. @@ -97,6 +102,11 @@ argument is supplied with the encountered node: <__main__.Cart object...> for user Alice <__main__.Cart object...> for user Bob +!!! Tip + + `__init__` can take more than a node argument, by using `functools.partial`. + See [this recipe](recipes.md#arbitrary-handler-args) for examples. + ### Class methods as sub-handlers The methods decorated with `xml_handle_element` or `xml_handle_text` are used as diff --git a/docs/src/recipes.md b/docs/src/recipes.md index e251471..86610d0 100644 --- a/docs/src/recipes.md +++ b/docs/src/recipes.md @@ -125,9 +125,7 @@ the `__post_init__` method: !!! Warning The `node` attribute is an `InitVar`, so that it is passed to `__post_init__` but - not stored in class attributes. It must be the only mandatory field, since the class - is automatically instantiated with only one argument (the node). For more details, - see [class handlers](handlers.md#classes). + not stored in class attributes. ## Yielding data in a class `__init__` {: #yield-in-init } @@ -178,6 +176,43 @@ Instead, you can define a custom `xml_handler` method: product: 9780099580485 END cart parsing for user Bob +## Passing arbitrary arguments to handlers {: #arbitrary-handler-args} + +You may want to pass arbitrary arguments to your handlers. This is achievable by +using `functools.partial`. + +For example, let's say you only want to yield fast vehicles from this vehicle file: + + :::xml filename=vehicles.xml + + Train + Boat + Car + Plane + + +You can pass the desired speed threshold directly to the handler +by using `functools.partial`: + + :::python + >>> from functools import partial + >>> @xml_handle_element("vehicles", "vehicle") + ... def handler(speed_threshold, node): + ... if int(node.attributes['speed']) > speed_threshold: + ... yield node.text + + >>> with open("vehicles.xml", "rb") as stream: + ... for vehicle in Parser(stream).iter_from(partial(handler, 150)): + ... print(vehicle) + Train + Plane + +!!! Warning + + If the parameter is positional-only, it must come before the `node` argument. + +This behavior also work for class (and dataclasses) handlers. + ## Streams without root {: #no-root } In some cases, you may be parsing a stream of XML elements that follow each other diff --git a/src/bigxml/handler_creator.py b/src/bigxml/handler_creator.py index 4982238..c4fae2d 100644 --- a/src/bigxml/handler_creator.py +++ b/src/bigxml/handler_creator.py @@ -1,4 +1,5 @@ from dataclasses import is_dataclass +from functools import partial from inspect import getmembers, isclass from typing import ( TYPE_CHECKING, @@ -55,7 +56,12 @@ class _HandlerTree: def __init__(self, path: Tuple[str, ...] = ()) -> None: self.path: Tuple[str, ...] = path self.children: Dict[str, _HandlerTree] = {} - self.handler: Optional[Callable[..., Iterable[object]]] = None + self.handler: Optional[ + Union[ + Callable[..., Iterable[object]], + partial[Callable[..., Iterable[object]]], + ] + ] = None def add_handler( self, @@ -126,9 +132,15 @@ def handle( self, node: Union["XMLElement", "XMLText"] ) -> Optional[Iterable[object]]: if self.handler: - if isclass(self.handler): - return self._handle_from_class(self.handler, node) - return self.handler(node) + unwrapped_handler = None + if isinstance(self.handler, partial): + unwrapped_handler = self.handler.func + if isclass(unwrapped_handler or self.handler): + return self._handle_from_class(cast(Type[Any], self.handler), node) + return cast( + Callable[[Union["XMLElement", "XMLText"]], Iterable[object]], + self.handler, + )(node) child: Optional[_HandlerTree] = None namespace = getattr(node, "namespace", None) diff --git a/src/bigxml/marks.py b/src/bigxml/marks.py index ca1921e..af8838e 100644 --- a/src/bigxml/marks.py +++ b/src/bigxml/marks.py @@ -1,17 +1,25 @@ +from functools import partial from typing import Tuple __ATTR_MARK_NAME = "_xml_handlers_on" +def _unwrap_partials(obj: object) -> object: + if isinstance(obj, partial): + return obj.func + return obj + + def has_marks(obj: object) -> bool: - return hasattr(obj, __ATTR_MARK_NAME) + return hasattr(_unwrap_partials(obj), __ATTR_MARK_NAME) def get_marks(obj: object) -> Tuple[Tuple[str, ...], ...]: - return getattr(obj, __ATTR_MARK_NAME, ()) + return getattr(_unwrap_partials(obj), __ATTR_MARK_NAME, ()) def add_mark(obj: object, mark: Tuple[str, ...]) -> None: + obj = _unwrap_partials(obj) marks = get_marks(obj) marks += (mark,) setattr(obj, __ATTR_MARK_NAME, marks) diff --git a/tests/unit/test_handler_creator.py b/tests/unit/test_handler_creator.py index 8fc9f39..9160b84 100644 --- a/tests/unit/test_handler_creator.py +++ b/tests/unit/test_handler_creator.py @@ -149,6 +149,24 @@ def catchall( test_create_handler(catchall) +@cases( + (("a",), "foo: catchall", "a"), + (("{foo}a",), "foo: catchall", "{foo}a"), + (("d0", "d1"), "foo: catchall", "d0"), + (("d0", "d1", "d2"), "foo: catchall", "d0"), + ((":text:",), "foo: catchall", ":text:"), +) +def test_one_partial_catchall(test_create_handler: TEST_CREATE_HANDLER_TYPE) -> None: + def catchall( + ctx: str, node: Union[XMLElement, XMLText] + ) -> Iterator[Tuple[str, Union[XMLElement, XMLText]]]: + yield (f"{ctx}: catchall", node) + + partial_handler = partial(catchall, "foo") + + test_create_handler(partial_handler) + + @cases( (("a",), "0", "a"), (("{foo}a",), "1", "{foo}a"), @@ -448,6 +466,20 @@ class Handler: assert isinstance(items[0], Handler) +def test_partial_class_without_subhandler() -> None: + @xml_handle_element("x") + class Handler: + def __init__(self, ctx: str) -> None: + self.ctx = ctx + + partial_handler = partial(Handler, "foo") + nodes = create_nodes("x", "y") + handler = create_handler(partial_handler) + items = list(handler(nodes[0])) + assert len(items) == 1 + assert isinstance(items[0], Handler) + + @pytest.mark.parametrize("init_mandatory", [False, True]) @pytest.mark.parametrize("init_optional", [False, True]) def test_class_init(init_mandatory: bool, init_optional: bool) -> None: @@ -581,6 +613,21 @@ def __init__(self, node: XMLElement, answer: int) -> None: assert "Add a default value for dataclass fields" not in str(excinfo.value) +def test_partial_class_multiple_mandatory_parameters() -> None: + @xml_handle_element("x") + class Handler: + def __init__(self, before: str, node: XMLElement, after: str) -> None: + pass + + partial_handler = partial(Handler, "before", after="after") + nodes = create_nodes("x", "y") + handler = create_handler(partial_handler) + items = list(handler(nodes[0])) + + assert len(items) == 1 + assert isinstance(items[0], Handler) + + def test_dataclass_init_two_mandatory_parameters() -> None: @xml_handle_element("x") @dataclass @@ -598,6 +645,23 @@ class Handler: assert "Add a default value for dataclass fields" in str(excinfo.value) +def test_partial_dataclass_two_mandatory_parameters() -> None: + @xml_handle_element("x") + @dataclass + class Handler: + before: str + node: XMLElement + after: str + + partial_handler = partial(Handler, "before", after="after") + nodes = create_nodes("x", "y") + handler = create_handler(partial_handler) + items = list(handler(nodes[0])) + + assert len(items) == 1 + assert isinstance(items[0], Handler) + + def test_class_init_crash() -> None: @xml_handle_element("x") class Handler: diff --git a/tests/unit/test_handler_marker.py b/tests/unit/test_handler_marker.py index e27e077..92a77ed 100644 --- a/tests/unit/test_handler_marker.py +++ b/tests/unit/test_handler_marker.py @@ -1,3 +1,4 @@ +from functools import partial from typing import Iterator, Union import pytest @@ -15,6 +16,16 @@ def fct(node: XMLElement) -> Iterator[str]: assert get_marks(fct) == (("abc", "def"),) +def test_one_marker_element_on_partial_func() -> None: + @xml_handle_element("abc", "def") + def fct(ctx: str, node: XMLElement) -> Iterator[str]: + yield f"{ctx}: <{node.text}>" + + partial_fct = partial(fct, "foo") + + assert get_marks(partial_fct) == (("abc", "def"),) + + def test_one_maker_element_on_method() -> None: class Klass: def __init__(self, multiplier: int) -> None: diff --git a/tests/unit/test_marks.py b/tests/unit/test_marks.py index f19f37b..09b7db9 100644 --- a/tests/unit/test_marks.py +++ b/tests/unit/test_marks.py @@ -1,3 +1,5 @@ +from functools import partial + import pytest from bigxml.marks import add_mark, get_marks, has_marks @@ -20,3 +22,21 @@ class Markable: add_mark(obj, ("def", "ghi", "jkl")) assert has_marks(obj) assert get_marks(obj) == (("abc",), ("def", "ghi", "jkl")) + + +def test_marks_on_partial() -> None: + class Markable: + pass + + obj = partial(Markable, "foo") + + assert not has_marks(obj) + assert not get_marks(obj) + + add_mark(obj, ("abc",)) + assert has_marks(obj) + assert get_marks(obj) == (("abc",),) + + add_mark(obj, ("def", "ghi", "jkl")) + assert has_marks(obj) + assert get_marks(obj) == (("abc",), ("def", "ghi", "jkl"))