Skip to content

Commit

Permalink
feat: support functools.partial handlers
Browse files Browse the repository at this point in the history
Handlers can now be wrapped by `functools.partial`,
and take arbitrary arguments.
  • Loading branch information
Sacrimento committed Jul 1, 2023
1 parent 2e9d265 commit 4becf5a
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 9 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/), and this project
adheres to [Semantic Versioning](https://semver.org/).

## [Unreleased]

- [unreleased]: https://github.com/rogdham/bigxml/compare/v0.10.0...HEAD

### :rocket: Added

- Support `functools.partial` handlers

## [0.10.0] - 2023-04-22

[0.10.0]: https://github.com/rogdham/bigxml/compare/v0.9.0...v0.10.0
Expand Down
10 changes: 10 additions & 0 deletions docs/src/handlers.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ The methods `iter_from` and `return_from` take _handlers_ as arguments.

A handler can be a generator function taking a _node_ as an argument.

!!! Tip

A handler can take more than a node argument, by using `functools.partial`.
See [this recipe](recipes.md#arbitrary-handler-args) for examples.

Such functions are usually decorated with `xml_handle_element` or `xml_handle_text`, to
restrict the type of nodes they are called with.

Expand Down Expand Up @@ -97,6 +102,11 @@ argument is supplied with the encountered node:
<__main__.Cart object...> for user Alice
<__main__.Cart object...> for user Bob

!!! Tip

`__init__` can take more than a node argument, by using `functools.partial`.
See [this recipe](recipes.md#arbitrary-handler-args) for examples.

### Class methods as sub-handlers

The methods decorated with `xml_handle_element` or `xml_handle_text` are used as
Expand Down
41 changes: 38 additions & 3 deletions docs/src/recipes.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,7 @@ the `__post_init__` method:
!!! Warning

The `node` attribute is an `InitVar`, so that it is passed to `__post_init__` but
not stored in class attributes. It must be the only mandatory field, since the class
is automatically instantiated with only one argument (the node). For more details,
see [class handlers](handlers.md#classes).
not stored in class attributes.

## Yielding data in a class `__init__` {: #yield-in-init }

Expand Down Expand Up @@ -178,6 +176,43 @@ Instead, you can define a custom `xml_handler` method:
product: 9780099580485
END cart parsing for user Bob

## Passing arbitrary arguments to handlers {: #arbitrary-handler-args}

You may want to pass arbitrary arguments to your handlers. This is achievable by
using `functools.partial`.

For example, let's say you only want to yield fast vehicles from this vehicle file:

:::xml filename=vehicles.xml
<vehicles>
<vehicle speed="300">Train</vehicle>
<vehicle speed="30">Boat</vehicle>
<vehicle speed="80">Car</vehicle>
<vehicle speed="900">Plane</vehicle>
</vehicles>

You can pass the desired speed threshold directly to the handler
by using `functools.partial`:

:::python
>>> from functools import partial
>>> @xml_handle_element("vehicles", "vehicle")
... def handler(speed_threshold, node):
... if int(node.attributes['speed']) > speed_threshold:
... yield node.text

>>> with open("vehicles.xml", "rb") as stream:
... for vehicle in Parser(stream).iter_from(partial(handler, 150)):
... print(vehicle)
Train
Plane

!!! Warning

If the parameter is positional-only, it must come before the `node` argument.

This behavior also work for class (and dataclasses) handlers.

## Streams without root {: #no-root }

In some cases, you may be parsing a stream of XML elements that follow each other
Expand Down
20 changes: 16 additions & 4 deletions src/bigxml/handler_creator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import is_dataclass
from functools import partial
from inspect import getmembers, isclass
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -55,7 +56,12 @@ class _HandlerTree:
def __init__(self, path: Tuple[str, ...] = ()) -> None:
self.path: Tuple[str, ...] = path
self.children: Dict[str, _HandlerTree] = {}
self.handler: Optional[Callable[..., Iterable[object]]] = None
self.handler: Optional[
Union[
Callable[..., Iterable[object]],
partial[Callable[..., Iterable[object]]],
]
] = None

def add_handler(
self,
Expand Down Expand Up @@ -126,9 +132,15 @@ def handle(
self, node: Union["XMLElement", "XMLText"]
) -> Optional[Iterable[object]]:
if self.handler:
if isclass(self.handler):
return self._handle_from_class(self.handler, node)
return self.handler(node)
unwrapped_handler = None
if isinstance(self.handler, partial):
unwrapped_handler = self.handler.func
if isclass(unwrapped_handler or self.handler):
return self._handle_from_class(cast(Type[Any], self.handler), node)
return cast(
Callable[[Union["XMLElement", "XMLText"]], Iterable[object]],
self.handler,
)(node)

child: Optional[_HandlerTree] = None
namespace = getattr(node, "namespace", None)
Expand Down
12 changes: 10 additions & 2 deletions src/bigxml/marks.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
from functools import partial
from typing import Tuple

__ATTR_MARK_NAME = "_xml_handlers_on"


def _unwrap_partials(obj: object) -> object:
if isinstance(obj, partial):
return obj.func
return obj


def has_marks(obj: object) -> bool:
return hasattr(obj, __ATTR_MARK_NAME)
return hasattr(_unwrap_partials(obj), __ATTR_MARK_NAME)


def get_marks(obj: object) -> Tuple[Tuple[str, ...], ...]:
return getattr(obj, __ATTR_MARK_NAME, ())
return getattr(_unwrap_partials(obj), __ATTR_MARK_NAME, ())


def add_mark(obj: object, mark: Tuple[str, ...]) -> None:
obj = _unwrap_partials(obj)
marks = get_marks(obj)
marks += (mark,)
setattr(obj, __ATTR_MARK_NAME, marks)
64 changes: 64 additions & 0 deletions tests/unit/test_handler_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,24 @@ def catchall(
test_create_handler(catchall)


@cases(
(("a",), "foo: catchall", "a"),
(("{foo}a",), "foo: catchall", "{foo}a"),
(("d0", "d1"), "foo: catchall", "d0"),
(("d0", "d1", "d2"), "foo: catchall", "d0"),
((":text:",), "foo: catchall", ":text:"),
)
def test_one_partial_catchall(test_create_handler: TEST_CREATE_HANDLER_TYPE) -> None:
def catchall(
ctx: str, node: Union[XMLElement, XMLText]
) -> Iterator[Tuple[str, Union[XMLElement, XMLText]]]:
yield (f"{ctx}: catchall", node)

partial_handler = partial(catchall, "foo")

test_create_handler(partial_handler)


@cases(
(("a",), "0", "a"),
(("{foo}a",), "1", "{foo}a"),
Expand Down Expand Up @@ -448,6 +466,20 @@ class Handler:
assert isinstance(items[0], Handler)


def test_partial_class_without_subhandler() -> None:
@xml_handle_element("x")
class Handler:
def __init__(self, ctx: str) -> None:
self.ctx = ctx

partial_handler = partial(Handler, "foo")
nodes = create_nodes("x", "y")
handler = create_handler(partial_handler)
items = list(handler(nodes[0]))
assert len(items) == 1
assert isinstance(items[0], Handler)


@pytest.mark.parametrize("init_mandatory", [False, True])
@pytest.mark.parametrize("init_optional", [False, True])
def test_class_init(init_mandatory: bool, init_optional: bool) -> None:
Expand Down Expand Up @@ -581,6 +613,21 @@ def __init__(self, node: XMLElement, answer: int) -> None:
assert "Add a default value for dataclass fields" not in str(excinfo.value)


def test_partial_class_multiple_mandatory_parameters() -> None:
@xml_handle_element("x")
class Handler:
def __init__(self, before: str, node: XMLElement, after: str) -> None:
pass

partial_handler = partial(Handler, "before", after="after")
nodes = create_nodes("x", "y")
handler = create_handler(partial_handler)
items = list(handler(nodes[0]))

assert len(items) == 1
assert isinstance(items[0], Handler)


def test_dataclass_init_two_mandatory_parameters() -> None:
@xml_handle_element("x")
@dataclass
Expand All @@ -598,6 +645,23 @@ class Handler:
assert "Add a default value for dataclass fields" in str(excinfo.value)


def test_partial_dataclass_two_mandatory_parameters() -> None:
@xml_handle_element("x")
@dataclass
class Handler:
before: str
node: XMLElement
after: str

partial_handler = partial(Handler, "before", after="after")
nodes = create_nodes("x", "y")
handler = create_handler(partial_handler)
items = list(handler(nodes[0]))

assert len(items) == 1
assert isinstance(items[0], Handler)


def test_class_init_crash() -> None:
@xml_handle_element("x")
class Handler:
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/test_handler_marker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import Iterator, Union

import pytest
Expand All @@ -15,6 +16,16 @@ def fct(node: XMLElement) -> Iterator[str]:
assert get_marks(fct) == (("abc", "def"),)


def test_one_marker_element_on_partial_func() -> None:
@xml_handle_element("abc", "def")
def fct(ctx: str, node: XMLElement) -> Iterator[str]:
yield f"{ctx}: <{node.text}>"

partial_fct = partial(fct, "foo")

assert get_marks(partial_fct) == (("abc", "def"),)


def test_one_maker_element_on_method() -> None:
class Klass:
def __init__(self, multiplier: int) -> None:
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/test_marks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import pytest

from bigxml.marks import add_mark, get_marks, has_marks
Expand All @@ -20,3 +22,21 @@ class Markable:
add_mark(obj, ("def", "ghi", "jkl"))
assert has_marks(obj)
assert get_marks(obj) == (("abc",), ("def", "ghi", "jkl"))


def test_marks_on_partial() -> None:
class Markable:
pass

obj = partial(Markable, "foo")

assert not has_marks(obj)
assert not get_marks(obj)

add_mark(obj, ("abc",))
assert has_marks(obj)
assert get_marks(obj) == (("abc",),)

add_mark(obj, ("def", "ghi", "jkl"))
assert has_marks(obj)
assert get_marks(obj) == (("abc",), ("def", "ghi", "jkl"))

0 comments on commit 4becf5a

Please sign in to comment.