Skip to content

Commit

Permalink
Fix merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
PointKernel committed Jan 24, 2024
1 parent 8a74bd3 commit 8f5e165
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 127 deletions.
64 changes: 30 additions & 34 deletions cpp/src/prims/key_store.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -72,7 +72,7 @@ template <typename ViewType>
struct key_cuco_store_contains_device_view_t {
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type =
typename ViewType::cuco_store_type::ref_type<cuco::experimental::contains_tag>;
typename ViewType::cuco_store_type::ref_type<cuco::contains_tag>;

static_assert(!ViewType::binary_search);

Expand All @@ -88,9 +88,8 @@ struct key_cuco_store_contains_device_view_t {

template <typename ViewType>
struct key_cuco_store_insert_device_view_t {
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type =
typename ViewType::cuco_store_type::ref_type<cuco::experimental::insert_tag>;
using key_type = typename ViewType::key_type;
using cuco_store_device_ref_type = typename ViewType::cuco_store_type::ref_type<cuco::insert_tag>;

static_assert(!ViewType::binary_search);

Expand Down Expand Up @@ -147,14 +146,14 @@ class key_cuco_store_view_t {

static constexpr bool binary_search = false;

using cuco_store_type = cuco::experimental::static_set<
key_t,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;
using cuco_store_type =
cuco::static_set<key_t,
cuco::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;

key_cuco_store_view_t(cuco_store_type const* store) : cuco_store_(store) {}

Expand All @@ -167,12 +166,9 @@ class key_cuco_store_view_t {
cuco_store_->contains(key_first, key_last, value_first, stream);
}

auto cuco_store_contains_device_ref() const
{
return cuco_store_->ref(cuco::experimental::contains);
}
auto cuco_store_contains_device_ref() const { return cuco_store_->ref(cuco::contains); }

auto cuco_store_insert_device_ref() const { return cuco_store_->ref(cuco::experimental::insert); }
auto cuco_store_insert_device_ref() const { return cuco_store_->ref(cuco::insert); }

key_t invalid_key() const { return cuco_store_->get_empty_key_sentinel(); }

Expand Down Expand Up @@ -240,14 +236,14 @@ class key_cuco_store_t {
public:
using key_type = key_t;

using cuco_store_type = cuco::experimental::static_set<
key_t,
cuco::experimental::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;
using cuco_store_type =
cuco::static_set<key_t,
cuco::extent<std::size_t>,
cuda::thread_scope_device,
thrust::equal_to<key_t>,
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>,
rmm::mr::stream_allocator_adaptor<rmm::mr::polymorphic_allocator<std::byte>>>;

key_cuco_store_t(rmm::cuda_stream_view stream) {}

Expand Down Expand Up @@ -324,14 +320,14 @@ class key_cuco_store_t {

auto stream_adapter = rmm::mr::make_stream_allocator_adaptor(
rmm::mr::polymorphic_allocator<std::byte>(rmm::mr::get_current_device_resource()), stream);
cuco_store_ = std::make_unique<cuco_store_type>(
cuco_size,
cuco::sentinel::empty_key<key_t>{invalid_key},
thrust::equal_to<key_t>{},
cuco::experimental::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>{},
stream_adapter,
stream.value());
cuco_store_ =
std::make_unique<cuco_store_type>(cuco_size,
cuco::sentinel::empty_key<key_t>{invalid_key},
thrust::equal_to<key_t>{},
cuco::linear_probing<1, // CG size
cuco::murmurhash3_32<key_t>>{},
stream_adapter,
stream.value());
}

std::unique_ptr<cuco_store_type> cuco_store_{nullptr};
Expand Down
Loading

0 comments on commit 8f5e165

Please sign in to comment.