Skip to content

Commit

Permalink
Add static asserts for mdspan_copyable
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks committed Oct 2, 2023
1 parent bd5a8f8 commit a8b17a8
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 15 deletions.
56 changes: 41 additions & 15 deletions cpp/test/core/mdspan_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ TEST(MDSpanCopy, Mdspan1DHostHost)
}

auto out_right = make_host_vector<double, std::uint32_t, layout_f_contiguous>(res, cols);
// std::copy
static_assert(detail::mdspan_copyable<true,
decltype(out_right.view()),
decltype(in_left.view())>::can_use_std_copy,
"Current implementation should use std::copy for this copy");
copy(res, out_right.view(), in_left.view());
for (auto i = std::uint32_t{}; i < cols; ++i) {
ASSERT_TRUE(match(out_right(i), double(gen_unique_entry(i)), CompareApprox<double>{0.0001}));
Expand All @@ -57,8 +60,11 @@ TEST(MDSpanCopy, Mdspan1DHostDevice)
in_left(i) = gen_unique_entry(i);
}

// raft::copy
auto out_right = make_device_vector<float, std::uint32_t, layout_f_contiguous>(res, cols);
static_assert(detail::mdspan_copyable<true,
decltype(out_right.view()),
decltype(in_left.view())>::can_use_raft_copy,
"Current implementation should use raft::copy for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < cols; ++i) {
Expand All @@ -78,8 +84,11 @@ TEST(MDSpanCopy, Mdspan1DDeviceHost)
in_left(i) = gen_unique_entry(i);
}

// raft::copy
auto out_right = make_host_vector<float, std::uint32_t, layout_f_contiguous>(res, cols);
static_assert(detail::mdspan_copyable<true,
decltype(out_right.view()),
decltype(in_left.view())>::can_use_raft_copy,
"Current implementation should use raft::copy for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < cols; ++i) {
Expand All @@ -95,9 +104,9 @@ TEST(MDSpanCopy, Mdspan3DHostHost)
auto constexpr depth = std::uint32_t{500};
auto constexpr rows = std::uint32_t{300};
auto constexpr cols = std::uint32_t{200};
auto in_left = make_host_mdarray<float, std::uint32_t, layout_c_contiguous, depth, rows, cols>(
auto in_left = make_host_mdarray<float, std::uint32_t, layout_f_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});
auto in_right = make_host_mdarray<float, std::uint32_t, layout_f_contiguous, depth, rows, cols>(
auto in_right = make_host_mdarray<float, std::uint32_t, layout_c_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});
auto gen_unique_entry = [](auto&& x, auto&& y, auto&& z) { return x * 7 + y * 11 + z * 13; };

Expand All @@ -112,10 +121,13 @@ TEST(MDSpanCopy, Mdspan3DHostHost)

auto out_left = make_host_mdarray<double, std::uint32_t, layout_f_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});
auto out_right = make_host_mdarray<double, std::uint32_t, layout_f_contiguous, depth, rows, cols>(
auto out_right = make_host_mdarray<double, std::uint32_t, layout_c_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});

// std::copy
static_assert(detail::mdspan_copyable<true,
decltype(out_right.view()),
decltype(in_right.view())>::can_use_std_copy,
"Current implementation should use std::copy for this copy");
copy(res, out_right.view(), in_right.view());
for (auto i = std::uint32_t{}; i < depth; ++i) {
for (auto j = std::uint32_t{}; j < rows; ++j) {
Expand All @@ -126,7 +138,6 @@ TEST(MDSpanCopy, Mdspan3DHostHost)
}
}

// simd or custom logic
copy(res, out_right.view(), in_left.view());
for (auto i = std::uint32_t{}; i < depth; ++i) {
for (auto j = std::uint32_t{}; j < rows; ++j) {
Expand All @@ -137,7 +148,6 @@ TEST(MDSpanCopy, Mdspan3DHostHost)
}
}

// simd or custom logic
copy(res, out_left.view(), in_right.view());
for (auto i = std::uint32_t{}; i < depth; ++i) {
for (auto j = std::uint32_t{}; j < rows; ++j) {
Expand All @@ -148,7 +158,9 @@ TEST(MDSpanCopy, Mdspan3DHostHost)
}
}

// std::copy
static_assert(detail::mdspan_copyable<true, decltype(out_left.view()), decltype(in_left.view())>::
can_use_std_copy,
"Current implementation should use std::copy for this copy");
copy(res, out_left.view(), in_left.view());
for (auto i = std::uint32_t{}; i < depth; ++i) {
for (auto j = std::uint32_t{}; j < rows; ++j) {
Expand Down Expand Up @@ -190,7 +202,10 @@ TEST(MDSpanCopy, Mdspan3DHostDevice)
make_device_mdarray<float, std::uint32_t, layout_f_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});

// raft::copy
static_assert(detail::mdspan_copyable<true,
decltype(out_right.view()),
decltype(in_right.view())>::can_use_raft_copy,
"Current implementation should use raft::copy for this copy");
copy(res, out_right.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand All @@ -203,7 +218,9 @@ TEST(MDSpanCopy, Mdspan3DHostDevice)
}
}

// raft::copy
static_assert(detail::mdspan_copyable<true, decltype(out_left.view()), decltype(in_left.view())>::
can_use_raft_copy,
"Current implementation should use raft::copy for this copy");
copy(res, out_left.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand Down Expand Up @@ -240,7 +257,10 @@ TEST(MDSpanCopy, Mdspan2DDeviceDevice)
auto out_right = make_device_mdarray<float, std::uint32_t, layout_f_contiguous, rows, cols>(
res, extents<std::uint32_t, rows, cols>{});

// raft::copy
static_assert(detail::mdspan_copyable<true,
decltype(out_right.view()),
decltype(in_right.view())>::can_use_raft_copy,
"Current implementation should use raft::copy for this copy");
copy(res, out_right.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand All @@ -250,7 +270,10 @@ TEST(MDSpanCopy, Mdspan2DDeviceDevice)
}
}

// cublas
static_assert(detail::mdspan_copyable<true,
decltype(out_right.view()),
decltype(in_left.view())>::can_use_cublas,
"Current implementation should use cuBLAS for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand All @@ -260,7 +283,10 @@ TEST(MDSpanCopy, Mdspan2DDeviceDevice)
}
}

// cublas
static_assert(detail::mdspan_copyable<true,
decltype(out_left.view()),
decltype(in_right.view())>::can_use_cublas,
"Current implementation should use cuBLAS for this copy");
copy(res, out_left.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand Down
54 changes: 54 additions & 0 deletions cpp/test/core/mdspan_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ TEST(MDSpanCopy, Mdspan3DDeviceDeviceCuda)
auto out_long =
make_device_mdarray<std::int64_t, std::uint32_t, layout_c_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_long.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_long.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand All @@ -66,6 +69,9 @@ TEST(MDSpanCopy, Mdspan3DDeviceDeviceCuda)
auto out_right = make_device_mdarray<int, std::uint32_t, layout_f_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});

static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand All @@ -76,6 +82,9 @@ TEST(MDSpanCopy, Mdspan3DDeviceDeviceCuda)
}
}

static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_left.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_left.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand Down Expand Up @@ -113,6 +122,9 @@ TEST(MDSpanCopy, Mdspan2DDeviceDeviceCuda)
res.sync_stream();

// Test dtype conversion without transpose
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand All @@ -123,6 +135,9 @@ TEST(MDSpanCopy, Mdspan2DDeviceDeviceCuda)
}

// Test dtype conversion with transpose
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand All @@ -131,6 +146,9 @@ TEST(MDSpanCopy, Mdspan2DDeviceDeviceCuda)
double(out_right(i, j)), double(gen_unique_entry(i, j)), CompareApprox<double>{0.0001}));
}
}
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_left.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_left.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand Down Expand Up @@ -167,6 +185,9 @@ TEST(MDSpanCopy, Mdspan3DDeviceHostCuda)
auto out_long =
make_host_mdarray<std::int64_t, std::uint32_t, layout_c_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_long.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_long.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand All @@ -183,6 +204,9 @@ TEST(MDSpanCopy, Mdspan3DDeviceHostCuda)
auto out_right = make_host_mdarray<int, std::uint32_t, layout_f_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});

static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand All @@ -193,6 +217,9 @@ TEST(MDSpanCopy, Mdspan3DDeviceHostCuda)
}
}

static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_left.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_left.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand Down Expand Up @@ -230,6 +257,9 @@ TEST(MDSpanCopy, Mdspan2DDeviceHostCuda)
res.sync_stream();

// Test dtype conversion without transpose
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand All @@ -240,6 +270,9 @@ TEST(MDSpanCopy, Mdspan2DDeviceHostCuda)
}

// Test dtype conversion with transpose
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand All @@ -248,6 +281,9 @@ TEST(MDSpanCopy, Mdspan2DDeviceHostCuda)
double(out_right(i, j)), double(gen_unique_entry(i, j)), CompareApprox<double>{0.0001}));
}
}
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_left.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_left.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand Down Expand Up @@ -285,6 +321,9 @@ TEST(MDSpanCopy, Mdspan3DHostDeviceCuda)
auto out_long =
make_device_mdarray<std::int64_t, std::uint32_t, layout_c_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_long.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_long.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand All @@ -301,6 +340,9 @@ TEST(MDSpanCopy, Mdspan3DHostDeviceCuda)
auto out_right = make_device_mdarray<int, std::uint32_t, layout_f_contiguous, depth, rows, cols>(
res, extents<std::uint32_t, depth, rows, cols>{});

static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand All @@ -311,6 +353,9 @@ TEST(MDSpanCopy, Mdspan3DHostDeviceCuda)
}
}

static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_left.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_left.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < depth; ++i) {
Expand Down Expand Up @@ -348,6 +393,9 @@ TEST(MDSpanCopy, Mdspan2DHostDeviceCuda)
res.sync_stream();

// Test dtype conversion without transpose
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand All @@ -358,6 +406,9 @@ TEST(MDSpanCopy, Mdspan2DHostDeviceCuda)
}

// Test dtype conversion with transpose
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_right.view()), decltype(in_left.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_right.view(), in_left.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand All @@ -366,6 +417,9 @@ TEST(MDSpanCopy, Mdspan2DHostDeviceCuda)
double(out_right(i, j)), double(gen_unique_entry(i, j)), CompareApprox<double>{0.0001}));
}
}
static_assert(
detail::mdspan_copyable_with_kernel_v<decltype(out_left.view()), decltype(in_right.view())>,
"Current implementation should use kernel for this copy");
copy(res, out_left.view(), in_right.view());
res.sync_stream();
for (auto i = std::uint32_t{}; i < rows; ++i) {
Expand Down

0 comments on commit a8b17a8

Please sign in to comment.