Skip to content

Commit

Permalink
bounce buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
vuule committed Sep 28, 2024
1 parent 4030a89 commit 243e12e
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions cpp/include/cudf/detail/device_scalar.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,33 +43,49 @@ class device_scalar : public rmm::device_scalar<T> {
explicit device_scalar(
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref())
: rmm::device_scalar<T>(stream, mr)
: rmm::device_scalar<T>(stream, mr), bounce_buffer{make_host_vector<T>(1, stream)}
{
}

explicit device_scalar(
T const& initial_value,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref())
: rmm::device_scalar<T>(stream, mr)
: rmm::device_scalar<T>(stream, mr), bounce_buffer{make_host_vector<T>(1, stream)}
{
auto bounce_buffer = make_host_vector<T>(1, stream);
bounce_buffer[0] = initial_value;
// TODO replace with to_device
cuda_memcpy<T>(device_span<T>{this->data(), 1}, bounce_buffer, stream);
bounce_buffer[0] = initial_value;
cuda_memcpy_async<T>(device_span<T>{this->data(), 1}, bounce_buffer, stream);
}

device_scalar(device_scalar const& other,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr = cudf::get_current_device_resource_ref())
: rmm::device_scalar<T>(other, stream, mr)
: rmm::device_scalar<T>(other, stream, mr), bounce_buffer{make_host_vector<T>(1, stream)}
{
}

[[nodiscard]] T value(rmm::cuda_stream_view stream) const
{
return make_host_vector_sync(device_span<T const>{this->data(), 1}, stream)[0];
cuda_memcpy<T>(bounce_buffer, device_span<T const>{this->data(), 1}, stream);
return bounce_buffer[0];
}

void set_value_async(T const& value, rmm::cuda_stream_view stream)
{
bounce_buffer[0] = value;
cuda_memcpy_async<T>(device_span<T>{this->data(), 1}, bounce_buffer, stream);
}

void set_value_async(T&& value, rmm::cuda_stream_view stream)
{
bounce_buffer[0] = std::move(value);
cuda_memcpy_async<T>(device_span<T>{this->data(), 1}, bounce_buffer, stream);
}

void set_value_to_zero_async(rmm::cuda_stream_view stream) { set_value_async(T{}, stream); }

private:
mutable cudf::detail::host_vector<T> bounce_buffer;
};

} // namespace detail
Expand Down

0 comments on commit 243e12e

Please sign in to comment.