diff --git a/cpp/src/prims/key_store.cuh b/cpp/src/prims/key_store.cuh index 7498d25301f..907ca36ef4a 100644 --- a/cpp/src/prims/key_store.cuh +++ b/cpp/src/prims/key_store.cuh @@ -72,9 +72,8 @@ struct key_binary_search_store_device_view_t { template struct key_cuco_store_contains_device_view_t { - using key_type = typename ViewType::key_type; - using cuco_store_device_ref_type = - typename ViewType::cuco_store_type::ref_type; + using key_type = typename ViewType::key_type; + using cuco_store_device_ref_type = typename ViewType::cuco_set_type::ref_type; static_assert(!ViewType::binary_search); @@ -91,7 +90,7 @@ struct key_cuco_store_contains_device_view_t { template struct key_cuco_store_insert_device_view_t { using key_type = typename ViewType::key_type; - using cuco_store_device_ref_type = typename ViewType::cuco_store_type::ref_type; + using cuco_store_device_ref_type = typename ViewType::cuco_set_type::ref_type; static_assert(!ViewType::binary_search); @@ -148,7 +147,7 @@ class key_cuco_store_view_t { static constexpr bool binary_search = false; - using cuco_store_type = + using cuco_set_type = cuco::static_set, cuda::thread_scope_device, @@ -158,7 +157,7 @@ class key_cuco_store_view_t { rmm::mr::stream_allocator_adaptor>, cuco_storage_type>; - key_cuco_store_view_t(cuco_store_type const* store) : cuco_store_(store) {} + key_cuco_store_view_t(cuco_set_type const* store) : cuco_store_(store) {} template void contains(QueryKeyIterator key_first, @@ -176,7 +175,7 @@ class key_cuco_store_view_t { key_t invalid_key() const { return cuco_store_->get_empty_key_sentinel(); } private: - cuco_store_type const* cuco_store_{}; + cuco_set_type const* cuco_store_{}; }; template @@ -239,7 +238,7 @@ class key_cuco_store_t { public: using key_type = key_t; - using cuco_store_type = + using cuco_set_type = cuco::static_set, cuda::thread_scope_device, @@ -306,7 +305,7 @@ class key_cuco_store_t { return keys; } - cuco_store_type const* cuco_store_ptr() const { return cuco_store_.get(); } + cuco_set_type const* cuco_store_ptr() const { return cuco_store_.get(); } key_t invalid_key() const { return cuco_store_->empty_key_sentinel(); } @@ -325,18 +324,18 @@ class key_cuco_store_t { auto stream_adapter = rmm::mr::make_stream_allocator_adaptor( rmm::mr::polymorphic_allocator(rmm::mr::get_current_device_resource()), stream); cuco_store_ = - std::make_unique(cuco_size, - cuco::sentinel::empty_key{invalid_key}, - thrust::equal_to{}, - cuco::linear_probing<1, // CG size - cuco::murmurhash3_32>{}, - cuco::thread_scope_device, - cuco_storage_type{}, - stream_adapter, - stream.value()); + std::make_unique(cuco_size, + cuco::sentinel::empty_key{invalid_key}, + thrust::equal_to{}, + cuco::linear_probing<1, // CG size + cuco::murmurhash3_32>{}, + cuco::thread_scope_device, + cuco_storage_type{}, + stream_adapter, + stream.value()); } - std::unique_ptr cuco_store_{nullptr}; + std::unique_ptr cuco_store_{nullptr}; size_t capacity_{0}; size_t size_{0}; // caching as cuco_store_->size() is expensive (this scans the entire slots to diff --git a/cpp/src/prims/kv_store.cuh b/cpp/src/prims/kv_store.cuh index 734f10d3788..be4fde2fbff 100644 --- a/cpp/src/prims/kv_store.cuh +++ b/cpp/src/prims/kv_store.cuh @@ -89,7 +89,7 @@ struct kv_binary_search_contains_op_t { template struct kv_cuco_insert_and_increment_t { using key_type = typename thrust::iterator_traits::value_type; - using cuco_store_type = + using cuco_set_type = cuco::static_map, @@ -100,7 +100,7 @@ struct kv_cuco_insert_and_increment_t { rmm::mr::stream_allocator_adaptor>, cuco_storage_type>; - typename cuco_store_type::ref_type device_ref{}; + typename cuco_set_type::ref_type device_ref{}; KeyIterator key_first{}; size_t* counter{nullptr}; size_t invalid_idx{}; @@ -112,7 +112,7 @@ struct kv_cuco_insert_and_increment_t { if (inserted) { cuda::atomic_ref atomic_counter(*counter); auto idx = atomic_counter.fetch_add(size_t{1}, cuda::std::memory_order_relaxed); - using ref_type = typename cuco_store_type::ref_type; + using ref_type = typename cuco_set_type::ref_type; cuda::atomic_ref ref( (*iter).second); ref.store(idx, cuda::std::memory_order_relaxed); @@ -126,7 +126,7 @@ struct kv_cuco_insert_and_increment_t { template struct kv_cuco_insert_if_and_increment_t { using key_type = typename thrust::iterator_traits::value_type; - using cuco_store_type = + using cuco_set_type = cuco::static_map, @@ -137,7 +137,7 @@ struct kv_cuco_insert_if_and_increment_t { rmm::mr::stream_allocator_adaptor>, cuco_storage_type>; - typename cuco_store_type::ref_type device_ref{}; + typename cuco_set_type::ref_type device_ref{}; KeyIterator key_first{}; StencilIterator stencil_first{}; PredOp pred_op{}; @@ -153,7 +153,7 @@ struct kv_cuco_insert_if_and_increment_t { if (inserted) { cuda::atomic_ref atomic_counter(*counter); auto idx = atomic_counter.fetch_add(size_t{1}, cuda::std::memory_order_relaxed); - using ref_type = typename cuco_store_type::ref_type; + using ref_type = typename cuco_set_type::ref_type; cuda::atomic_ref ref( (*iter).second); ref.store(idx, cuda::std::memory_order_relaxed); @@ -166,7 +166,7 @@ struct kv_cuco_insert_if_and_increment_t { template struct kv_cuco_insert_and_assign_t { - using cuco_store_type = + using cuco_set_type = cuco::static_map, value_t, size_t>, cuco::extent, @@ -177,13 +177,13 @@ struct kv_cuco_insert_and_assign_t { rmm::mr::stream_allocator_adaptor>, cuco_storage_type>; - typename cuco_store_type::ref_type device_ref{}; + typename cuco_set_type::ref_type device_ref{}; __device__ void operator()(thrust::tuple pair) { auto [iter, inserted] = device_ref.insert_and_find(pair); if (!inserted) { - using ref_type = typename cuco_store_type::ref_type; + using ref_type = typename cuco_set_type::ref_type; cuda::atomic_ref ref( (*iter).second); ref.store(thrust::get<1>(pair), cuda::std::memory_order_relaxed); @@ -227,7 +227,7 @@ template struct kv_cuco_store_find_device_view_t { using key_type = typename ViewType::key_type; using value_type = typename ViewType::value_type; - using cuco_store_device_ref_type = typename ViewType::cuco_store_type::ref_type; + using cuco_store_device_ref_type = typename ViewType::cuco_set_type::ref_type; static_assert(!ViewType::binary_search); @@ -340,7 +340,7 @@ class kv_cuco_store_view_t { static constexpr bool binary_search = false; - using cuco_store_type = + using cuco_set_type = cuco::static_map, value_type, size_t>, cuco::extent, @@ -352,14 +352,14 @@ class kv_cuco_store_view_t { cuco_storage_type>; template - kv_cuco_store_view_t(cuco_store_type const* store, + kv_cuco_store_view_t(cuco_set_type const* store, std::enable_if_t, int32_t> = 0) : cuco_store_(store) { } template - kv_cuco_store_view_t(cuco_store_type const* store, + kv_cuco_store_view_t(cuco_set_type const* store, ValueIterator value_first, type invalid_value, std::enable_if_t, int32_t> = 0) @@ -417,7 +417,7 @@ class kv_cuco_store_view_t { } private: - cuco_store_type const* cuco_store_{}; + cuco_set_type const* cuco_store_{}; std::conditional_t, ValueIterator, std::byte /* dummy */> store_value_first_{}; @@ -536,7 +536,7 @@ class kv_cuco_store_t { std::invoke_result_t), value_buffer_type&>; - using cuco_store_type = + using cuco_set_type = cuco::static_map, value_t, size_t>, cuco::extent, @@ -788,7 +788,7 @@ class kv_cuco_store_t { return std::make_tuple(std::move(retrieved_keys), std::move(retrieved_values)); } - cuco_store_type const* cuco_store_ptr() const { return cuco_store_.get(); } + cuco_set_type const* cuco_store_ptr() const { return cuco_store_.get(); } template std::enable_if_t, const_value_iterator> store_value_first() const @@ -827,18 +827,18 @@ class kv_cuco_store_t { rmm::mr::polymorphic_allocator(rmm::mr::get_current_device_resource()), stream); if constexpr (std::is_arithmetic_v) { cuco_store_ = - std::make_unique(cuco_size, - cuco::sentinel::empty_key{invalid_key}, - cuco::sentinel::empty_value{invalid_value}, - thrust::equal_to{}, - cuco::linear_probing<1, // CG size - cuco::murmurhash3_32>{}, - cuco::thread_scope_device, - cuco_storage_type{}, - stream_adapter, - stream.value()); + std::make_unique(cuco_size, + cuco::sentinel::empty_key{invalid_key}, + cuco::sentinel::empty_value{invalid_value}, + thrust::equal_to{}, + cuco::linear_probing<1, // CG size + cuco::murmurhash3_32>{}, + cuco::thread_scope_device, + cuco_storage_type{}, + stream_adapter, + stream.value()); } else { - cuco_store_ = std::make_unique( + cuco_store_ = std::make_unique( cuco_size, cuco::sentinel::empty_key{invalid_key}, cuco::sentinel::empty_value{std::numeric_limits::max()}, @@ -852,7 +852,7 @@ class kv_cuco_store_t { } } - std::unique_ptr cuco_store_{nullptr}; + std::unique_ptr cuco_store_{nullptr}; std::conditional_t, decltype(allocate_dataframe_buffer(0, rmm::cuda_stream_view{})), std::byte /* dummy */>