From ce2cb5740e7e689ac622b4b6c6055dff385d3cf3 Mon Sep 17 00:00:00 2001 From: Muhammad Haseeb <14217455+mhaseeb123@users.noreply.github.com> Date: Tue, 10 Sep 2024 03:51:06 +0000 Subject: [PATCH] Fix for Java tests and replicate the test in C++ Signed-off-by: Muhammad Haseeb <14217455+mhaseeb123@users.noreply.github.com> --- cpp/src/join/mixed_join_kernels_semi.cu | 4 +++- cpp/src/join/mixed_join_semi.cu | 19 ++++++++----------- cpp/tests/join/mixed_join_tests.cu | 15 +++++++++++++++ 3 files changed, 26 insertions(+), 12 deletions(-) diff --git a/cpp/src/join/mixed_join_kernels_semi.cu b/cpp/src/join/mixed_join_kernels_semi.cu index d2c76d11340..9910bfdf2e9 100644 --- a/cpp/src/join/mixed_join_kernels_semi.cu +++ b/cpp/src/join/mixed_join_kernels_semi.cu @@ -65,9 +65,11 @@ CUDF_KERNEL void __launch_bounds__(block_size) left_table, right_table, device_expression_data); if (outer_row_index < outer_num_rows) { + // Make sure to swap_tables here as hash_set will use probe table as the left one. + auto constexpr swap_tables = true; // Figure out the number of elements for this key. auto equality = single_expression_equality{ - evaluator, thread_intermediate_storage, false, equality_probe}; + evaluator, thread_intermediate_storage, swap_tables, equality_probe}; auto const set_ref_equality = set_ref.with_key_eq(equality); auto const result = set_ref_equality.contains(tile, outer_row_index); diff --git a/cpp/src/join/mixed_join_semi.cu b/cpp/src/join/mixed_join_semi.cu index 32df6a7a266..6400cea7dba 100644 --- a/cpp/src/join/mixed_join_semi.cu +++ b/cpp/src/join/mixed_join_semi.cu @@ -151,20 +151,16 @@ std::unique_ptr> mixed_join_semi( preprocessed_build_condtional}; auto const equality_build_conditional = row_comparator_conditional_build.equal_to(build_nulls, compare_nulls); - double_row_equality_comparator equality_build{equality_build_equality, - equality_build_conditional}; - - auto const build_num_rows = compute_hash_table_size(build.num_rows()); hash_set_type row_set{ - build_num_rows, + {compute_hash_table_size(build.num_rows())}, cuco::empty_key{JoinNoneValue}, - equality_build, + {equality_build_equality, equality_build_conditional}, {row_hash_build.device_hasher(build_nulls)}, {}, {}, cudf::detail::cuco_allocator{rmm::mr::polymorphic_allocator{}, stream}, - stream.value()}; + {stream.value()}}; auto iter = thrust::make_counting_iterator(0); @@ -183,12 +179,13 @@ std::unique_ptr> mixed_join_semi( detail::grid_1d const config(outer_num_rows, DEFAULT_JOIN_BLOCK_SIZE); auto const shmem_size_per_block = - (parser.shmem_per_thread / hash_set_type::cg_size) * config.num_threads_per_block; + parser.shmem_per_thread * + cuco::detail::int_div_ceil(config.num_threads_per_block, hash_set_type::cg_size); - auto const row_hash = cudf::experimental::row::hash::row_hasher{preprocessed_probe}; + auto const row_hash = cudf::experimental::row::hash::row_hasher{preprocessed_probe}; + auto const hash_probe = row_hash.device_hasher(has_nulls); - hash_set_ref_type const row_set_ref = - row_set.ref(cuco::contains).with_hash_function(row_hash.device_hasher(has_nulls)); + hash_set_ref_type const row_set_ref = row_set.ref(cuco::contains).with_hash_function(hash_probe); // Vector used to indicate indices from left/probe table which are present in output auto left_table_keep_mask = rmm::device_uvector(probe.num_rows(), stream); diff --git a/cpp/tests/join/mixed_join_tests.cu b/cpp/tests/join/mixed_join_tests.cu index 6c147c8a128..c02349feab1 100644 --- a/cpp/tests/join/mixed_join_tests.cu +++ b/cpp/tests/join/mixed_join_tests.cu @@ -778,6 +778,21 @@ TYPED_TEST(MixedLeftSemiJoinTest, BasicEquality) {1}); } +TYPED_TEST(MixedLeftSemiJoinTest, MixedLeftSemiJoinGatherMap) +{ + auto const col_ref_left_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::LEFT); + auto const col_ref_right_1 = cudf::ast::column_reference(0, cudf::ast::table_reference::RIGHT); + auto left_one_greater_right_one = + cudf::ast::operation(cudf::ast::ast_operator::GREATER, col_ref_left_1, col_ref_right_1); + + this->test({{2, 3, 9, 0, 1, 7, 4, 6, 5, 8}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 0}}, + {{6, 5, 9, 8, 10, 32}, {0, 1, 2, 3, 4, 5}, {7, 8, 9, 0, 1, 2}}, + {0}, + {1}, + left_one_greater_right_one, + {2, 7, 8}); +} + TYPED_TEST(MixedLeftSemiJoinTest, BasicEqualityDuplicates) { this->test({{0, 1, 2, 1}, {3, 4, 5, 6}, {10, 20, 30, 40}},