From cf45304240254abd98482e472e943454e8a94527 Mon Sep 17 00:00:00 2001 From: Andrey Bokovoy Date: Tue, 13 Aug 2024 16:27:45 +0200 Subject: [PATCH] Wrap rocmprim header with #ifdef --- .../src/split_embeddings_utils/transpose_embedding_input.cu | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu index c0bbf2492b..25f9d0688d 100644 --- a/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu +++ b/fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu @@ -9,7 +9,9 @@ #include "fbgemm_gpu/embedding_backward_template_helpers.cuh" // @manual #include "fbgemm_gpu/ops_utils.h" // @manual #include "fbgemm_gpu/split_embeddings_utils.cuh" // @manual +#ifdef USE_ROCM #include +#endif // clang-format off #include "fbgemm_gpu/cub_namespace_prefix.cuh" // @manual #include @@ -297,7 +299,7 @@ transpose_embedding_input( } { size_t temp_storage_bytes = 0; -#ifdef __HIP_PLATFORM_NVIDIA__ +#ifndef USE_ROCM AT_CUDA_CHECK( FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs( nullptr,