Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPUAI-1250 - Flash Attention v2.04 two modules layer_norm cannot be used fixed #52

Open
wants to merge 1 commit into
base: flash_attention_for_rocm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions csrc/layer_norm/ln.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#pragma once

#include <unordered_map>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
//include <cuda_bf16.h>

#ifdef OLD_GENERATOR_PATH
#include <ATen/CUDAGeneratorImpl.h>
Expand All @@ -21,9 +22,9 @@ struct LaunchParams{
size_t workspace_bytes;
size_t barrier_size;

cudaDeviceProp * props;
hipDeviceProp_t* props;

cudaStream_t stream;
hipStream_t stream;

Params params;

Expand Down Expand Up @@ -179,8 +180,8 @@ extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS;

using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;

//using bf16 = nv_bfloat16;
using bf16 = hip_bfloat16;
////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T>
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
6 changes: 4 additions & 2 deletions csrc/layer_norm/ln_bwd_kernels.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#pragma once

#include "hip/hip_runtime.h"
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
Expand Down Expand Up @@ -501,6 +501,7 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params

if( Kernel_traits::SMEM_BYTES >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES));
//CHECK_CUDA(hipFuncSetAttribute(kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
Expand All @@ -511,7 +512,8 @@ void launch_(LaunchParams<BwdParams> &launch_params, const bool configure_params
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
//((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}

using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
23 changes: 15 additions & 8 deletions csrc/layer_norm/ln_fwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
#endif

#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
#include <curand_kernel.h>

//include <curand_kernel.h> //
#include<hiprand/hiprand_kernel.h>
#include "ln.h"
#include "ln_utils.cuh"
#include "ln_kernel_traits.h"
Expand Down Expand Up @@ -72,11 +72,13 @@ void ln_fwd_kernel(FwdParams params) {
const index_t *z_subset = static_cast<index_t *>(params.z_subset);

// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
curandStatePhilox4_32_10_t state;
//curandStatePhilox4_32_10_t state;
hiprandStatePhilox4_32_10_t state;
if (Is_dropout) {
auto seeds = at::cuda::philox::unpack(params.philox_args);
const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
//curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
hiprand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
}

const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
Expand Down Expand Up @@ -122,7 +124,8 @@ void ln_fwd_kernel(FwdParams params) {
// the more efficient curand_uniform4.
compute_t x_ij;
if (load_x0) {
mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
//mask_t keep = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
mask_t keep = !Is_dropout ? true : hiprand_uniform(&state) <= params.dropout_keep_p;
if (Is_dropout) { dmask.data.elt[jt] = keep; }
compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val;
x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
Expand Down Expand Up @@ -233,8 +236,10 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
auto kernel = &ln_fwd_kernel<Kernel_traits, IsDropoutConst, HasColscaleConst, HasSubsetConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
//CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
CHECK_CUDA(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
Expand All @@ -253,6 +258,7 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params

if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
// CHECK_CUDA(hipFuncSetAttribute(kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
Expand All @@ -263,7 +269,8 @@ void launch_(LaunchParams<FwdParams> &launch_params, const bool configure_params
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
//cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
});
});
Expand Down
3 changes: 2 additions & 1 deletion csrc/layer_norm/ln_parallel_residual_bwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,8 @@ void launch_parallel_residual_(LaunchParams<BwdParams> &launch_params, const boo
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
//cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES, stream);
}

using Kernel_traits_f = layer_norm::Kernel_traits_finalize<HIDDEN_SIZE,
Expand Down
25 changes: 17 additions & 8 deletions csrc/layer_norm/ln_parallel_residual_fwd_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
#endif

#include <ATen/cuda/detail/UnpackRaw.cuh> // For at::cuda::philox::unpack
#include <curand_kernel.h>
//include <curand_kernel.h>
#include <hiprand/hiprand_kernel.h>

#include "ln.h"
#include "ln_utils.cuh"
Expand Down Expand Up @@ -69,11 +70,13 @@ void ln_parallel_residual_fwd_kernel(FwdParams params) {
compute_t *rs_ptr = static_cast<compute_t *>(params.rs);

// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/cuda/Dropout.cu
curandStatePhilox4_32_10_t state;
//curandStatePhilox4_32_10_t state;
hiprandStatePhilox4_32_10_t state;
if (Is_dropout) {
auto seeds = at::cuda::philox::unpack(params.philox_args);
const index_t tidx_global = blockIdx.x * blockDim.x + threadIdx.x;
curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
//curand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
hiprand_init(std::get<0>(seeds), tidx_global, std::get<1>(seeds), &state);
}

const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG;
Expand Down Expand Up @@ -124,12 +127,14 @@ void ln_parallel_residual_fwd_kernel(FwdParams params) {
// TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use
// the more efficient curand_uniform4.
compute_t x_ij;
mask_t keep0 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
//mask_t keep0 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
mask_t keep0 = !Is_dropout ? true : hiprand_uniform(&state) <= params.dropout_keep_p;
if (Is_dropout) { dmask0.data.elt[jt] = keep0; }
compute_t x0_ij = compute_t(x0.data.elt[jt]);
x0_ij = keep0 ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f;
if (has_x1) {
mask_t keep1 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
//mask_t keep1 = !Is_dropout ? true : curand_uniform(&state) <= params.dropout_keep_p;
mask_t keep1 = !Is_dropout ? true : hiprand_uniform(&state) <= params.dropout_keep_p;
if (Is_dropout) { dmask1.data.elt[jt] = keep1; }
compute_t x1_ij = compute_t(x1.data.elt[jt]);
x1_ij = keep1 ? (Is_dropout ? x1_ij * params.dropout_scale : x1_ij) : 0.0f;
Expand Down Expand Up @@ -243,8 +248,10 @@ void launch_parallel_residual_(LaunchParams<FwdParams> &launch_params, const boo
auto kernel = &ln_parallel_residual_fwd_kernel<Kernel_traits, IsDropoutConst, TiedNormConst, IsEvenColsConst>;
if( configure_params ) {
int ctas_per_sm;
CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
//CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
// &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
CHECK_CUDA(hipOccupancyMaxActiveBlocksPerMultiprocessor(
&ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD));
launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW;
const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA;
launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS;
Expand All @@ -263,6 +270,7 @@ void launch_parallel_residual_(LaunchParams<FwdParams> &launch_params, const boo

if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) {
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
//CHECK_CUDA(hipfuncsetattribute(kernel, hipFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD));
}
auto stream = launch_params.stream;
auto ctas_per_col = launch_params.params.ctas_per_col;
Expand All @@ -273,7 +281,8 @@ void launch_parallel_residual_(LaunchParams<FwdParams> &launch_params, const boo
dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col);
dim3 block(Kernel_traits::THREADS_PER_CTA);
void *params_ = (void *)&launch_params.params;
cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
//cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
hipLaunchCooperativeKernel((void *)kernel, grid, block, (void **)&params_, Kernel_traits::SMEM_BYTES_FWD, stream);
}
});
});
Expand Down
78 changes: 50 additions & 28 deletions csrc/layer_norm/ln_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,24 @@

#include <cassert>

#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include "ln.h"

//include <cuda_bf16.h> //
//include <cuda_fp16.h> //
#include <hip/hip_fp16.h>
#include <hip/hip_bf16.h>
//#include "ln.h"
#include "ln_hip.h"

typedef __hip_bfloat162 nv_bfloat162;
typedef __hip_bfloat16 nv_bfloat16;
////////////////////////////////////////////////////////////////////////////////////////////////////

constexpr uint32_t THREADS_PER_WARP = 32;

////////////////////////////////////////////////////////////////////////////////////////////////////

inline void check_cuda_(cudaError_t status, const char *file, int line) {
if( status != cudaSuccess ) {
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line);
inline void check_cuda_(/*cudaError_t*/hipError_t status, const char *file, int line) {
if( status != hipSuccess ) {
fprintf(stderr, "hip Error: %s %s %d\n", /*cudaGetErrorString*/hipGetErrorString(status), file, line);
exit(status);
}
}
Expand Down Expand Up @@ -122,7 +126,8 @@ struct Sum {

template<typename T>
inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){
return __shfl_xor_sync(uint32_t(-1), x, idx);
// return __shfl_xor_sync(uint32_t(-1), x, idx); //这里需要修改吗//先暂且用_shrfl_xor代替
return __shfl_xor(uint32_t(-1), x, idx);
}

template<>
Expand All @@ -132,7 +137,8 @@ inline __device__ float2 warp_shuffle_xor<float2>(const float2 & x, uint32_t idx

template<typename T>
inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){
return __shfl_down_sync(uint32_t(-1), x, idx);
// return __shfl_down_sync(uint32_t(-1), x, idx); //这里需要修改吗///先暂且用_shrfl_down代替
return __shfl_down(uint32_t(-1), x, idx);
}

template<>
Expand Down Expand Up @@ -223,9 +229,12 @@ struct TypeToVec2<half> {
};

template<>
struct TypeToVec2<nv_bfloat16> {
struct TypeToVec2<hip_bfloat16>{
using Type = nv_bfloat162;
};
// struct TypeToVec2<nv_bfloat16> {
// using Type = nv_bfloat162;
// };

////////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down Expand Up @@ -275,21 +284,28 @@ struct Converter<float2, half2>{
}
};

//template<>
// struct Converter<float2, nv_bfloat162>{
// static inline __device__ nv_bfloat162 convert(const float2 &x) {
template<>
struct Converter<float2, nv_bfloat162>{
static inline __device__ nv_bfloat162 convert(const float2 &x) {
#if __CUDA_ARCH__ >= 800
return __float22bfloat162_rn(x);
#else
union {
nv_bfloat162 raw;
nv_bfloat16 x;
nv_bfloat16 y;
} tmp;
tmp.x = __float2bfloat16_rn(x.x);
tmp.y = __float2bfloat16_rn(x.y);
return tmp.raw;
#endif
// #if __CUDA_ARCH__ >= 800 //这里修改
// return __float22bfloat162_rn(x);
// #else
// union {
// //nv_bfloat162 raw;
// //nv_bfloat16 x;
// //nv_bfloat16 y;
// nv_bfloat162 raw;
// hip_bfloat16 x;
// hip_bfloat16 y;
// } tmp;
// tmp.x = __float2bfloat16_rn(x.x);
// tmp.y = __float2bfloat16_rn(x.y);
// return tmp.raw;
// #endif
return __float22bfloat162_rn(x);
}
};

Expand Down Expand Up @@ -372,9 +388,13 @@ struct InterCTASync {
}

inline __device__ void spin_wait_(int *barrier, int step, int expected) {
asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
// asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step));
//*barrier=*barrier+step;
atomicAdd(barrier,step);
for( int found = -1; found != expected; ) {
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
// asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier));
found=__hip_atomic_load(barrier, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT);
//found=*barrier;
}
}

Expand All @@ -394,7 +414,7 @@ struct InterCTASync {
spin_wait_(barrier, step, expected);
}
// CTA waits for thread 0
__syncthreads();
__syncthreads(); //这里是build in hip中的
}

int phase_counter_;
Expand Down Expand Up @@ -594,8 +614,10 @@ inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int nu
m2_a = m2_ab;
}
// Intra-warp broadcast (only lane 0 has valid stats).
m_a = __shfl_sync(uint32_t(-1), m_a, 0);
m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
// m_a = __shfl_sync(uint32_t(-1), m_a, 0);// 用shfl代替,如今先只能这样
// m2_a = __shfl_sync(uint32_t(-1), m2_a, 0);
m_a = __shfl(uint32_t(-1), m_a, 0);// 用shfl代替,如今先只能这样
m2_a = __shfl(uint32_t(-1), m2_a, 0);
}

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Loading