From 413d685aae32c7fe8a8967a8d2fc6aff9daf0043 Mon Sep 17 00:00:00 2001 From: Tim Ohliger Date: Mon, 25 Nov 2024 01:22:44 +0100 Subject: [PATCH] Added RealNumber with custom caster for testing typing classes. --- tests/test_pytypes.cpp | 79 ++++++++++++++++++++++++++++++++++++++++++ tests/test_pytypes.py | 29 ++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/tests/test_pytypes.cpp b/tests/test_pytypes.cpp index 8df4cdd3f6..abd49c29e3 100644 --- a/tests/test_pytypes.cpp +++ b/tests/test_pytypes.cpp @@ -137,6 +137,44 @@ typedef py::typing::TypeVar<"V"> TypeVarV; } // namespace typevar #endif +// Custom type for testing arg_name/return_name type hints +// RealNumber: +// in arguments -> float | int, +// in return -> float +// fallback -> complex (just for testing, not really useful here) + +struct RealNumber { + double value; +}; + +namespace pybind11 { +namespace detail { + +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(RealNumber, const_name("complex")); + static constexpr auto arg_name = const_name("Union[float, int]"); + static constexpr auto return_name = const_name("float"); + + static handle cast(const RealNumber &number, return_value_policy, handle) { + return PyFloat_FromDouble(number.value); + } + + bool load(handle src, bool) { + if (!src) { + return false; + } + if (!PyFloat_Check(src.ptr()) && !PyLong_Check(src.ptr())) { + return false; + } + value = RealNumber{PyFloat_AsDouble(src.ptr())}; + return true; + } +}; + +} // namespace detail +} // namespace pybind11 + TEST_SUBMODULE(pytypes, m) { m.def("obj_class_name", [](py::handle obj) { return py::detail::obj_class_name(obj.ptr()); }); @@ -998,4 +1036,45 @@ TEST_SUBMODULE(pytypes, m) { #else m.attr("defined_PYBIND11_TEST_PYTYPES_HAS_RANGES") = false; #endif + m.def("half_of_number", [](const RealNumber &x) { return RealNumber{x.value / 2}; }); + m.def("half_of_number_tuple", [](const py::typing::Tuple &x) { + py::typing::Tuple result + = py::make_tuple(RealNumber{x[0].cast().value / 2}, + RealNumber{x[1].cast().value / 2}); + return result; + }); + m.def("half_of_number_tuple_ellipsis", + [](const py::typing::Tuple &x) { + py::typing::Tuple result(x.size()); + for (size_t i = 0; i < x.size(); ++i) { + result[i] = x[i].cast().value / 2; + } + return result; + }); + m.def("half_of_number_list", [](const py::typing::List &x) { + py::typing::List result; + for (auto num : x) { + result.append(RealNumber{num.cast().value / 2}); + } + return result; + }); + m.def("half_of_number_nested_list", + [](const py::typing::List> &x) { + py::typing::List> result_lists; + for (auto nums : x) { + py::typing::List result; + for (auto num : nums) { + result.append(RealNumber{num.cast().value / 2}); + } + result_lists.append(result); + } + return result_lists; + }); + m.def("half_of_number_dict", [](const py::typing::Dict &x) { + py::typing::Dict result; + for (auto it : x) { + result[it.first] = RealNumber{it.second.cast().value / 2}; + } + return result; + }); } diff --git a/tests/test_pytypes.py b/tests/test_pytypes.py index 9fd24b34f1..6dd01ee31e 100644 --- a/tests/test_pytypes.py +++ b/tests/test_pytypes.py @@ -1101,3 +1101,32 @@ def test_list_ranges(tested_list, expected): def test_dict_ranges(tested_dict, expected): assert m.dict_iterator_default_initialization() assert m.transform_dict_plus_one(tested_dict) == expected + + +def test_arg_return_type_hints(doc): + assert doc(m.half_of_number) == "half_of_number(arg0: Union[float, int]) -> float" + assert m.half_of_number(2.0) == 1.0 + assert m.half_of_number(2) == 1.0 + assert m.half_of_number(0) == 0 + assert isinstance(m.half_of_number(0), float) + assert not isinstance(m.half_of_number(0), int) + assert ( + doc(m.half_of_number_tuple) + == "half_of_number_tuple(arg0: tuple[Union[float, int], Union[float, int]]) -> tuple[float, float]" + ) + assert ( + doc(m.half_of_number_tuple_ellipsis) + == "half_of_number_tuple_ellipsis(arg0: tuple[Union[float, int], ...]) -> tuple[float, ...]" + ) + assert ( + doc(m.half_of_number_list) + == "half_of_number_list(arg0: list[Union[float, int]]) -> list[float]" + ) + assert ( + doc(m.half_of_number_nested_list) + == "half_of_number_nested_list(arg0: list[list[Union[float, int]]]) -> list[list[float]]" + ) + assert ( + doc(m.half_of_number_dict) + == "half_of_number_dict(arg0: dict[str, Union[float, int]]) -> dict[str, float]" + )