Skip to content

Commit

Permalink
Added RealNumber with custom caster for testing typing classes.
Browse files Browse the repository at this point in the history
  • Loading branch information
timohl committed Nov 25, 2024
1 parent 251aeb7 commit 413d685
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
79 changes: 79 additions & 0 deletions tests/test_pytypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,44 @@ typedef py::typing::TypeVar<"V"> TypeVarV;
} // namespace typevar
#endif

// Custom type for testing arg_name/return_name type hints
// RealNumber:
// in arguments -> float | int,
// in return -> float
// fallback -> complex (just for testing, not really useful here)

struct RealNumber {
double value;
};

namespace pybind11 {
namespace detail {

template <>
struct type_caster<RealNumber> {
PYBIND11_TYPE_CASTER(RealNumber, const_name("complex"));
static constexpr auto arg_name = const_name("Union[float, int]");
static constexpr auto return_name = const_name("float");

static handle cast(const RealNumber &number, return_value_policy, handle) {
return PyFloat_FromDouble(number.value);
}

bool load(handle src, bool) {
if (!src) {
return false;
}
if (!PyFloat_Check(src.ptr()) && !PyLong_Check(src.ptr())) {
return false;
}
value = RealNumber{PyFloat_AsDouble(src.ptr())};
return true;
}
};

} // namespace detail
} // namespace pybind11

TEST_SUBMODULE(pytypes, m) {
m.def("obj_class_name", [](py::handle obj) { return py::detail::obj_class_name(obj.ptr()); });

Expand Down Expand Up @@ -998,4 +1036,45 @@ TEST_SUBMODULE(pytypes, m) {
#else
m.attr("defined_PYBIND11_TEST_PYTYPES_HAS_RANGES") = false;
#endif
m.def("half_of_number", [](const RealNumber &x) { return RealNumber{x.value / 2}; });
m.def("half_of_number_tuple", [](const py::typing::Tuple<RealNumber, RealNumber> &x) {
py::typing::Tuple<RealNumber, RealNumber> result
= py::make_tuple(RealNumber{x[0].cast<RealNumber>().value / 2},
RealNumber{x[1].cast<RealNumber>().value / 2});
return result;
});
m.def("half_of_number_tuple_ellipsis",
[](const py::typing::Tuple<RealNumber, py::ellipsis> &x) {
py::typing::Tuple<RealNumber, py::ellipsis> result(x.size());
for (size_t i = 0; i < x.size(); ++i) {
result[i] = x[i].cast<RealNumber>().value / 2;
}
return result;
});
m.def("half_of_number_list", [](const py::typing::List<RealNumber> &x) {
py::typing::List<RealNumber> result;
for (auto num : x) {
result.append(RealNumber{num.cast<RealNumber>().value / 2});
}
return result;
});
m.def("half_of_number_nested_list",
[](const py::typing::List<py::typing::List<RealNumber>> &x) {
py::typing::List<py::typing::List<RealNumber>> result_lists;
for (auto nums : x) {
py::typing::List<RealNumber> result;
for (auto num : nums) {
result.append(RealNumber{num.cast<RealNumber>().value / 2});
}
result_lists.append(result);
}
return result_lists;
});
m.def("half_of_number_dict", [](const py::typing::Dict<std::string, RealNumber> &x) {
py::typing::Dict<std::string, RealNumber> result;
for (auto it : x) {
result[it.first] = RealNumber{it.second.cast<RealNumber>().value / 2};
}
return result;
});
}
29 changes: 29 additions & 0 deletions tests/test_pytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,3 +1101,32 @@ def test_list_ranges(tested_list, expected):
def test_dict_ranges(tested_dict, expected):
assert m.dict_iterator_default_initialization()
assert m.transform_dict_plus_one(tested_dict) == expected


def test_arg_return_type_hints(doc):
assert doc(m.half_of_number) == "half_of_number(arg0: Union[float, int]) -> float"
assert m.half_of_number(2.0) == 1.0
assert m.half_of_number(2) == 1.0
assert m.half_of_number(0) == 0
assert isinstance(m.half_of_number(0), float)
assert not isinstance(m.half_of_number(0), int)
assert (
doc(m.half_of_number_tuple)
== "half_of_number_tuple(arg0: tuple[Union[float, int], Union[float, int]]) -> tuple[float, float]"
)
assert (
doc(m.half_of_number_tuple_ellipsis)
== "half_of_number_tuple_ellipsis(arg0: tuple[Union[float, int], ...]) -> tuple[float, ...]"
)
assert (
doc(m.half_of_number_list)
== "half_of_number_list(arg0: list[Union[float, int]]) -> list[float]"
)
assert (
doc(m.half_of_number_nested_list)
== "half_of_number_nested_list(arg0: list[list[Union[float, int]]]) -> list[list[float]]"
)
assert (
doc(m.half_of_number_dict)
== "half_of_number_dict(arg0: dict[str, Union[float, int]]) -> dict[str, float]"
)

0 comments on commit 413d685

Please sign in to comment.