From c8459d199ddcea909f6ccd18ae4945cb19d3eb9e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 22 Jul 2023 16:06:35 +0200 Subject: [PATCH] Add the layer norm files. (#222) --- candle-kernels/README.md | 4 + candle-kernels/build.rs | 1 + candle-kernels/src/lib.rs | 1 + candle-kernels/src/ln.h | 274 +++++++++ candle-kernels/src/ln_fwd_256.cu | 15 + candle-kernels/src/ln_fwd_kernels.cuh | 257 +++++++++ candle-kernels/src/ln_kernel_traits.h | 172 ++++++ candle-kernels/src/ln_utils.cuh | 783 ++++++++++++++++++++++++++ candle-kernels/src/static_switch.h | 25 + 9 files changed, 1532 insertions(+) create mode 100644 candle-kernels/src/ln.h create mode 100644 candle-kernels/src/ln_fwd_256.cu create mode 100644 candle-kernels/src/ln_fwd_kernels.cuh create mode 100644 candle-kernels/src/ln_kernel_traits.h create mode 100644 candle-kernels/src/ln_utils.cuh create mode 100644 candle-kernels/src/static_switch.h diff --git a/candle-kernels/README.md b/candle-kernels/README.md index 1043f31ff6..a527dde6fc 100644 --- a/candle-kernels/README.md +++ b/candle-kernels/README.md @@ -2,3 +2,7 @@ This crate contains CUDA kernels used from candle. Some of these implementations come from the [dfdx crate](https://github.com/coreylowman/dfdx). + +The `ln*` files come from the [flash-attention +repo](https://github.com/Dao-AILab/flash-attention) and have been edited so as +to compile without including the PyTorch codebase. diff --git a/candle-kernels/build.rs b/candle-kernels/build.rs index 3c8e96a929..585184123b 100644 --- a/candle-kernels/build.rs +++ b/candle-kernels/build.rs @@ -184,6 +184,7 @@ mod cuda { let mut command = std::process::Command::new("nvcc"); command.arg(format!("--gpu-architecture=sm_{compute_cap}")) .arg("--ptx") + .arg("--expt-relaxed-constexpr") .args(["--default-stream", "per-thread"]) .args(["--output-directory", &out_dir]) // Flash attention only diff --git a/candle-kernels/src/lib.rs b/candle-kernels/src/lib.rs index b9d12b7ba3..848daee5eb 100644 --- a/candle-kernels/src/lib.rs +++ b/candle-kernels/src/lib.rs @@ -4,6 +4,7 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx")); pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx")); pub const EMBEDDINGS: &str = include_str!(concat!(env!("OUT_DIR"), "/embeddings.ptx")); pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx")); +pub const LN_FWD_256: &str = include_str!(concat!(env!("OUT_DIR"), "/ln_fwd_256.ptx")); pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx")); pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx")); pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx")); diff --git a/candle-kernels/src/ln.h b/candle-kernels/src/ln.h new file mode 100644 index 0000000000..3acf18ec9c --- /dev/null +++ b/candle-kernels/src/ln.h @@ -0,0 +1,274 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace layer_norm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LaunchParams{ + + size_t elts_per_thread; + size_t workspace_bytes; + size_t barrier_size; + + cudaDeviceProp * props; + + cudaStream_t stream; + + Params params; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ParamsBase { + ParamsBase() + : ctas_per_col(0) + , rows(0) + , cols(0) + , x(nullptr) + , mu(nullptr) + , rs(nullptr) + , gamma(nullptr) + , gamma1(nullptr) + , rowscale(nullptr) + , colscale(nullptr) + , dropout_keep_p(1.f) + , dropout_scale(1.f) + , is_rms_norm(false) + , workspace(nullptr) + , barrier(nullptr) + { + } + + // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. + int ctas_per_col; + + // Input is interpreted as matrix. We normalize across columns. + int rows; + int cols; + + // Common data pointers. + void *x0; + void *x1; + void *residual; + void *x; + void *dmask; + void *dmask1; + void *mu; + void *rs; + void *gamma; + void *gamma1; + void *rowscale; + void *colscale; + void *x0_subset; + void *z_subset; + + float inverse_cols; + + float dropout_keep_p; + float dropout_scale; + float rowscale_const; + + bool is_rms_norm; + + // Multi-CTA workspace in gmem. + void *workspace; + + // Multi-CTA sync barriers in gmem. + int *barrier; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct FwdParams : public ParamsBase { + FwdParams() + : ParamsBase() + , z(nullptr) + , z1(nullptr) + , beta(nullptr) + , beta1(nullptr) + , epsilon(0.f) + { + } + + // Output of LN FWD. + void *z; + void *z1; + void *beta; + void *beta1; + float epsilon; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct BwdParams : public ParamsBase { + BwdParams() + : ParamsBase() + , dz(nullptr) + , dz1(nullptr) + , dx(nullptr) + , dbeta_part(nullptr) + , dgamma_part(nullptr) + , dbeta1_part(nullptr) + , dgamma1_part(nullptr) + , dcolscale_part(nullptr) + , dx0(nullptr) + , dx1(nullptr) + , dresidual(nullptr) + , dbeta(nullptr) + , dgamma(nullptr) + , dbeta1(nullptr) + , dgamma1(nullptr) + , dcolscale(nullptr) + { + } + + // Input: gradient wrt. LN FWD output. + void *dz; + void *dz1; + // Input: gradient wrt residual. + void *dx; + + // Workspace for Wgrad pre-reduction. + void *dbeta_part; + void *dgamma_part; + void *dbeta1_part; + void *dgamma1_part; + void *dcolscale_part; + + // Output: Dgrad. + void *dx0; + void *dx1; + void *dresidual; + // Output: Wgrad. + void *dbeta; + void *dgamma; + void *dbeta1; + void *dgamma1; + void *dcolscale; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using FwdFunction = std::function&, const bool)>; +using BwdFunction = std::function&, const bool)>; +using FunctionKey = uint64_t; +using FwdRegistry = std::unordered_map; +using BwdRegistry = std::unordered_map; + +extern FwdRegistry FWD_FUNCS, PARALLEL_FWD_FUNCS; +extern BwdRegistry BWD_FUNCS, PARALLEL_BWD_FUNCS; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using fp32 = float; +using fp16 = half; +using bf16 = nv_bfloat16; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeId{}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 0; +}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 1; +}; + +template<> +struct TypeId{ + constexpr static uint32_t Value = 2; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Type2Key{ + constexpr static uint32_t Value = TypeId::Value << S; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct WeightType2Key : public Type2Key{}; + +template +struct InputType2Key : public Type2Key{}; + +template +struct ResidualType2Key : public Type2Key{}; + +template +struct OutputType2Key : public Type2Key{}; + +template +struct ComputeType2Key : public Type2Key{}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Types2Key{ + constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | ResidualType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; + constexpr static inline uint64_t get(const uint64_t hidden_size){ + constexpr uint64_t type_key = Value; + return (type_key << 32) | hidden_size; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FwdRegistrar{ + FwdRegistrar(FwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + FWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BwdRegistrar{ + BwdRegistrar(BwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + BWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FwdParallelRegistrar{ + FwdParallelRegistrar(FwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + PARALLEL_FWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BwdParallelRegistrar{ + BwdParallelRegistrar(BwdFunction f){ + uint64_t key = Types2Key::get(HIDDEN_SIZE); + PARALLEL_BWD_FUNCS.insert({ key, f }); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/candle-kernels/src/ln_fwd_256.cu b/candle-kernels/src/ln_fwd_256.cu new file mode 100644 index 0000000000..f3a541c6db --- /dev/null +++ b/candle-kernels/src/ln_fwd_256.cu @@ -0,0 +1,15 @@ +#include "ln_fwd_kernels.cuh" + +// Create forward launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, RYTPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG + +REGISTER_FWD_LAUNCHER( 256, fp32, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp32, fp32, fp32, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp32, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, bf16, bf16, fp32, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp32, bf16, bf16, bf16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, fp16, fp16, fp16, fp16, fp32, 1, 4, 1, 16); +REGISTER_FWD_LAUNCHER( 256, bf16, bf16, bf16, bf16, fp32, 1, 4, 1, 16); diff --git a/candle-kernels/src/ln_fwd_kernels.cuh b/candle-kernels/src/ln_fwd_kernels.cuh new file mode 100644 index 0000000000..faa64d05ed --- /dev/null +++ b/candle-kernels/src/ln_fwd_kernels.cuh @@ -0,0 +1,257 @@ +#pragma once + +#include + +#include "ln.h" +#include "ln_utils.cuh" +#include "ln_kernel_traits.h" +#include "static_switch.h" + +namespace layer_norm { + +template +__global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) +void ln_fwd_kernel(FwdParams params) { + + enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; + enum { WARPS_N = Ktraits::WARPS_N }; + enum { WARPS_M = Ktraits::WARPS_M }; + enum { THREADS_PER_ROW = Ktraits::THREADS_PER_ROW }; + enum { VEC_COLS_PER_LDG = Ktraits::VEC_COLS_PER_LDG }; + enum { BYTES_PER_ROW = Ktraits::BYTES_PER_ROW }; + enum { LDGS = Ktraits::LDGS }; + enum { NUM_ELTS = Ktraits::NUM_ELTS }; + enum { CTAS_PER_ROW = Ktraits::CTAS_PER_ROW }; + + using input_t = typename Ktraits::input_t; + using residual_t = typename Ktraits::residual_t; + using output_t = typename Ktraits::output_t; + using index_t = typename Ktraits::index_t; + using compute_t = typename Ktraits::compute_t; + using mask_t = typename Ktraits::mask_t; + using Ivec = typename Ktraits::Ivec; + using Rvec = typename Ktraits::Rvec; + using Ovec = typename Ktraits::Ovec; + using Wvec = typename Ktraits::Wvec; + using Cvec = typename Ktraits::Cvec; + using Mvec = typename Ktraits::Mvec; + + using Stats = typename Ktraits::Stats; + using stats_t = typename Stats::stats_t; + + const bool has_residual = params.residual != nullptr; + const bool save_x = has_residual || Is_dropout || Has_colscale || (params.rowscale != nullptr) || Has_subset || !(std::is_same::value); + + extern __shared__ char smem_[]; + + const index_t tidx = threadIdx.x; + const index_t bidn = blockIdx.x % CTAS_PER_ROW; + const index_t bidm = blockIdx.x / CTAS_PER_ROW; + const index_t lane = tidx % THREADS_PER_WARP; + const index_t warp = tidx / THREADS_PER_WARP; + const index_t warp_m = warp / WARPS_N; + const index_t warp_n = warp % WARPS_N; + + const index_t r = bidm * ROWS_PER_CTA + warp_m; + const index_t c = bidn * THREADS_PER_ROW + warp_n * THREADS_PER_WARP + lane; + + Stats stats(params, bidm, bidn, warp_m, warp_n, lane, smem_); + + compute_t *mu_ptr = static_cast(params.mu); + compute_t *rs_ptr = static_cast(params.rs); + + const input_t *rowscale = static_cast(params.rowscale); + const index_t *x0_subset = static_cast(params.x0_subset); + const index_t *z_subset = static_cast(params.z_subset); + + const index_t num_valid_ldgs = ((params.cols / Ktraits::ELTS_PER_LDG) - 1 - c + VEC_COLS_PER_LDG) / VEC_COLS_PER_LDG; + + Wvec gamma[LDGS]; + Wvec beta[LDGS]; + Wvec colscale[LDGS]; + index_t idx = c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + gamma[it].load_from(params.gamma, idx); + if (params.beta != nullptr) { + beta[it].load_from(params.beta, idx); + } else { + beta[it].zero_(); + } + if (Has_colscale) { colscale[it].load_from(params.colscale, idx); } + idx += VEC_COLS_PER_LDG; + } + } + + for( int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA ) { + const compute_t rowscale_val = !Has_subset ? (params.rowscale == nullptr ? 1.0f : compute_t(rowscale[row])) : params.rowscale_const; + const int row_x0 = !Has_subset ? row + 1 : x0_subset[row]; + const int row_z = !Has_subset ? row + 1 : z_subset[row]; + const bool load_x0 = !Has_subset || row_x0 > 0; + index_t idx_x = row * params.cols / Ktraits::ELTS_PER_LDG + c; + index_t idx_x0 = !Has_subset ? idx_x : (load_x0 ? (row_x0 - 1) * params.cols / Ktraits::ELTS_PER_LDG + c : 0); + compute_t xf[LDGS * NUM_ELTS]; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ivec x0; + Rvec residual; + Rvec x; + Mvec dmask; + if (load_x0) { x0.load_from(params.x0, !Has_subset ? idx_x : idx_x0); } + if (has_residual) { residual.load_from(params.residual, idx_x); } + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + // TD [2022-04-22]: We're memory bound, not compute bound, so we don't need to use + // the more efficient curand_uniform4. + compute_t x_ij; + if (load_x0) { + mask_t keep = true; + if (Is_dropout) { dmask.data.elt[jt] = keep; } + compute_t x0_ij = compute_t(x0.data.elt[jt]) * rowscale_val; + x0_ij = keep ? (Is_dropout ? x0_ij * params.dropout_scale : x0_ij) : 0.0f; + if (Has_colscale) { x0_ij *= compute_t(colscale[it].data.elt[jt]); } + x_ij = has_residual ? x0_ij + compute_t(residual.data.elt[jt]) : x0_ij; + } else { + x_ij = has_residual ? compute_t(residual.data.elt[jt]) : 0.f; + } + if (save_x) { x.data.elt[jt] = x_ij; } + xf[it * NUM_ELTS + jt] = x_ij; + } + if (save_x) { x.store_to(params.x, idx_x); } + if (Is_dropout && load_x0) { dmask.store_to(params.dmask, !Has_subset ? idx_x : idx_x0); } + idx_x += VEC_COLS_PER_LDG; + idx_x0 += VEC_COLS_PER_LDG; + } + } + + static_assert(CTAS_PER_ROW == 1, "Don't support multiple CTAs per row for now"); + const index_t num_vecs = params.cols / Ktraits::ELTS_PER_LDG; + const index_t num_full_ldgs = num_vecs / Ktraits::VEC_COLS_PER_LDG; + const index_t remaining_vecs = num_vecs % Ktraits::VEC_COLS_PER_LDG; + auto valid_elts_in_warp_fn = [num_full_ldgs, remaining_vecs] (int warp_n) -> int { + // Need to convert to int, otherwise the subtraction will wrap around. + const index_t valid_partial_vecs_in_warp = + std::min(std::max(int(remaining_vecs) - int(warp_n * THREADS_PER_WARP), int(0)), + int(THREADS_PER_WARP)); + return (num_full_ldgs * THREADS_PER_WARP + valid_partial_vecs_in_warp) * NUM_ELTS; + }; + stats_t s = stats.template compute( + xf, params.inverse_cols, valid_elts_in_warp_fn, num_valid_ldgs * NUM_ELTS + ); + + compute_t mu = layer_norm::Get<0>::of(s); + compute_t m2 = layer_norm::Get<1>::of(s); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + mu_ptr[row] = mu; + } + + compute_t rs = rsqrtf(m2 * params.inverse_cols + params.epsilon + (!params.is_rms_norm ? 0.f : mu * mu)); + + if( bidn == 0 && warp_n == 0 && lane == 0 ) { + rs_ptr[row] = rs; + } + + const bool save_z = !Has_subset || row_z > 0; + if (save_z) { + index_t idx_z = (!Has_subset ? row : (row_z - 1)) * params.cols / Ktraits::ELTS_PER_LDG + c; + #pragma unroll + for( int it = 0; it < LDGS; it++ ) { + if (Is_even_cols || (it < num_valid_ldgs)) { + Ovec z; + #pragma unroll + for( int jt = 0; jt < NUM_ELTS; jt++ ) { + compute_t y_ij = compute_t(rs * (xf[it * NUM_ELTS + jt] - (!params.is_rms_norm ? mu : 0.f))); + compute_t g_ij = gamma[it].data.elt[jt]; + compute_t b_ij = beta[it].data.elt[jt]; + z.data.elt[jt] = output_t(g_ij * y_ij + b_ij); + } + z.store_to(params.z, idx_z); + idx_z += VEC_COLS_PER_LDG; + } + } + } + + } +} + +} // namespace layer_norm + +using namespace layer_norm; + +template< + typename weight_t, + typename input_t, + typename residual_t, + typename output_t, + typename compute_t, + typename index_t, + int HIDDEN_SIZE, + int CTAS_PER_ROW, + int WARPS_M, + int WARPS_N, + int BYTES_PER_LDG +> +void launch_(LaunchParams &launch_params, const bool configure_params){ + + using Kernel_traits = Kernel_traits; + bool has_colscale = launch_params.params.colscale != nullptr; + bool has_subset = launch_params.params.x0_subset != nullptr; + bool is_even_cols = launch_params.params.cols == HIDDEN_SIZE; + BOOL_SWITCH(launch_params.params.dropout_keep_p < 1.f, IsDropoutConst, [&] { + BOOL_SWITCH(has_colscale, HasColscaleConst, [&] { + BOOL_SWITCH(has_subset, HasSubsetConst, [&] { + BOOL_SWITCH(is_even_cols, IsEvenColsConst, [&] { + auto kernel = &ln_fwd_kernel; + if( configure_params ) { + int ctas_per_sm; + CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); + launch_params.params.ctas_per_col = launch_params.props->multiProcessorCount * ctas_per_sm / Kernel_traits::CTAS_PER_ROW; + const size_t rows_per_loop = launch_params.params.ctas_per_col * Kernel_traits::ROWS_PER_CTA; + launch_params.elts_per_thread = (launch_params.params.rows + rows_per_loop - 1) / rows_per_loop * Kernel_traits::LDGS * Kernel_traits::NUM_ELTS; + launch_params.barrier_size = 0; + launch_params.workspace_bytes = 0; + if(Kernel_traits::CTAS_PER_ROW > 1) { + launch_params.barrier_size = 2 * launch_params.params.ctas_per_col; + launch_params.workspace_bytes = launch_params.params.ctas_per_col + * Kernel_traits::WARPS_M + * Kernel_traits::CTAS_PER_ROW + * sizeof(typename Kernel_traits::Stats::stats_t) + * 2; + } + return; + } + + if( Kernel_traits::SMEM_BYTES_FWD >= 48 * 1024 ) { + CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::SMEM_BYTES_FWD)); + } + auto stream = launch_params.stream; + auto ctas_per_col = launch_params.params.ctas_per_col; + + if( Kernel_traits::CTAS_PER_ROW == 1 ) { + kernel<<>>(launch_params.params); + } else { + dim3 grid(Kernel_traits::CTAS_PER_ROW * ctas_per_col); + dim3 block(Kernel_traits::THREADS_PER_CTA); + void *params_ = (void *)&launch_params.params; + cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, Kernel_traits::SMEM_BYTES_FWD, stream); + } + }); + }); + }); + }); +} diff --git a/candle-kernels/src/ln_kernel_traits.h b/candle-kernels/src/ln_kernel_traits.h new file mode 100644 index 0000000000..77de6bf9af --- /dev/null +++ b/candle-kernels/src/ln_kernel_traits.h @@ -0,0 +1,172 @@ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace layer_norm { +template< + uint32_t HIDDEN_SIZE_, + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t THREADS_PER_CTA_ +> +struct Kernel_traits_base { + + using weight_t = weight_t_; + using input_t = input_t_; + using residual_t = residual_t_; + using output_t = output_t_; + using compute_t = compute_t_; + using index_t = index_t_; + + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { THREADS_PER_CTA = THREADS_PER_CTA_ }; + enum { THREADS_PER_WARP = 32 }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + uint32_t HIDDEN_SIZE_, + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + bool Has_colscale, + uint32_t THREADS_PER_CTA_, + uint32_t BYTES_PER_LDG_, + typename Base = Kernel_traits_base +> +struct Kernel_traits_finalize : public Base { + enum { ROWS_PER_CTA = Base::THREADS_PER_CTA / Base::THREADS_PER_WARP }; + static_assert((int) ROWS_PER_CTA <= (int) Base::THREADS_PER_WARP); + // Bytes per global load from the input. + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + // Number of elements fetched by a global load. + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(compute_t_) }; + // Bytes per global store of the weights. + enum { BYTES_PER_STG = ELTS_PER_LDG * sizeof(weight_t_) }; + static_assert(sizeof(BYTES_PER_LDG) == 4, "Conflict-free smem transpose only implemented for 4B compute type!"); + static_assert(Base::THREADS_PER_CTA == ROWS_PER_CTA * Base::THREADS_PER_WARP, "We assume one warp per row!"); + // The total number of BYTES_PER_LDG-wide words in a hidden vector. + enum { COLS = HIDDEN_SIZE_ * sizeof(compute_t_) / BYTES_PER_LDG }; + static_assert(COLS * BYTES_PER_LDG == HIDDEN_SIZE_ * sizeof(compute_t_)); + + // Shared memory size to transpose the CTA result. + enum { SMEM_BYTES_TRANSPOSE = Base::THREADS_PER_CTA * BYTES_PER_LDG }; + // Shared memory size to coalsece the CTA result. + enum { SMEM_BYTES_OUTPUT = Base::THREADS_PER_WARP * BYTES_PER_LDG }; + // Shared memory requirement per CTA. + static constexpr int NUM_FACTORS = Has_colscale ? 3 : 2; + enum { SMEM_BYTES_PER_CTA = NUM_FACTORS * SMEM_BYTES_TRANSPOSE + NUM_FACTORS * SMEM_BYTES_OUTPUT }; + + // The type of the reducer. + using Reducer = layer_norm::Reducer; + + // Condition for the whole CTA to participate in syncthreads. + static_assert(COLS % Base::THREADS_PER_WARP == 0); + enum { CTAS = COLS / Base::THREADS_PER_WARP }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template< + typename weight_t_, + typename input_t_, + typename residual_t_, + typename output_t_, + typename compute_t_, + typename index_t_, + uint32_t HIDDEN_SIZE_, + uint32_t CTAS_PER_ROW_, + uint32_t WARPS_M_, + uint32_t WARPS_N_, + uint32_t BYTES_PER_LDG_ = 16, + typename Base = Kernel_traits_base< + HIDDEN_SIZE_, + weight_t_, + input_t_, + residual_t_, + output_t_, + compute_t_, + index_t_, + WARPS_M_*WARPS_N_*THREADS_PER_WARP + > +> +struct Kernel_traits : public Base { + + using input_t = typename Base::input_t; + using residual_t = typename Base::residual_t; + using weight_t = typename Base::weight_t; + using compute_t = typename Base::compute_t; + using output_t = typename Base::output_t; + using index_t = typename Base::index_t; + // using mask_t = unsigned char; + using mask_t = bool; + + enum { CTAS_PER_ROW = CTAS_PER_ROW_ }; + enum { WARPS_M = WARPS_M_ }; + enum { WARPS_N = WARPS_N_ }; + enum { COLS = HIDDEN_SIZE_ }; + enum { HIDDEN_SIZE = HIDDEN_SIZE_ }; + enum { BYTES_PER_LDG = BYTES_PER_LDG_ }; + enum { NUM_ELTS = BYTES_PER_LDG / sizeof(input_t) }; + + enum { THREADS_PER_ROW = WARPS_N * THREADS_PER_WARP }; + enum { THREADS_PER_CTA = WARPS_M * THREADS_PER_ROW }; + enum { ROWS_PER_CTA = WARPS_M }; + + enum { BYTES_PER_ROW = COLS * sizeof(input_t) }; + enum { BYTES_PER_ROW_PER_CTA = THREADS_PER_ROW * BYTES_PER_LDG }; + // Multi-row per CTA not supported for multi-CTA => no smem for WGRAD needed + enum { SMEM_BYTES_WGRAD = CTAS_PER_ROW > 1 ? 0 : ROWS_PER_CTA * COLS * sizeof(compute_t) }; + static_assert(WARPS_M == 1 || CTAS_PER_ROW == 1); + + using reduce_t = typename layer_norm::TypeToVec2::Type; + using Reducer = layer_norm::Reducer; + + enum { SMEM_BYTES_DGRAD = Reducer::SMEM_BYTES }; + enum { SMEM_BYTES = SMEM_BYTES_DGRAD + SMEM_BYTES_WGRAD }; + + using Ivec = layer_norm::Vec; + using Rvec = layer_norm::Vec; + using Ovec = layer_norm::Vec; + using Wvec = layer_norm::Vec; + using Cvec = layer_norm::Vec; + using Mvec = layer_norm::Vec; + enum { ELTS_PER_LDG = BYTES_PER_LDG / sizeof(input_t) }; + + // Assume that each thread can handle the same number of elements in the output and weights as in the input. + static_assert(sizeof(input_t) == sizeof(output_t)); + static_assert(sizeof(input_t) <= sizeof(residual_t)); + // The number of columns fetched per load from input: one per thread. + enum { VEC_COLS_PER_LDG = CTAS_PER_ROW * THREADS_PER_ROW }; + // The total number of vectorized loads/stores per hidden vector. + enum { VEC_COLS = COLS / ELTS_PER_LDG }; + // The number of loads per thread for the input. + enum { LDGS = VEC_COLS / VEC_COLS_PER_LDG }; + static_assert(LDGS * VEC_COLS_PER_LDG == VEC_COLS); + //static_assert(LDGS * BYTES_PER_ROW_PER_CTA * CTAS_PER_ROW == BYTES_PER_ROW, ""); + + using Stats = layer_norm::Stats; + enum { SMEM_BYTES_FWD = Stats::SMEM_BYTES }; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/candle-kernels/src/ln_utils.cuh b/candle-kernels/src/ln_utils.cuh new file mode 100644 index 0000000000..178d6fda89 --- /dev/null +++ b/candle-kernels/src/ln_utils.cuh @@ -0,0 +1,783 @@ +#pragma once + +#include + +#include +#include + +#include "ln.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +constexpr uint32_t THREADS_PER_WARP = 32; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void check_cuda_(cudaError_t status, const char *file, int line) { + if( status != cudaSuccess ) { + fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(status), file, line); + exit(status); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(ans) \ + { check_cuda_((ans), __FILE__, __LINE__); } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define DIVUP(x, y) (((x) + ((y)-1)) / (y)) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_( \ + launch_params, configure_params); \ + } \ + static FwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_BWD_LAUNCHER( \ + HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_(launch_params, configure_params); \ + } \ + static BwdRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_PARALLEL_FWD_LAUNCHER(HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG) \ + void ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_parallel_residual_( \ + launch_params, configure_params); \ + } \ + static FwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_parallel_residual_fwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define REGISTER_PARALLEL_BWD_LAUNCHER( \ + HIDDEN_SIZE, WTYPE, ITYPE, RTYPE, OTYPE, CTYPE, CTAS_PER_ROW, WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINALIZE) \ + void ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE(LaunchParams &launch_params, \ + const bool configure_params) { \ + launch_parallel_residual_(launch_params, configure_params); \ + } \ + static BwdParallelRegistrar reg_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE( \ + ln_parallel_residual_bwd_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##RTYPE##_##OTYPE##_##CTYPE) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void operator+=(float2 & a, const float2 & b){ + a.x += b.x; + a.y += b.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sum { + inline __device__ Sum(){} + inline __device__ T operator()(const T &a, const T &b){ + return a + b; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ T warp_shuffle_xor(const T & x, uint32_t idx){ + return __shfl_xor_sync(uint32_t(-1), x, idx); +} + +template<> +inline __device__ float2 warp_shuffle_xor(const float2 & x, uint32_t idx){ + return { warp_shuffle_xor(x.x, idx), warp_shuffle_xor(x.y, idx) }; +} + +template +inline __device__ T warp_shuffle_down(const T & x, uint32_t idx){ + return __shfl_down_sync(uint32_t(-1), x, idx); +} + +template<> +inline __device__ float2 warp_shuffle_down(const float2 & x, uint32_t idx){ + return { warp_shuffle_down(x.x, idx), warp_shuffle_down(x.y, idx) }; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace layer_norm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint16 { + uint4 u; + uint4 v; + uint4 s; + uint4 t; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct uint8 { + uint4 u; + uint4 v; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BytesToType {}; + +template<> +struct BytesToType<64> { + using Type = uint16; + static_assert(sizeof(Type) == 64); +}; + +template<> +struct BytesToType<32> { + using Type = uint8; + static_assert(sizeof(Type) == 32); +}; + +template<> +struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> +struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> +struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> +struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> +struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TypeToVec2 {}; + +template<> +struct TypeToVec2 { + using Type = float2; +}; + +template<> +struct TypeToVec2 { + using Type = half2; +}; + +template<> +struct TypeToVec2 { + using Type = nv_bfloat162; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Get { + template + static inline __device__ R of(const T &vec); +}; + +template<> +template +inline __device__ R Get<0>::of(const T &vec) { + return vec.x; +} + +template<> +template +inline __device__ R Get<1>::of(const T &vec) { + return vec.y; +} + +template<> +template +inline __device__ R Get<2>::of(const T &vec) { + return vec.z; +} + +template<> +template +inline __device__ R Get<3>::of(const T &vec) { + return vec.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ Dst convert(const Src &from) { + return Dst(from); + } +}; + +template<> +struct Converter{ + static inline __device__ half2 convert(const float2 &x) { + return __float22half2_rn(x); + } +}; + +template<> +struct Converter{ + static inline __device__ nv_bfloat162 convert(const float2 &x) { +#if __CUDA_ARCH__ >= 800 + return __float22bfloat162_rn(x); +#else + union { + nv_bfloat162 raw; + nv_bfloat16 x; + nv_bfloat16 y; + } tmp; + tmp.x = __float2bfloat16_rn(x.x); + tmp.y = __float2bfloat16_rn(x.y); + return tmp.raw; +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Zeros{ + static inline __device__ T get() { + return T(0.f); + } +}; + +template<> +struct Zeros{ + static inline __device__ float2 get() { + return make_float2(0.f, 0.f); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Vec { + + enum { BYTES = NUM_ELT * sizeof(Elt_type) }; + + using Vec_type = typename BytesToType::Type; + + using Alias_type = union { + Vec_type vec; + Elt_type elt[NUM_ELT]; + }; + + Alias_type data; + + template + inline __device__ void to(Vec &other) { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + other.data.elt[it] = S(this->data.elt[it]); + } + } + + template + inline __device__ void assign(const Op &op) { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + this->data.elt[it] = op(it); + } + } + + inline __device__ void zero_() { + #pragma unroll + for( int it = 0; it < NUM_ELT; it++ ) { + this->data.elt[it] = Elt_type(0.f); + } + } + + inline __device__ void load_from(const void *base_ptr, const size_t idx) { + this->data.vec = static_cast(base_ptr)[idx]; + } + + inline __device__ void store_to(void *base_ptr, const size_t idx) { + static_cast(base_ptr)[idx] = this->data.vec; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct InterCTASync { + + template + inline __device__ InterCTASync(Params & params, uint32_t bidm, uint32_t bidn) + : phase_counter_(0) + , b0_(params.barrier + bidm) // The barrier for this group of CTAs. + , b1_(params.barrier + bidm + params.ctas_per_col) // The barrier for this group of CTAs. + { + // BARRIERS ARE ASSUMED TO BE INITIALIZED TO 0! + } + + inline __device__ void spin_wait_(int *barrier, int step, int expected) { + asm volatile("red.release.gpu.global.add.s32 [%0], %1;" ::"l"(barrier), "r"(step)); + for( int found = -1; found != expected; ) { + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];" : "=r"(found) : "l"(barrier)); + } + } + + inline __device__ void sync(){ + // ALL THREADS MUST ENTER! + + // We switch barrier every iteration. + int *barrier = phase_counter_ & 0x1 ? b1_ : b0_; + // We decrement every other iteration. + bool dec = phase_counter_ & 0x2; + int step = dec ? -1 : 1; + int expected = dec ? 0 : CTAS_PER_ROW; + // There are only 4 phases: up/down for b0/b1. + phase_counter_ = (phase_counter_ + 1) & 0x3; + + if( threadIdx.x == 0 ) { + spin_wait_(barrier, step, expected); + } + // CTA waits for thread 0 + __syncthreads(); + } + + int phase_counter_; + int * b0_; + int * b1_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + + using InterCTASync = InterCTASync; + using Base = Reducer; + using Type = typename Base::Type; + + enum { SMEM_BYTES = Base::SMEM_BYTES }; + + enum { WS_BARRIER_BYTES = 2 * sizeof(int) }; + enum { WS_DATA_BYTES = WARPS_M * CTAS_PER_ROW * sizeof(T) }; + + // size of the barriers + temporary result per CTA (multiply with CTAS_PER_ROW to get total) + enum { WORKSPACE_BYTES_PER_GROUP = Base::WORKSPACE_BYTES_PER_GROUP + WS_BARRIER_BYTES + WS_DATA_BYTES }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + , inter_cta_(params, bidm, bidn) + , bidn_(bidn) // CTA id within the group. + , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) + , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) + { + } + + template + inline __device__ T allreduce(T data, Op &op) { + data = Base::reduce(data, op); + // We switch workspace every iteration. + T *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + // Warp leaders 0 hold the CTA-local results. + if( this->warp_n_ == 0 && this->lane_ == 0 ) { + workspace[bidn_] = data; + } + inter_cta_.sync(); + static_assert(CTAS_PER_ROW <= 32); + T total = Zeros::get(); + if(this->lane_ < CTAS_PER_ROW){ + total = workspace[this->lane_]; + } + total = Reducer::allreduce_(total, op); + + return total; + } + + InterCTASync inter_cta_; + + T *w0_; + T *w1_; + int bidn_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer { + + using Type = T; + enum { SMEM_BYTES = 0 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : warp_n_(warp_n) + , lane_(lane) + { + } + + template + static inline __device__ T allreduce_(T data, Op &op) { + #pragma unroll + for( int it = 1; it < THREADS_PER_WARP; it *= 2 ) { + data = op(data, warp_shuffle_xor(data, it)); + } + return data; + } + + template + inline __device__ T allreduce(T data, Op &op) { + return allreduce_(data, op); + } + + template + inline __device__ T reduce(T data, Op &op){ + // only lane 0 holds the result! + #pragma unroll + for( int it = THREADS_PER_WARP / 2; it > 0; it /= 2 ) { + data = op(data, warp_shuffle_down(data, it)); + } + return data; + } + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Reducer : public Reducer { + + using Base = Reducer; + + using Type = T; + + enum { SMEM_BYTES = Base::SMEM_BYTES + WARPS_M * WARPS_N * sizeof(T) * 2 }; + enum { WORKSPACE_BYTES_PER_GROUP = 0 }; + + enum { THREADS_PER_WARP = 32 }; + + template + inline __device__ Reducer(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : Base(params, bidm, bidn, warp_m, warp_n, lane, smem) + , use0_(true) + { + smem0_ = &static_cast(smem)[warp_m * WARPS_N]; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ T allreduce(T data, Op & op) { + T * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + data = Base::reduce(data, op); + if( this->lane_ == 0 ) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + #pragma unroll + for( int it = 0; it < WARPS_N; it++ ) { + out = op(out, smem[it]); + } + return out; + } + + template + inline __device__ T reduce(T data, Op &op) { + T * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // only intra-CTA group leader holds the result! + data = Base::reduce(data, op); + if( this->lane_ == 0 ) { + smem[this->warp_n_] = data; + } + __syncthreads(); + T out = Zeros::get(); + if( this->warp_n_ == 0 && this->lane_ == 0 ) { + #pragma unroll + for( int it = 0; it < WARPS_N; it++ ) { + out = op(out, smem[it]); + } + } + return out; + } + + T * smem0_; + T * smem1_; + bool use0_; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void warp_chan_upd_dynamic(T &m_a, T &m2_a, int_t &n_a, int num_active){ + //Assume at least leftmost is valid and init: step = next_pow2(num_active) / 2 (might get NaN otherwise) + const int highest_bit_set = (8 * sizeof(num_active)) - __clz(num_active - 1); + + #pragma unroll + for( int step = (1 << (highest_bit_set - 1)); step > 0; step /= 2 ) { + // Exchange + int_t n_b = warp_shuffle_down(n_a, step); + T m_b = warp_shuffle_down(m_a, step); + T m2_b = warp_shuffle_down(m2_a, step); + + // Update + const int_t n_ab = n_a + n_b; // We can handle one of them being 0, not both. + const T rn_ab = 1.f / n_ab; // Might have different n per thread, otherwise this would simplify :( + const T delta = m_a - m_b; + const float m2_ab = m2_a + m2_b + delta * delta * n_a * n_b * rn_ab; + const float m_ab = (n_a * m_a + n_b * m_b) * rn_ab; + + n_a = n_ab; + m_a = m_ab; + m2_a = m2_ab; + } + // Intra-warp broadcast (only lane 0 has valid stats). + m_a = __shfl_sync(uint32_t(-1), m_a, 0); + m2_a = __shfl_sync(uint32_t(-1), m2_a, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + // This could be done generically with the Reducer. But then we would have to exchange 3 instead of 2 fields. + + using InterCTASync = InterCTASync; + using BlockStats = Stats; + using stats_t = typename BlockStats::stats_t; + + enum { SMEM_BYTES = BlockStats::SMEM_BYTES }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : inter_cta_(params, bidm, bidn) + , block_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) + , bidn_(bidn) // CTA id within the group. + , w0_(static_cast(params.workspace) + (bidm * WARPS_M + warp_m) * CTAS_PER_ROW) + , w1_(w0_ + params.ctas_per_col * WARPS_M * CTAS_PER_ROW) + , warp_n_(warp_n) + , lane_(lane) + { + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T rn) { + constexpr T ELTS_PER_ROW_PER_CTA = N * WARPS_N * THREADS_PER_WARP; + // TODO rn is not really needed here.. + constexpr T block_rn = 1.f / T(ELTS_PER_ROW_PER_CTA); + stats_t block_stats = block_stats_.compute(elts, block_rn); + + stats_t *workspace = inter_cta_.phase_counter_ & 0x1 ? w1_ : w0_; + + if( warp_n_ == 0 && lane_ == 0 ) { + workspace[bidn_] = block_stats; + } + + // Wait for all CTAS_PER_ROW CTAS in the group to have written their result. + inter_cta_.sync(); + + T n = Zeros::get(); + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume CTA group size in N less than 32, such that we can finalize with a single warp. + static_assert(CTAS_PER_ROW <= 32); + + // Every warp does the final reduction locally. + if( lane_ < CTAS_PER_ROW ) { + stats_t result = workspace[lane_]; + n = ELTS_PER_ROW_PER_CTA; + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, CTAS_PER_ROW); + + return { m, m2 }; + } + + InterCTASync inter_cta_; + BlockStats block_stats_; + + stats_t *w0_; + stats_t *w1_; + int bidn_; + int warp_n_; + int lane_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + + using WarpStats = Stats; + using stats_t = typename WarpStats::stats_t; + + enum { SMEM_BYTES = WARPS_M * WARPS_N * sizeof(stats_t) * 2 }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : warp_stats_(params, bidm, bidn, warp_m, warp_n, lane, smem) + , use0_(true) + { + smem0_ = static_cast(smem) + warp_m * WARPS_N; + smem1_ = smem0_ + WARPS_M * WARPS_N; + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, + function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { + stats_t * smem = use0_ ? smem0_ : smem1_; + use0_ = !use0_; + // Compute warp local for all WARPS_N + const auto warp_n = warp_stats_.reducer_.warp_n_; + const T warp_norm_factor = 1.f / T(Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(warp_n)); + stats_t warp_stats = warp_stats_.template compute( + elts, warp_norm_factor, valid_elts_in_warp_fn, num_valid_elts + ); + + //Each warp warp leader stores its stats + const auto lane = warp_stats_.reducer_.lane_; + if( lane == 0 ) { + smem[warp_n] = warp_stats; + } + __syncthreads(); + + int n = 0;; + T m = Zeros::get(); + T m2 = Zeros::get(); + + // Assume that there are less than 32 warps, such that we can finalize with a single warp + static_assert(WARPS_N <= 32); + if(lane < WARPS_N){ + stats_t result = smem[lane]; + n = Is_even_cols ? N * THREADS_PER_WARP : valid_elts_in_warp_fn(lane); + m = layer_norm::Get<0>::of(result); + m2 = layer_norm::Get<1>::of(result); + } + + warp_chan_upd_dynamic(m, m2, n, WARPS_N); + + return { m, m2 }; + } + WarpStats warp_stats_; + stats_t * smem0_; + stats_t * smem1_; + bool use0_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Stats { + + using stats_t = typename TypeToVec2::Type; + // The simple Warp reducer. + using Reducer = Reducer; + + enum { SMEM_BYTES = 0 }; + + template + inline __device__ Stats(Params & params, uint32_t bidm, uint32_t bidn, uint32_t warp_m, uint32_t warp_n, uint32_t lane, void * smem) + : reducer_(params, bidm, bidn, warp_m, warp_n, lane, smem) + { + } + + template + inline __device__ stats_t compute(const T (&elts)[N], const T row_norm_factor, + // const int valid_elts_in_warp_ignored_, const int num_valid_elts = N) { + function_t valid_elts_in_warp_fn, const int num_valid_elts = N) { + + auto sum = Sum(); + + T m = Zeros::get(); + #pragma unroll + for( int it = 0; it < N; it++ ) { + if (Is_even_cols || (it < num_valid_elts)) { + m += elts[it]; + } + } + m = reducer_.allreduce(m, sum) * row_norm_factor; + + T m2 = Zeros::get(); + #pragma unroll + for( int it = 0; it < N; it++ ) { + if (Is_even_cols || (it < num_valid_elts)) { + T diff = (elts[it] - m); + m2 += diff * diff; + } + } + m2 = reducer_.allreduce(m2, sum); + + return {m, m2}; + } + + Reducer reducer_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layer_norm diff --git a/candle-kernels/src/static_switch.h b/candle-kernels/src/static_switch.h new file mode 100644 index 0000000000..7920ac045d --- /dev/null +++ b/candle-kernels/src/static_switch.h @@ -0,0 +1,25 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }()