Skip to content

Commit

Permalink
Improve softmax performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
doru1004 committed Dec 3, 2024
1 parent 5672206 commit 5a4ae3d
Showing 1 changed file with 161 additions and 35 deletions.
196 changes: 161 additions & 35 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,24 @@ struct SoftMaxBackwardEpilogue {
const AccumT sum;
};

template<typename T, typename AccumT, typename OutT>
struct SoftMaxForwardWithMulEpilogue {
__device__ __forceinline__ SoftMaxForwardWithMulEpilogue(AccumT max_input, AccumT sum)
: max_input(max_input)
, sum(sum) {}

__device__ __forceinline__ OutT operator()(T input) const {
#ifdef PYTORCH_USE_EXPF
return static_cast<OutT>(__expf(input - max_input) * sum);
#else
return static_cast<OutT>(std::exp(input - max_input) * sum);
#endif
}

const AccumT max_input;
const AccumT sum;
};




Expand Down Expand Up @@ -387,6 +405,19 @@ struct SumExpFloat
const AccumT max_k;
};

template<typename T, typename AccumT>
struct SumExpfFloat
{
__device__ __forceinline__ SumExpfFloat(AccumT v)
: max_k(v) {}

__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + __expf(v - max_k);
}

const AccumT max_k;
};

template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
Expand Down Expand Up @@ -449,6 +480,18 @@ T blockReduceWarp(T* smem_cache, T value, const Reduction<T>& op, T defaultVal)
return smem_cache[0];
}

template <template<typename> class Reduction, typename T>
__device__ __forceinline__
T blockReduceWarpInverse(T* smem_cache, T value, const Reduction<T>& op, T defaultVal)
{
T result = cuda_utils::BlockReduce<T, Reduction<T>>(value, op, defaultVal, smem_cache);
if (threadIdx.x == 0) {
smem_cache[0] = 1 / result;
}
__syncthreads();
return smem_cache[0];
}

template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT, typename index_t=int>
__device__ __forceinline__ AccumT
ilpReduce(index_t shift,
Expand Down Expand Up @@ -694,6 +737,71 @@ cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes)
}
}

template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
__global__ void
cunn_SoftMaxForwardGmem(outscalar_t *output, const scalar_t *input, index_t classes)
{
// Each thread block processes a sample in the batch
input += static_cast<int64_t>(blockIdx.x) * classes;
output += static_cast<int64_t>(blockIdx.x) * classes;

accscalar_t threadMax = -at::numeric_limits<accscalar_t>::max();
accscalar_t threadExp = static_cast<accscalar_t>(0);

// The first smem segment is used to cache input values and the last
// segment is used for thread block reductions
extern __shared__ unsigned char smem[];
auto smem_reduction_cache = reinterpret_cast<accscalar_t*>(smem);

using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
const LoadT* const input_vec_ptr = reinterpret_cast<const LoadT*>(input);

// Do the first step in max calculation:
MaxFloat<scalar_t, accscalar_t> maxFunc;
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
#pragma unroll
for (int i = 0; i < ILP; ++i) {
threadMax = maxFunc(threadMax, crnt_vec.val[i]);
}
}

accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(smem_reduction_cache, threadMax,
Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());

// Do the second step in sum exp calculation:
#ifdef PYTORCH_USE_EXPF
SumExpfFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
#else
SumExpFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
#endif
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
#pragma unroll
for (int i = 0; i < ILP; ++i) {
threadExp = sumExpFunc(threadExp, crnt_vec.val[i]);
}
}

accscalar_t sumAll = blockReduceWarpInverse<Add, accscalar_t>(smem_reduction_cache, threadExp,
Add<accscalar_t>(), static_cast<accscalar_t>(0));

Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);

using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
StoreT* output_vec_ptr = reinterpret_cast<StoreT*>(output);
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
StoreT out_vec;
#pragma unroll
for (int i = 0; i < ILP; ++i) {
out_vec.val[i] = epilogue(crnt_vec.val[i]);
}
output_vec_ptr[offset] = out_vec;
}
}

template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
__global__ void
Expand Down Expand Up @@ -816,7 +924,7 @@ cunn_SoftMaxBackward(scalar_t *gradInput, const outscalar_t *output, const outsc
}
}

template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
template<template<typename, typename, typename> class Epilogue, template<typename, typename, typename> class EpilogueWithMul, bool is_log_softmax, bool uses_rocm>
Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){
if (half_to_float) {
TORCH_CHECK(input_.scalar_type() == ScalarType::Half, "conversion is supported for Half type only");
Expand Down Expand Up @@ -858,23 +966,30 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
}
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

bool can_use_smem = dim_size < max_elements_per_smem;
can_use_smem &= !(reinterpret_cast<const uintptr_t>(input_ptr) % ALIGN_BYTES);
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);

if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
if (uses_rocm && !is_log_softmax) {
dim3 block(512);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, scalar_t, EpilogueWithMul>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

bool can_use_smem = (size_t) dim_size < max_elements_per_smem;
can_use_smem &= !(reinterpret_cast<uintptr_t>(input_ptr) % ALIGN_BYTES);
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);

if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
}
}

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand All @@ -894,23 +1009,30 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
}
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

bool can_use_smem = dim_size < max_elements_per_smem;
can_use_smem &= !(reinterpret_cast<const uintptr_t>(input_ptr) % ALIGN_BYTES);
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);

if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
if (uses_rocm && !is_log_softmax) {
dim3 block(512);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, accscalar_t, EpilogueWithMul>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
smem_reduction_sz) / sizeof(scalar_t);

bool can_use_smem = (size_t) dim_size < max_elements_per_smem;
can_use_smem &= !(reinterpret_cast<uintptr_t>(input_ptr) % ALIGN_BYTES);
can_use_smem &= (!(reinterpret_cast<uintptr_t>(output_ptr) % ALIGN_BYTES));
can_use_smem &= !(dim_size % ILP);

if (can_use_smem) {
size_t smem_sz = dim_size * sizeof(scalar_t) + smem_reduction_sz;
cunn_SoftMaxForwardSmem<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_sz, stream>>>(output_ptr, input_ptr, dim_size);
} else {
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
}
}

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down Expand Up @@ -1069,7 +1191,7 @@ TORCH_IMPL_FUNC(log_softmax_cuda_out) (
const int64_t dim,
const bool half_to_float,
const Tensor &output) {
host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float, output);
host_softmax<LogSoftMaxForwardEpilogue, LogSoftMaxForwardEpilogue, true, false>(input, dim, half_to_float, output);
}

TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) (
Expand All @@ -1093,7 +1215,11 @@ TORCH_IMPL_FUNC(softmax_cuda_out) (
const int64_t dim,
const bool half_to_float,
const Tensor &output) {
host_softmax<SoftMaxForwardEpilogue,false>(input, dim, half_to_float, output);
#ifdef USE_ROCM
host_softmax<SoftMaxForwardEpilogue, SoftMaxForwardWithMulEpilogue, false, true>(input, dim, half_to_float, output);
#else
host_softmax<SoftMaxForwardEpilogue, SoftMaxForwardWithMulEpilogue, false, false>(input, dim, half_to_float, output);
#endif
}

TORCH_IMPL_FUNC(softmax_backward_cuda_out)
Expand Down

0 comments on commit 5a4ae3d

Please sign in to comment.