Skip to content

Commit

Permalink
Add support for array_t<handle> and array_t<object>
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenBaert committed Nov 1, 2024
1 parent 75e48c5 commit fe27f7f
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 31 deletions.
18 changes: 18 additions & 0 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,24 @@ struct npy_format_descriptor<T, enable_if_t<is_same_ignoring_cvref<T, PyObject *
static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
};

template <>
struct npy_format_descriptor<handle, enable_if_t<sizeof(handle) == sizeof(PyObject*)>> {
static constexpr auto name = const_name("object");

static constexpr int value = npy_api::NPY_OBJECT_;

static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
};

template <>
struct npy_format_descriptor<object, enable_if_t<sizeof(object) == sizeof(PyObject*)>> {
static constexpr auto name = const_name("object");

static constexpr int value = npy_api::NPY_OBJECT_;

static pybind11::dtype dtype() { return pybind11::dtype(/*typenum*/ value); }
};

#define PYBIND11_DECL_CHAR_FMT \
static constexpr auto name = const_name("S") + const_name<N>(); \
static pybind11::dtype dtype() { \
Expand Down
53 changes: 52 additions & 1 deletion tests/test_numpy_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,20 +528,71 @@ TEST_SUBMODULE(numpy_array, sm) {
return sum_str_values;
});

sm.def("pass_array_handle_return_sum_str_values",
[](const py::array_t<py::handle> &objs) {
std::string sum_str_values;
for (const auto &obj : objs) {
sum_str_values += py::str(obj.attr("value"));
}
return sum_str_values;
});

sm.def("pass_array_object_return_sum_str_values",
[](const py::array_t<py::object> &objs) {
std::string sum_str_values;
for (const auto &obj : objs) {
sum_str_values += py::str(obj.attr("value"));
}
return sum_str_values;
});

sm.def("pass_array_pyobject_ptr_return_as_list",
[](const py::array_t<PyObject *> &objs) -> py::list { return objs; });

sm.def("pass_array_handle_return_as_list",
[](const py::array_t<py::handle> &objs) -> py::list { return objs; });

sm.def("pass_array_object_return_as_list",
[](const py::array_t<py::object> &objs) -> py::list { return objs; });

sm.def("return_array_pyobject_ptr_cpp_loop", [](const py::list &objs) {
py::size_t arr_size = py::len(objs);
py::array_t<PyObject *> arr_from_list(static_cast<py::ssize_t>(arr_size));
PyObject **data = arr_from_list.mutable_data();
for (py::size_t i = 0; i < arr_size; i++) {
assert(data[i] == nullptr);
data[i] = py::cast<PyObject *>(objs[i].attr("value"));
data[i] = py::cast<PyObject *>(objs[i]);
}
return arr_from_list;
});

sm.def("return_array_handle_cpp_loop", [](const py::list &objs) {
py::size_t arr_size = py::len(objs);
py::array_t<py::handle> arr_from_list(static_cast<py::ssize_t>(arr_size));
py::handle *data = arr_from_list.mutable_data();
for (py::size_t i = 0; i < arr_size; i++) {
assert(data[i] == nullptr);
data[i] = py::object(objs[i]).release();
}
return arr_from_list;
});

sm.def("return_array_object_cpp_loop", [](const py::list &objs) {
py::size_t arr_size = py::len(objs);
py::array_t<py::object> arr_from_list(static_cast<py::ssize_t>(arr_size));
py::object *data = arr_from_list.mutable_data();
for (py::size_t i = 0; i < arr_size; i++) {
data[i] = objs[i];
}
return arr_from_list;
});

sm.def("return_array_pyobject_ptr_from_list",
[](const py::list &objs) -> py::array_t<PyObject *> { return objs; });

sm.def("return_array_handle_from_list",
[](const py::list &objs) -> py::array_t<py::handle> { return objs; });

sm.def("return_array_object_from_list",
[](const py::list &objs) -> py::array_t<py::object> { return objs; });
}
105 changes: 75 additions & 30 deletions tests/test_numpy_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,12 @@ def test_round_trip_float():
# * Sanitizers are much more likely to detect heap-use-after-free due to
# other ref-count bugs.
class PyValueHolder:
counter = 0
def __init__(self, value):
self.value = value
PyValueHolder.counter += 1
def __del__(self):
PyValueHolder.counter -= 1


def WrapWithPyValueHolder(*values):
Expand All @@ -629,45 +633,86 @@ def UnwrapPyValueHolder(vhs):
return [vh.value for vh in vhs]


def test_pass_array_pyobject_ptr_return_sum_str_values_ndarray():
# Intentionally all temporaries, do not change.
assert (
m.pass_array_pyobject_ptr_return_sum_str_values(
np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object)
@pytest.mark.parametrize(
"func",
[
m.pass_array_pyobject_ptr_return_sum_str_values,
m.pass_array_handle_return_sum_str_values,
m.pass_array_object_return_sum_str_values,
],
)
def test_pass_array_object_return_sum_str_values_ndarray(func):
initial_counter = PyValueHolder.counter
for loop in range(100):
# Intentionally all temporaries, do not change.
assert (
func(
np.array(WrapWithPyValueHolder(-3, "four", 5.0), dtype=object)
)
== "-3four5.0"
)
== "-3four5.0"
)
assert PyValueHolder.counter == initial_counter


def test_pass_array_pyobject_ptr_return_sum_str_values_list():
# Intentionally all temporaries, do not change.
assert (
m.pass_array_pyobject_ptr_return_sum_str_values(
WrapWithPyValueHolder(2, "three", -4.0)
@pytest.mark.parametrize(
"func",
[
m.pass_array_pyobject_ptr_return_sum_str_values,
m.pass_array_handle_return_sum_str_values,
m.pass_array_object_return_sum_str_values,
],
)
def test_pass_array_object_return_sum_str_values_list(func):
initial_counter = PyValueHolder.counter
for loop in range(100):
# Intentionally all temporaries, do not change.
assert (
func(
WrapWithPyValueHolder(2, "three", -4.0)
)
== "2three-4.0"
)
== "2three-4.0"
)
assert PyValueHolder.counter == initial_counter


def test_pass_array_pyobject_ptr_return_as_list():
# Intentionally all temporaries, do not change.
assert UnwrapPyValueHolder(
m.pass_array_pyobject_ptr_return_as_list(
np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object)
)
) == [-1, "two", 3.0]
@pytest.mark.parametrize(
"func",
[
m.pass_array_pyobject_ptr_return_as_list,
m.pass_array_handle_return_as_list,
m.pass_array_object_return_as_list,
],
)
def test_pass_array_object_return_as_list(func):
initial_counter = PyValueHolder.counter
for loop in range(100):
# Intentionally all temporaries, do not change.
assert UnwrapPyValueHolder(
func(
np.array(WrapWithPyValueHolder(-1, "two", 3.0), dtype=object)
)
) == [-1, "two", 3.0]
assert PyValueHolder.counter == initial_counter


@pytest.mark.parametrize(
("return_array_pyobject_ptr", "unwrap"),
"func",
[
(m.return_array_pyobject_ptr_cpp_loop, list),
(m.return_array_pyobject_ptr_from_list, UnwrapPyValueHolder),
m.return_array_pyobject_ptr_cpp_loop,
m.return_array_handle_cpp_loop,
m.return_array_object_cpp_loop,
m.return_array_pyobject_ptr_from_list,
m.return_array_handle_from_list,
m.return_array_object_from_list,
],
)
def test_return_array_pyobject_ptr_cpp_loop(return_array_pyobject_ptr, unwrap):
# Intentionally all temporaries, do not change.
arr_from_list = return_array_pyobject_ptr(WrapWithPyValueHolder(6, "seven", -8.0))
assert isinstance(arr_from_list, np.ndarray)
assert arr_from_list.dtype == np.dtype("O")
assert unwrap(arr_from_list) == [6, "seven", -8.0]
def test_return_array_object_cpp_loop(func):
initial_counter = PyValueHolder.counter
for loop in range(100):
# Intentionally all temporaries, do not change.
arr_from_list = func(WrapWithPyValueHolder(6, "seven", -8.0))
assert isinstance(arr_from_list, np.ndarray)
assert arr_from_list.dtype == np.dtype("O")
assert UnwrapPyValueHolder(arr_from_list) == [6, "seven", -8.0]
del arr_from_list
assert PyValueHolder.counter == initial_counter

0 comments on commit fe27f7f

Please sign in to comment.