diff --git a/pybind11_abseil/absl_casters.h b/pybind11_abseil/absl_casters.h index 3449afc..fa823c1 100644 --- a/pybind11_abseil/absl_casters.h +++ b/pybind11_abseil/absl_casters.h @@ -387,7 +387,7 @@ namespace internal { template static constexpr bool is_buffer_interface_compatible_type = detail::is_same_ignoring_cvref::value || - std::is_arithmetic::value || + std::is_arithmetic>::value || std::is_same>::value || std::is_same>::value; @@ -405,7 +405,8 @@ std::tuple> LoadSpanFromBuffer(handle src) { if (PyObject_GetBuffer(src.ptr(), &view, flags) == 0) { auto cleanup = absl::MakeCleanup([&view] { PyBuffer_Release(&view); }); if (view.ndim == 1 && view.strides[0] == sizeof(T) && - buffer_info(&view, /*ownview=*/false).item_type_is_equivalent_to()) { + buffer_info(&view, /*ownview=*/false) + .item_type_is_equivalent_to>()) { return {true, absl::MakeSpan(static_cast(view.buf), view.shape[0])}; } } else { @@ -421,6 +422,29 @@ constexpr std::tuple> LoadSpanFromBuffer(handle /*src*/) { return {false, absl::Span()}; } +template , bool>::value, int>::type = 0> +std::tuple> LoadSpanOpaqueVector(handle src) { + // Attempt to unwrap an opaque std::vector. + using value_type = std::remove_cv_t; + type_caster_base> caster; + if (caster.load(src, false)) { + return {true, + absl::MakeSpan(static_cast&>(caster))}; + } + return {false, absl::Span()}; +} + +template , bool>::value, int>::type = 0> +std::tuple> LoadSpanOpaqueVector(handle src) { + // std::vector is special and cannot directly be converted to a Span + // (see https://en.cppreference.com/w/cpp/container/vector_bool). + return {false, absl::Span()}; +} + // Helper to determine whether T is a span. template struct is_absl_span : std::false_type {}; @@ -433,7 +457,7 @@ template struct type_caster> { public: // The type referenced by the span, with const removed. - using value_type = typename std::remove_cv::type; + using value_type = std::remove_cv_t; static_assert(!is_absl_span::value, "Nested absl spans are not supported."); @@ -479,19 +503,17 @@ struct type_caster> { std::tie(loaded, value_) = LoadSpanFromBuffer(src); if (loaded) return true; - // Attempt to unwrap an opaque std::vector. - type_caster_base> caster; - if (caster.load(src, false)) { - value_ = get_value(caster); - return true; - } + std::tie(loaded, value_) = LoadSpanOpaqueVector(src); + if (loaded) return true; - // Attempt to convert a native sequence. If the is_base_of_v check passes, + // Attempt to convert a native sequence. If the is_base_of check passes, // the elements do not require converting and pointers do not reference a // temporary object owned by the element caster. Pointers to converted // types are not allowed because they would result a dangling reference // when the element caster is destroyed. if (convert && std::is_const::value && + // See comment for ephemeral_storage_type below. + !std::is_same::value && (!std::is_pointer::value || std::is_base_of>::value)) { list_caster_.emplace(); @@ -512,12 +534,28 @@ struct type_caster> { } private: - template + // Unfortunately using std::vector as ephemeral_storage_type creates + // complications for std::vector + // (https://en.cppreference.com/w/cpp/container/vector_bool). + using ephemeral_storage_type = std::vector; + + template < + typename Caster, typename VT = value_type, + typename std::enable_if::value, int>::type = 0> absl::Span get_value(Caster& caster) { - return absl::MakeSpan(static_cast&>(caster)); + return absl::MakeSpan(static_cast(caster)); + } + + // This template specialization is needed to avoid compilation errors. + // The conditions in load() make this code unreachable. + template < + typename Caster, typename VT = value_type, + typename std::enable_if::value, int>::type = 0> + absl::Span get_value(Caster&) { + throw std::runtime_error("Expected to be unreachable."); } - using ListCaster = list_caster, value_type>; + using ListCaster = list_caster; absl::optional list_caster_; absl::Span value_; }; diff --git a/pybind11_abseil/tests/absl_example.cc b/pybind11_abseil/tests/absl_example.cc index 6ba376d..91fdec4 100644 --- a/pybind11_abseil/tests/absl_example.cc +++ b/pybind11_abseil/tests/absl_example.cc @@ -279,6 +279,18 @@ std::string PassSpanPyObjectPtr(absl::Span input_span) { return result; } +std::string PassSpanBool(absl::Span input_span) { + std::string result; + for (const auto& i : input_span) result += (i ? "t" : "f"); + return result; +} + +std::string PassSpanConstBool(absl::Span input_span) { + std::string result; + for (const auto& i : input_span) result += (i ? "T" : "F"); + return result; +} + struct ObjectForSpan { explicit ObjectForSpan(int v) : value(v) {} int value; @@ -404,6 +416,8 @@ PYBIND11_MODULE(absl_example, m) { m.def("sum_span_const_complex128", &SumSpanComplex>, arg("input_span")); m.def("pass_span_pyobject_ptr", &PassSpanPyObjectPtr, arg("span")); + m.def("pass_span_bool", &PassSpanBool, arg("span")); + m.def("pass_span_const_bool", &PassSpanConstBool, arg("span")); // Span of objects. class_(m, "ObjectForSpan") diff --git a/pybind11_abseil/tests/absl_test.py b/pybind11_abseil/tests/absl_test.py index 085c044..555ad87 100644 --- a/pybind11_abseil/tests/absl_test.py +++ b/pybind11_abseil/tests/absl_test.py @@ -312,7 +312,7 @@ def make_read_only_numpy_array(): return values -def make_srided_numpy_array(stride): +def make_strided_numpy_array(stride): return np.zeros(10, dtype=np.int32)[::stride] @@ -373,10 +373,10 @@ def test_fill_span_from_numpy(self): @parameterized.named_parameters( ('float_numpy', np.zeros(5, dtype=float)), - ('two_d_numpy', np.zeros( - (5, 5), dtype=np.int32)), ('read_only', make_read_only_numpy_array()), - ('strided_skip', make_srided_numpy_array(2)), - ('strided_reverse', make_srided_numpy_array(-1)), + ('two_d_numpy', np.zeros((5, 5), dtype=np.int32)), + ('read_only', make_read_only_numpy_array()), + ('strided_skip', make_strided_numpy_array(2)), + ('strided_reverse', make_strided_numpy_array(-1)), ('non_supported_type', np.zeros(5, dtype=np.unicode_)), ('native_list', [0] * 5)) def test_fill_span_fails_from(self, values): @@ -397,6 +397,28 @@ def test_pass_span_pyobject_ptr(self): arr = np.array([-3, 'four', 5.0], dtype=object) self.assertEqual(absl_example.pass_span_pyobject_ptr(arr), '-3four5.0') + @parameterized.parameters( + ([], ''), + ([False], 'f'), + ([True], 't'), + ([False, True, True, False], 'fttf'), + ) + def test_pass_span_bool(self, bools, expected): + arr = np.array(bools, dtype=bool) + s = absl_example.pass_span_bool(arr) + self.assertEqual(s, expected) + + @parameterized.parameters( + ([], ''), + ([False], 'F'), + ([True], 'T'), + ([False, True, True, False], 'FTTF'), + ) + def test_pass_span_const_bool(self, bools, expected): + arr = np.array(bools, dtype=bool) + s = absl_example.pass_span_const_bool(arr) + self.assertEqual(s, expected) + def make_native_list_of_objects(): return [absl_example.ObjectForSpan(3), absl_example.ObjectForSpan(5)]