diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b31cbe124..3eb0dbc8d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -171,6 +171,7 @@ endif() if(RAJA_ENABLE_HIP) message(STATUS "HIP version: ${hip_VERSION}") + set(RAJA_HIP_WAVESIZE "64" CACHE STRING "Set the wave size for GPU architecture. E.g. MI200/MI300 this is 64.") if("${hip_VERSION}" VERSION_LESS "3.5") message(FATAL_ERROR "Trying to use HIP/ROCm version ${hip_VERSION}. RAJA requires HIP/ROCm version 3.5 or newer. ") endif() diff --git a/include/RAJA/config.hpp.in b/include/RAJA/config.hpp.in index 380418efa1..29d97fed69 100644 --- a/include/RAJA/config.hpp.in +++ b/include/RAJA/config.hpp.in @@ -182,6 +182,8 @@ static_assert(RAJA_HAS_SOME_CXX14, #cmakedefine RAJA_ENABLE_NV_TOOLS_EXT #cmakedefine RAJA_ENABLE_ROCTX +#cmakedefine RAJA_HIP_WAVESIZE @RAJA_HIP_WAVESIZE@ + /*! ****************************************************************************** * diff --git a/include/RAJA/policy/hip/policy.hpp b/include/RAJA/policy/hip/policy.hpp index a9f9027675..040de50f31 100644 --- a/include/RAJA/policy/hip/policy.hpp +++ b/include/RAJA/policy/hip/policy.hpp @@ -324,8 +324,9 @@ struct DeviceConstants // values for HIP warp size and max block size. // #if defined(__HIP_PLATFORM_AMD__) -constexpr DeviceConstants device_constants(64, 1024, 64); // MI300A -// constexpr DeviceConstants device_constants(64, 1024, 128); // MI250X +constexpr DeviceConstants device_constants(RAJA_HIP_WAVESIZE, 1024, 64); // MI300A +// constexpr DeviceConstants device_constants(RAJA_HIP_WAVESIZE, 1024, 128); // MI250X + #elif defined(__HIP_PLATFORM_NVIDIA__) constexpr DeviceConstants device_constants(32, 1024, 32); // V100 #endif diff --git a/include/RAJA/policy/tensor/arch/hip/hip_wave.hpp b/include/RAJA/policy/tensor/arch/hip/hip_wave.hpp index 74bbc2f077..f1810807f9 100644 --- a/include/RAJA/policy/tensor/arch/hip/hip_wave.hpp +++ b/include/RAJA/policy/tensor/arch/hip/hip_wave.hpp @@ -57,7 +57,7 @@ namespace expt public: - static constexpr int s_num_elem = 64; + static constexpr int s_num_elem = policy::hip::device_constants.WARP_SIZE; /*! * @brief Default constructor, zeros register contents @@ -780,8 +780,8 @@ namespace expt // Third: mask off everything but output_segment // this is because all output segments are valid at this point - // (5-segbits), the 5 is since the warp-width is 32 == 1<<5 - int our_output_segment = get_lane()>>(6-segbits); + static constexpr int log2_warp_size = RAJA::log2(RAJA::policy::hip::device_constants.WARP_SIZE); + int our_output_segment = get_lane()>>(log2_warp_size-segbits); bool in_output_segment = our_output_segment == output_segment; if(!in_output_segment){ result.get_raw_value() = 0; @@ -828,8 +828,9 @@ namespace expt // First: tree reduce values within each segment element_type x = m_value; + static constexpr int log2_warp_size = RAJA::log2(RAJA::policy::hip::device_constants.WARP_SIZE); RAJA_UNROLL - for(int i = 0;i < 6-segbits; ++ i){ + for(int i = 0;i < log2_warp_size-segbits; ++ i){ // tree shuffle int delta = s_num_elem >> (i+1); diff --git a/include/RAJA/policy/tensor/arch/hip/traits.hpp b/include/RAJA/policy/tensor/arch/hip/traits.hpp index 4c4d959599..1b8a9679bb 100644 --- a/include/RAJA/policy/tensor/arch/hip/traits.hpp +++ b/include/RAJA/policy/tensor/arch/hip/traits.hpp @@ -29,7 +29,8 @@ namespace expt { struct RegisterTraits{ using element_type = T; using register_policy = RAJA::expt::hip_wave_register; - static constexpr camp::idx_t s_num_elem = 64; + static constexpr camp::idx_t s_num_elem = policy::hip::device_constants.WARP_SIZE; + static constexpr camp::idx_t s_num_bits = sizeof(T) * s_num_elem; using int_element_type = int32_t; }; diff --git a/test/include/RAJA_test-tensor.hpp b/test/include/RAJA_test-tensor.hpp index cf633098a9..d836e1463f 100644 --- a/test/include/RAJA_test-tensor.hpp +++ b/test/include/RAJA_test-tensor.hpp @@ -87,7 +87,9 @@ struct TensorTestHelper void exec(BODY const &body){ hipDeviceSynchronize(); - RAJA::forall>(RAJA::RangeSegment(0,64), + static constexpr int warp_size = RAJA::policy::hip::device_constants.WARP_SIZE; + + RAJA::forall>(RAJA::RangeSegment(0,warp_size), [=] RAJA_HOST_DEVICE (int ){ body(); });