diff --git a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh index 9ad7c68f87..24207ba6db 100644 --- a/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh +++ b/cpp/include/raft/random/detail/rmat_rectangular_generator.cuh @@ -151,15 +151,16 @@ RAFT_KERNEL rmat_gen_kernel(IdxT* out, raft::random::PCGenerator gen{r.seed, r.base_subsequence + idx, 0}; auto min_scale = min(r_scale, c_scale); IdxT i = 0; - for (; i < min_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a, a + b, a + b + c, r_scale, c_scale, i, gen); - } - for (; i < r_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a + b, a + b, ProbT(1), r_scale, c_scale, i, gen); - } - for (; i < c_scale; ++i) { - gen_and_update_bits(src_id, dst_id, a + c, ProbT(1), ProbT(1), r_scale, c_scale, i, gen); + // Whether we have more rows than columns. + const bool more_rows = r_scale > c_scale; + + for (; i < max_scale; ++i) { + ProbT A = (i < min_scale) ? a : (more_rows ? a + b : a + c); + ProbT AB = (i < min_scale) ? a + b : (more_rows ? a + b : ProbT(1)); + ProbT ABC = (i < min_scale) ? a + b + c : ProbT(1); + gen_and_update_bits(src_id, dst_id, A, AB, ABC, r_scale, c_scale, i, gen); } + store_ids(out, out_src, out_dst, src_id, dst_id, idx, n_edges); }