diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu index 4aca753a510b8b..dad88e9737f77e 100644 --- a/aten/src/ATen/native/cuda/SoftMax.cu +++ b/aten/src/ATen/native/cuda/SoftMax.cu @@ -89,6 +89,26 @@ struct SoftMaxBackwardEpilogue { const AccumT sum; }; +#ifdef USE_ROCM +template +struct SoftMaxForwardWithMulEpilogue { + __device__ __forceinline__ SoftMaxForwardWithMulEpilogue(AccumT max_input, AccumT sum) + : max_input(max_input) + , sum(sum) {} + + __device__ __forceinline__ OutT operator()(T input) const { +#ifdef PYTORCH_USE_EXPF + return static_cast(__expf(input - max_input) * sum); +#else + return static_cast(std::exp(input - max_input) * sum); +#endif + } + + const AccumT max_input; + const AccumT sum; +}; +#endif + @@ -387,6 +407,21 @@ struct SumExpFloat const AccumT max_k; }; +#ifdef USE_ROCM +template +struct SumExpfFloat +{ + __device__ __forceinline__ SumExpfFloat(AccumT v) + : max_k(v) {} + + __device__ __forceinline__ AccumT operator()(AccumT sum, T v) const { + return sum + __expf(v - max_k); + } + + const AccumT max_k; +}; +#endif + template class Reduction, typename AccumT> __device__ __forceinline__ AccumT blockReduce(AccumT* smem, AccumT val, @@ -449,6 +484,18 @@ T blockReduceWarp(T* smem_cache, T value, const Reduction& op, T defaultVal) return smem_cache[0]; } +template class Reduction, typename T> +__device__ __forceinline__ +T blockReduceWarpInverse(T* smem_cache, T value, const Reduction& op, T defaultVal) +{ + T result = cuda_utils::BlockReduce>(value, op, defaultVal, smem_cache); + if (threadIdx.x == 0) { + smem_cache[0] = 1 / result; + } + __syncthreads(); + return smem_cache[0]; +} + template class Reduction, int ILP, typename T, typename AccumT, typename index_t=int> __device__ __forceinline__ AccumT ilpReduce(index_t shift, @@ -694,6 +741,71 @@ cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes) } } +template class Epilogue, typename index_t = int32_t> +__global__ void +cunn_SoftMaxForwardGmem(outscalar_t *output, const scalar_t *input, index_t classes) +{ + // Each thread block processes a sample in the batch + input += static_cast(blockIdx.x) * classes; + output += static_cast(blockIdx.x) * classes; + + accscalar_t threadMax = -at::numeric_limits::max(); + accscalar_t threadExp = static_cast(0); + + // The first smem segment is used to cache input values and the last + // segment is used for thread block reductions + extern __shared__ unsigned char smem[]; + auto smem_reduction_cache = reinterpret_cast(smem); + + using LoadT = at::native::memory::aligned_vector; + const LoadT* const input_vec_ptr = reinterpret_cast(input); + + // Do the first step in max calculation: + MaxFloat maxFunc; + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = input_vec_ptr[offset]; + #pragma unroll + for (int i = 0; i < ILP; ++i) { + threadMax = maxFunc(threadMax, crnt_vec.val[i]); + } + } + + accscalar_t max_k = blockReduceWarp(smem_reduction_cache, threadMax, + Max(), -at::numeric_limits::max()); + + // Do the second step in sum exp calculation: +#ifdef PYTORCH_USE_EXPF + SumExpfFloat sumExpFunc(max_k); +#else + SumExpFloat sumExpFunc(max_k); +#endif + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = input_vec_ptr[offset]; + #pragma unroll + for (int i = 0; i < ILP; ++i) { + threadExp = sumExpFunc(threadExp, crnt_vec.val[i]); + } + } + + accscalar_t sumAll = blockReduceWarpInverse(smem_reduction_cache, threadExp, + Add(), static_cast(0)); + + Epilogue epilogue(max_k, sumAll); + + using StoreT = at::native::memory::aligned_vector; + StoreT* output_vec_ptr = reinterpret_cast(output); + for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) { + LoadT crnt_vec = input_vec_ptr[offset]; + StoreT out_vec; + #pragma unroll + for (int i = 0; i < ILP; ++i) { + out_vec.val[i] = epilogue(crnt_vec.val[i]); + } + output_vec_ptr[offset] = out_vec; + } +} + template class Epilogue, typename index_t = int32_t> __global__ void @@ -858,6 +970,12 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); +#ifdef USE_ROCM + dim3 block(512); + size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + cunn_SoftMaxForwardGmem + <<>>(output_ptr, input_ptr, dim_size); +#else dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - @@ -876,6 +994,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } +#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -894,6 +1013,12 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t } } else { constexpr int ILP = sizeof(float4) / sizeof(scalar_t); +#ifdef USE_ROCM + dim3 block(512); + size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); + cunn_SoftMaxForwardGmem + <<>>(output_ptr, input_ptr, dim_size); +#else dim3 block = SoftMaxForward_getBlockSize(dim_size); size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t); auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock - @@ -912,6 +1037,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t cunn_SoftMaxForward <<>>(output_ptr, input_ptr, dim_size); } +#endif C10_CUDA_KERNEL_LAUNCH_CHECK(); }