Skip to content

Commit

Permalink
Merge branch 'branch-24.04' of https://github.com/rapidsai/cugraph in…
Browse files Browse the repository at this point in the history
…to fea_prim_edge_masking
  • Loading branch information
seunghwak committed Feb 20, 2024
2 parents c164b4c + f0388bc commit 1089a60
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 139 deletions.
81 changes: 41 additions & 40 deletions cpp/src/prims/key_store.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,8 @@ namespace cugraph {

namespace detail {

using cuco_storage_type = cuco::storage<1>; ///< cuco window storage type

template <typename KeyIterator>
struct key_binary_search_contains_op_t {
using key_type = typename thrust::iterator_traits<KeyIterator>::value_type;
Expand Down Expand Up @@ -70,9 +72,8 @@ struct key_binary_search_store_device_view_t {

template <typename ViewType>
struct key_cuco_store_contains_device_view_t {
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type =
typename ViewType::cuco_store_type::ref_type<cuco::experimental::contains_tag>;
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type = typename ViewType::cuco_set_type::ref_type<cuco::contains_tag>;

static_assert(!ViewType::binary_search);

Expand All @@ -88,9 +89,8 @@ struct key_cuco_store_contains_device_view_t {

template <typename ViewType>
struct key_cuco_store_insert_device_view_t {
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type =
typename ViewType::cuco_store_type::ref_type<cuco::experimental::insert_tag>;
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type = typename ViewType::cuco_set_type::ref_type<cuco::insert_tag>;

static_assert(!ViewType::binary_search);

Expand Down Expand Up @@ -147,16 +147,17 @@ class key_cuco_store_view_t {

static constexpr bool binary_search = false;

using cuco_store_type = cuco::experimental::static_set<
key_t,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;
using cuco_set_type =
cuco::static_set<key_t,
cuco::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>,
cuco_storage_type>;

key_cuco_store_view_t(cuco_store_type const* store) : cuco_store_(store) {}
key_cuco_store_view_t(cuco_set_type const* store) : cuco_store_(store) {}

template <typename QueryKeyIterator, typename ResultValueIterator>
void contains(QueryKeyIterator key_first,
Expand All @@ -167,17 +168,14 @@ class key_cuco_store_view_t {
cuco_store_->contains(key_first, key_last, value_first, stream);
}

auto cuco_store_contains_device_ref() const
{
return cuco_store_->ref(cuco::experimental::contains);
}
auto cuco_store_contains_device_ref() const { return cuco_store_->ref(cuco::contains); }

auto cuco_store_insert_device_ref() const { return cuco_store_->ref(cuco::experimental::insert); }
auto cuco_store_insert_device_ref() const { return cuco_store_->ref(cuco::insert); }

key_t invalid_key() const { return cuco_store_->get_empty_key_sentinel(); }

private:
cuco_store_type const* cuco_store_{};
cuco_set_type const* cuco_store_{};
};

template <typename key_t>
Expand Down Expand Up @@ -240,14 +238,15 @@ class key_cuco_store_t {
public:
using key_type = key_t;

using cuco_store_type = cuco::experimental::static_set<
key_t,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;
using cuco_set_type =
cuco::static_set<key_t,
cuco::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>,
cuco_storage_type>;

key_cuco_store_t(rmm::cuda_stream_view stream) {}

Expand Down Expand Up @@ -306,7 +305,7 @@ class key_cuco_store_t {
return keys;
}

cuco_store_type const* cuco_store_ptr() const { return cuco_store_.get(); }
cuco_set_type const* cuco_store_ptr() const { return cuco_store_.get(); }

key_t invalid_key() const { return cuco_store_->empty_key_sentinel(); }

Expand All @@ -324,17 +323,19 @@ class key_cuco_store_t {

auto stream_adapter = rmm::mr::make_stream_allocator_adaptor(
rmm::mr::polymorphic_allocator<std::byte>(rmm::mr::get_current_device_resource()), stream);
cuco_store_ = std::make_unique<cuco_store_type>(
cuco_size,
cuco::sentinel::empty_key<key_t>{invalid_key},
thrust::equal_to<key_t>{},
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>{},
stream_adapter,
stream.value());
cuco_store_ =
std::make_unique<cuco_set_type>(cuco_size,
cuco::sentinel::empty_key<key_t>{invalid_key},
thrust::equal_to<key_t>{},
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>{},
cuco::thread_scope_device,
cuco_storage_type{},
stream_adapter,
stream.value());
}

std::unique_ptr<cuco_store_type> cuco_store_{nullptr};
std::unique_ptr<cuco_set_type> cuco_store_{nullptr};

size_t capacity_{0};
size_t size_{0}; // caching as cuco_store_->size() is expensive (this scans the entire slots to
Expand Down
Loading

0 comments on commit 1089a60

Please sign in to comment.