Skip to content

Commit

Permalink
Enable passing absl::Span<bool> and absl::Span<const bool>
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 590290022
  • Loading branch information
Ralf W. Grosse-Kunstleve authored and copybara-github committed Dec 12, 2023
1 parent f37d445 commit 67491a4
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 18 deletions.
64 changes: 51 additions & 13 deletions pybind11_abseil/absl_casters.h
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ namespace internal {
template <typename T>
static constexpr bool is_buffer_interface_compatible_type =
detail::is_same_ignoring_cvref<T, PyObject*>::value ||
std::is_arithmetic<T>::value ||
std::is_arithmetic<std::remove_cv_t<T>>::value ||
std::is_same<T, std::complex<float>>::value ||
std::is_same<T, std::complex<double>>::value;

Expand All @@ -405,7 +405,8 @@ std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle src) {
if (PyObject_GetBuffer(src.ptr(), &view, flags) == 0) {
auto cleanup = absl::MakeCleanup([&view] { PyBuffer_Release(&view); });
if (view.ndim == 1 && view.strides[0] == sizeof(T) &&
buffer_info(&view, /*ownview=*/false).item_type_is_equivalent_to<T>()) {
buffer_info(&view, /*ownview=*/false)
.item_type_is_equivalent_to<std::remove_cv_t<T>>()) {
return {true, absl::MakeSpan(static_cast<T*>(view.buf), view.shape[0])};
}
} else {
Expand All @@ -421,6 +422,29 @@ constexpr std::tuple<bool, absl::Span<T>> LoadSpanFromBuffer(handle /*src*/) {
return {false, absl::Span<T>()};
}

template <typename T,
typename std::enable_if<
!std::is_same<std::remove_cv_t<T>, bool>::value, int>::type = 0>
std::tuple<bool, absl::Span<T>> LoadSpanOpaqueVector(handle src) {
// Attempt to unwrap an opaque std::vector.
using value_type = std::remove_cv_t<T>;
type_caster_base<std::vector<value_type>> caster;
if (caster.load(src, false)) {
return {true,
absl::MakeSpan(static_cast<std::vector<value_type>&>(caster))};
}
return {false, absl::Span<T>()};
}

template <typename T,
typename std::enable_if<
std::is_same<std::remove_cv_t<T>, bool>::value, int>::type = 0>
std::tuple<bool, absl::Span<T>> LoadSpanOpaqueVector(handle src) {
// std::vector<bool> is special and cannot directly be converted to a Span
// (see https://en.cppreference.com/w/cpp/container/vector_bool).
return {false, absl::Span<T>()};
}

// Helper to determine whether T is a span.
template <typename T>
struct is_absl_span : std::false_type {};
Expand All @@ -433,7 +457,7 @@ template <typename T>
struct type_caster<absl::Span<T>> {
public:
// The type referenced by the span, with const removed.
using value_type = typename std::remove_cv<T>::type;
using value_type = std::remove_cv_t<T>;
static_assert(!is_absl_span<value_type>::value,
"Nested absl spans are not supported.");

Expand Down Expand Up @@ -479,19 +503,17 @@ struct type_caster<absl::Span<T>> {
std::tie(loaded, value_) = LoadSpanFromBuffer<T>(src);
if (loaded) return true;

// Attempt to unwrap an opaque std::vector.
type_caster_base<std::vector<value_type>> caster;
if (caster.load(src, false)) {
value_ = get_value(caster);
return true;
}
std::tie(loaded, value_) = LoadSpanOpaqueVector<T>(src);
if (loaded) return true;

// Attempt to convert a native sequence. If the is_base_of_v check passes,
// Attempt to convert a native sequence. If the is_base_of check passes,
// the elements do not require converting and pointers do not reference a
// temporary object owned by the element caster. Pointers to converted
// types are not allowed because they would result a dangling reference
// when the element caster is destroyed.
if (convert && std::is_const<T>::value &&
// See comment for ephemeral_storage_type below.
!std::is_same<T, const bool>::value &&
(!std::is_pointer<T>::value ||
std::is_base_of<type_caster_generic, make_caster<T>>::value)) {
list_caster_.emplace();
Expand All @@ -512,12 +534,28 @@ struct type_caster<absl::Span<T>> {
}

private:
template <typename Caster>
// Unfortunately using std::vector as ephemeral_storage_type creates
// complications for std::vector<bool>
// (https://en.cppreference.com/w/cpp/container/vector_bool).
using ephemeral_storage_type = std::vector<value_type>;

template <
typename Caster, typename VT = value_type,
typename std::enable_if<!std::is_same<VT, bool>::value, int>::type = 0>
absl::Span<T> get_value(Caster& caster) {
return absl::MakeSpan(static_cast<std::vector<value_type>&>(caster));
return absl::MakeSpan(static_cast<ephemeral_storage_type&>(caster));
}

// This template specialization is needed to avoid compilation errors.
// The conditions in load() make this code unreachable.
template <
typename Caster, typename VT = value_type,
typename std::enable_if<std::is_same<VT, bool>::value, int>::type = 0>
absl::Span<T> get_value(Caster&) {
throw std::runtime_error("Expected to be unreachable.");
}

using ListCaster = list_caster<std::vector<value_type>, value_type>;
using ListCaster = list_caster<ephemeral_storage_type, value_type>;
absl::optional<ListCaster> list_caster_;
absl::Span<T> value_;
};
Expand Down
14 changes: 14 additions & 0 deletions pybind11_abseil/tests/absl_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,18 @@ std::string PassSpanPyObjectPtr(absl::Span<PyObject*> input_span) {
return result;
}

std::string PassSpanBool(absl::Span<bool> input_span) {
std::string result;
for (const auto& i : input_span) result += (i ? "t" : "f");
return result;
}

std::string PassSpanConstBool(absl::Span<const bool> input_span) {
std::string result;
for (const auto& i : input_span) result += (i ? "T" : "F");
return result;
}

struct ObjectForSpan {
explicit ObjectForSpan(int v) : value(v) {}
int value;
Expand Down Expand Up @@ -404,6 +416,8 @@ PYBIND11_MODULE(absl_example, m) {
m.def("sum_span_const_complex128",
&SumSpanComplex<const std::complex<double>>, arg("input_span"));
m.def("pass_span_pyobject_ptr", &PassSpanPyObjectPtr, arg("span"));
m.def("pass_span_bool", &PassSpanBool, arg("span"));
m.def("pass_span_const_bool", &PassSpanConstBool, arg("span"));

// Span of objects.
class_<ObjectForSpan>(m, "ObjectForSpan")
Expand Down
32 changes: 27 additions & 5 deletions pybind11_abseil/tests/absl_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def make_read_only_numpy_array():
return values


def make_srided_numpy_array(stride):
def make_strided_numpy_array(stride):
return np.zeros(10, dtype=np.int32)[::stride]


Expand Down Expand Up @@ -373,10 +373,10 @@ def test_fill_span_from_numpy(self):

@parameterized.named_parameters(
('float_numpy', np.zeros(5, dtype=float)),
('two_d_numpy', np.zeros(
(5, 5), dtype=np.int32)), ('read_only', make_read_only_numpy_array()),
('strided_skip', make_srided_numpy_array(2)),
('strided_reverse', make_srided_numpy_array(-1)),
('two_d_numpy', np.zeros((5, 5), dtype=np.int32)),
('read_only', make_read_only_numpy_array()),
('strided_skip', make_strided_numpy_array(2)),
('strided_reverse', make_strided_numpy_array(-1)),
('non_supported_type', np.zeros(5, dtype=np.unicode_)),
('native_list', [0] * 5))
def test_fill_span_fails_from(self, values):
Expand All @@ -397,6 +397,28 @@ def test_pass_span_pyobject_ptr(self):
arr = np.array([-3, 'four', 5.0], dtype=object)
self.assertEqual(absl_example.pass_span_pyobject_ptr(arr), '-3four5.0')

@parameterized.parameters(
([], ''),
([False], 'f'),
([True], 't'),
([False, True, True, False], 'fttf'),
)
def test_pass_span_bool(self, bools, expected):
arr = np.array(bools, dtype=bool)
s = absl_example.pass_span_bool(arr)
self.assertEqual(s, expected)

@parameterized.parameters(
([], ''),
([False], 'F'),
([True], 'T'),
([False, True, True, False], 'FTTF'),
)
def test_pass_span_const_bool(self, bools, expected):
arr = np.array(bools, dtype=bool)
s = absl_example.pass_span_const_bool(arr)
self.assertEqual(s, expected)


def make_native_list_of_objects():
return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]
Expand Down

0 comments on commit 67491a4

Please sign in to comment.