From 466b5b9e50c07902d576167770857014d1c30fde Mon Sep 17 00:00:00 2001 From: Chang Liu Date: Mon, 2 Dec 2024 09:13:15 -0800 Subject: [PATCH] [Bugfix] Add stream synchronization before the scatter operation (#73) This is to address the issue from this PR: https://github.com/rapidsai/wholegraph/pull/229, and it's only for the last scatter operation before the Python interface (not for all internal `scatter_func` calls) Since the output of the scatter operation could be on the host (e.g., when emb_device = 'cpu'), it is necessary to perform synchronization internally. This ensures users do not need to explicitly synchronize the compute stream before accessing the host memory. Unlike the gather operation, where the output is always in device memory, host side synchronization is unnecessary. Authors: - Chang Liu (https://github.com/chang-l) - Alex Barghi (https://github.com/alexbarghi-nv) Approvers: - https://github.com/linhu-nv URL: https://github.com/rapidsai/cugraph-gnn/pull/73 --- cpp/src/wholememory_ops/scatter_op_impl_mapped.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/wholememory_ops/scatter_op_impl_mapped.cu b/cpp/src/wholememory_ops/scatter_op_impl_mapped.cu index 2d7d497..dbe68b1 100644 --- a/cpp/src/wholememory_ops/scatter_op_impl_mapped.cu +++ b/cpp/src/wholememory_ops/scatter_op_impl_mapped.cu @@ -18,6 +18,7 @@ #include #include +#include "cuda_macros.hpp" #include "wholememory_ops/functions/gather_scatter_func.h" namespace wholememory_ops { @@ -41,6 +42,7 @@ wholememory_error_code_t wholememory_scatter_mapped( wholememory_desc, stream, scatter_sms); + WM_CUDA_CHECK(cudaStreamSynchronize(stream)); } } // namespace wholememory_ops