Skip to content

Commit

Permalink
Fix for Java tests and replicate the test in C++
Browse files Browse the repository at this point in the history
Signed-off-by: Muhammad Haseeb <[email protected]>
  • Loading branch information
mhaseeb123 committed Sep 10, 2024
1 parent ee0e606 commit ce2cb57
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
4 changes: 3 additions & 1 deletion cpp/src/join/mixed_join_kernels_semi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ CUDF_KERNEL void __launch_bounds__(block_size)
left_table, right_table, device_expression_data);

if (outer_row_index < outer_num_rows) {
// Make sure to swap_tables here as hash_set will use probe table as the left one.
auto constexpr swap_tables = true;
// Figure out the number of elements for this key.
auto equality = single_expression_equality<has_nulls>{
evaluator, thread_intermediate_storage, false, equality_probe};
evaluator, thread_intermediate_storage, swap_tables, equality_probe};

auto const set_ref_equality = set_ref.with_key_eq(equality);
auto const result = set_ref_equality.contains(tile, outer_row_index);
Expand Down
19 changes: 8 additions & 11 deletions cpp/src/join/mixed_join_semi.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,16 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(
preprocessed_build_condtional};
auto const equality_build_conditional =
row_comparator_conditional_build.equal_to<false>(build_nulls, compare_nulls);
double_row_equality_comparator equality_build{equality_build_equality,
equality_build_conditional};

auto const build_num_rows = compute_hash_table_size(build.num_rows());

hash_set_type row_set{
build_num_rows,
{compute_hash_table_size(build.num_rows())},
cuco::empty_key{JoinNoneValue},
equality_build,
{equality_build_equality, equality_build_conditional},
{row_hash_build.device_hasher(build_nulls)},
{},
{},
cudf::detail::cuco_allocator<char>{rmm::mr::polymorphic_allocator<char>{}, stream},
stream.value()};
{stream.value()}};

auto iter = thrust::make_counting_iterator(0);

Expand All @@ -183,12 +179,13 @@ std::unique_ptr<rmm::device_uvector<size_type>> mixed_join_semi(

detail::grid_1d const config(outer_num_rows, DEFAULT_JOIN_BLOCK_SIZE);
auto const shmem_size_per_block =
(parser.shmem_per_thread / hash_set_type::cg_size) * config.num_threads_per_block;
parser.shmem_per_thread *
cuco::detail::int_div_ceil(config.num_threads_per_block, hash_set_type::cg_size);

auto const row_hash = cudf::experimental::row::hash::row_hasher{preprocessed_probe};
auto const row_hash = cudf::experimental::row::hash::row_hasher{preprocessed_probe};
auto const hash_probe = row_hash.device_hasher(has_nulls);

hash_set_ref_type const row_set_ref =
row_set.ref(cuco::contains).with_hash_function(row_hash.device_hasher(has_nulls));
hash_set_ref_type const row_set_ref = row_set.ref(cuco::contains).with_hash_function(hash_probe);

// Vector used to indicate indices from left/probe table which are present in output
auto left_table_keep_mask = rmm::device_uvector<bool>(probe.num_rows(), stream);
Expand Down
15 changes: 15 additions & 0 deletions cpp/tests/join/mixed_join_tests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,21 @@ TYPED_TEST(MixedLeftSemiJoinTest, BasicEquality)
{1});
}

TYPED_TEST(MixedLeftSemiJoinTest, MixedLeftSemiJoinGatherMap)
{
auto const col_ref_left_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::LEFT);
auto const col_ref_right_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT);
auto left_one_greater_right_one =
cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_left_1, col_ref_right_1);

this->test({{2, 3, 9, 0, 1, 7, 4, 6, 5, 8}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}},
{{6, 5, 9, 8, 10, 32}, {0, 1, 2, 3, 4, 5}, {7, 8, 9, 0, 1, 2}},
{0},
{1},
left_one_greater_right_one,
{2, 7, 8});
}

TYPED_TEST(MixedLeftSemiJoinTest, BasicEqualityDuplicates)
{
this->test({{0, 1, 2, 1}, {3, 4, 5, 6}, {10, 20, 30, 40}},
Expand Down

0 comments on commit ce2cb57

Please sign in to comment.