Skip to content

Commit

Permalink
Update cuco-related code
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Feb 5, 2024
1 parent d16ff4f commit 2db4e84
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 127 deletions.
70 changes: 36 additions & 34 deletions cpp/src/prims/key_store.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,8 @@ namespace cugraph {

namespace detail {

using cuco_storage_type = cuco::storage<1>; ///< cuco window storage type

template <typename KeyIterator>
struct key_binary_search_contains_op_t {
using key_type = typename thrust::iterator_traits<KeyIterator>::value_type;
Expand Down Expand Up @@ -72,7 +74,7 @@ template <typename ViewType>
struct key_cuco_store_contains_device_view_t {
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type =
typename ViewType::cuco_store_type::ref_type<cuco::experimental::contains_tag>;
typename ViewType::cuco_store_type::ref_type<cuco::contains_tag>;

static_assert(!ViewType::binary_search);

Expand All @@ -88,9 +90,8 @@ struct key_cuco_store_contains_device_view_t {

template <typename ViewType>
struct key_cuco_store_insert_device_view_t {
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type =
typename ViewType::cuco_store_type::ref_type<cuco::experimental::insert_tag>;
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type = typename ViewType::cuco_store_type::ref_type<cuco::insert_tag>;

static_assert(!ViewType::binary_search);

Expand Down Expand Up @@ -147,14 +148,15 @@ class key_cuco_store_view_t {

static constexpr bool binary_search = false;

using cuco_store_type = cuco::experimental::static_set<
key_t,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;
using cuco_store_type =
cuco::static_set<key_t,
cuco::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>,
cuco_storage_type>;

key_cuco_store_view_t(cuco_store_type const* store) : cuco_store_(store) {}

Expand All @@ -167,12 +169,9 @@ class key_cuco_store_view_t {
cuco_store_->contains(key_first, key_last, value_first, stream);
}

auto cuco_store_contains_device_ref() const
{
return cuco_store_->ref(cuco::experimental::contains);
}
auto cuco_store_contains_device_ref() const { return cuco_store_->ref(cuco::contains); }

auto cuco_store_insert_device_ref() const { return cuco_store_->ref(cuco::experimental::insert); }
auto cuco_store_insert_device_ref() const { return cuco_store_->ref(cuco::insert); }

key_t invalid_key() const { return cuco_store_->get_empty_key_sentinel(); }

Expand Down Expand Up @@ -240,14 +239,15 @@ class key_cuco_store_t {
public:
using key_type = key_t;

using cuco_store_type = cuco::experimental::static_set<
key_t,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;
using cuco_store_type =
cuco::static_set<key_t,
cuco::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>,
cuco_storage_type>;

key_cuco_store_t(rmm::cuda_stream_view stream) {}

Expand Down Expand Up @@ -324,14 +324,16 @@ class key_cuco_store_t {

auto stream_adapter = rmm::mr::make_stream_allocator_adaptor(
rmm::mr::polymorphic_allocator<std::byte>(rmm::mr::get_current_device_resource()), stream);
cuco_store_ = std::make_unique<cuco_store_type>(
cuco_size,
cuco::sentinel::empty_key<key_t>{invalid_key},
thrust::equal_to<key_t>{},
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>{},
stream_adapter,
stream.value());
cuco_store_ =
std::make_unique<cuco_store_type>(cuco_size,
cuco::sentinel::empty_key<key_t>{invalid_key},
thrust::equal_to<key_t>{},
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>{},
cuco::thread_scope_device,
cuco_storage_type{},
stream_adapter,
stream.value());
}

std::unique_ptr<cuco_store_type> cuco_store_{nullptr};
Expand Down
Loading

0 comments on commit 2db4e84

Please sign in to comment.