diff --git a/aten/src/ATen/native/cuda/layer_norm_kernel.cu b/aten/src/ATen/native/cuda/layer_norm_kernel.cu index b0e0483fe3587..f0a4ee2696d6a 100644 --- a/aten/src/ATen/native/cuda/layer_norm_kernel.cu +++ b/aten/src/ATen/native/cuda/layer_norm_kernel.cu @@ -126,7 +126,11 @@ WelfordDataLN cuWelfordOnlineSum( { U delta = val - curr_sum.mean; U new_count = curr_sum.count + 1.f; +#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL) + U new_mean = curr_sum.mean + delta * __builtin_amdgcn_rcpf(new_count); +#else U new_mean = curr_sum.mean + delta * (1.f/new_count); //proper division is slow, this is less accurate but noticeably faster +#endif return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; } @@ -140,7 +144,11 @@ WelfordDataLN cuWelfordCombine( U count = dataA.count + dataB.count; U mean, sigma2; if (count > decltype(dataB.count){0}) { +#if defined(USE_ROCM) && defined(PYTORCH_LAYERNORM_FAST_RECIPROCAL) + auto coef = __builtin_amdgcn_rcpf(count); +#else auto coef = 1.f/count; //NB we don't use --use_fast_math, but this is emulation, 1./count goes to intrinsic, `* coef` is multiplication, instead of slow fp division +#endif auto nA = dataA.count * coef; auto nB = dataB.count * coef; mean = nA*dataA.mean + nB*dataB.mean; diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 504d6ee2243db..673ea502d3afe 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1066,6 +1066,22 @@ if(USE_ROCM) list(APPEND HIP_HIPCC_FLAGS -fdebug-info-for-profiling) endif(CMAKE_BUILD_TYPE MATCHES Debug) + # Get EnVar 'PYTORCH_LAYERNORM_FAST_RECIPROCAL' (or default to on). + if(DEFINED ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL}) + set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE $ENV{PYTORCH_LAYERNORM_FAST_RECIPROCAL}) + else() + set(PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE ON) + endif() + + set(PYTORCH_LAYERNORM_FAST_RECIPROCAL + ${PYTORCH_LAYERNORM_FAST_RECIPROCAL_CMAKE} + CACHE BOOL "Enable fast reciprocals within layer normalization." FORCE + ) + + if(PYTORCH_LAYERNORM_FAST_RECIPROCAL) + add_definitions(-DPYTORCH_LAYERNORM_FAST_RECIPROCAL) + endif() + # needed for compat with newer versions of hip-clang that introduced C++20 mangling rules list(APPEND HIP_HIPCC_FLAGS -fclang-abi-compat=17) diff --git a/setup.py b/setup.py index ea087c356c152..0fa7c74b41c95 100644 --- a/setup.py +++ b/setup.py @@ -135,6 +135,10 @@ # USE_ROCM_KERNEL_ASSERT=1 # Enable kernel assert in ROCm platform # +# PYTORCH_LAYERNORM_FAST_RECIPROCAL +# If set, enables the use of builtin functions for fast reciprocals (1/x) w.r.t. +# layer normalization. Default: enabled. +# # Environment variables we respect (these environment variables are # conventional and are often understood/set by other software.) #