From 1e5030d1b4f85a9f306c36f8a030494fa59aaaa4 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Wed, 11 Dec 2024 23:39:15 +0100 Subject: [PATCH] Fix rnd bit generation in rmat_rectangular_kernel (#2524) For certain architectures, the compiler always generates zero destination bit in the following loop https://github.com/rapidsai/raft/blob/ee45ce786686b54d1972408b927d7fcd8ce0cf20/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh#L160-L162 irrespective of the random value that shall determine which bit to use for `dst_id`. This PR refactors the loop. This way the `dst_id` number has the desired random distribution for all bits. Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2524 --- .../detail/rmat_rectangular_generator.cuh | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh index 9ad7c68f87..24207ba6db 100644 --- a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh @@ -151,15 +151,16 @@ RAFT_KERNEL rmat_gen_kernel(IdxT* out, raft::random::PCGenerator gen{r.seed, r.base_subsequence + idx, 0}; auto min_scale = min(r_scale, c_scale); IdxT i = 0; - for (; i < min_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a, a + b, a + b + c, r_scale, c_scale, i, gen); - } - for (; i < r_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a + b, a + b, ProbT(1), r_scale, c_scale, i, gen); - } - for (; i < c_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a + c, ProbT(1), ProbT(1), r_scale, c_scale, i, gen); + // Whether we have more rows than columns. + const bool more_rows = r_scale > c_scale; + + for (; i < max_scale; ++i) { + ProbT A = (i < min_scale) ? a : (more_rows ? a + b : a + c); + ProbT AB = (i < min_scale) ? a + b : (more_rows ? a + b : ProbT(1)); + ProbT ABC = (i < min_scale) ? a + b + c : ProbT(1); + gen_and_update_bits(src_id, dst_id, A, AB, ABC, r_scale, c_scale, i, gen); } + store_ids(out, out_src, out_dst, src_id, dst_id, idx, n_edges); }