Skip to content

Commit

Permalink
Revert "Enable bfloat16 for micro sdpa kernel (#2344)"
Browse files Browse the repository at this point in the history
This reverts commit f145cbe.
  • Loading branch information
h-sadia authored Jan 11, 2025
1 parent f145cbe commit 4a0e123
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 95 deletions.
78 changes: 36 additions & 42 deletions src/gpu/intel/ocl/micro_sdpa.cl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*******************************************************************************/

#include "gpu/intel/ocl/ocl_types.h"
#include "gpu/intel/ocl/sdpa_utils.h"
#include "gpu/intel/ocl/tile_ops.h"

Expand All @@ -37,31 +38,25 @@
typedef ugemm_kq_c_type s_tile_type;
typedef ugemm_vs_c_type a_tile_type;

#ifdef QRY_DT_F16
#define VEC_TYPE2 half2
#else // data type is bf16
#define VEC_TYPE2 ushort2
#endif

DECLARE_2D_TILE(q_tile_type, uint, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)

#ifdef BLOCK_Q
DECLARE_2D_TILE_BLOCK_OPS(
q_tile_type, uint, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
#elif Q_ALIGN < 4
DECLARE_2D_TILE_LOAD_PACKED_VEC(q_tile_type, QRY_DATA_T, VEC_TYPE2,
SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
DECLARE_2D_TILE_LOAD_PACKED_HALF(
q_tile_type, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
#endif

#ifdef BLOCK_A
DECLARE_2D_TILE(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, ugemm_vs_sg_tile_m,
1, 1, ugemm_vs_sg_tile_n)
DECLARE_2D_TILE(a_tile_type_half, half, SUBGROUP_SIZE, ugemm_vs_sg_tile_m, 1, 1,
ugemm_vs_sg_tile_n)
#else
DECLARE_2D_TILE(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, ugemm_vs_sg_tile_m,
8, 1, ugemm_vs_sg_tile_n / 8)
DECLARE_2D_TILE(a_tile_type_half, half, SUBGROUP_SIZE, ugemm_vs_sg_tile_m, 8, 1,
ugemm_vs_sg_tile_n / 8)
#endif

DECLARE_2D_TILE(s_tile_type_packed, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
DECLARE_2D_TILE(s_tile_type_half2, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1 / 2, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1)

Expand All @@ -83,34 +78,34 @@ DECLARE_2D_TILE(
#define mask_nbc ugemm_kq_c_type_nblock1
#endif

DECLARE_2D_TILE(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br, mask_bc,
mask_nbr, mask_nbc)
DECLARE_2D_TILE(mask_tile_type, half, SUBGROUP_SIZE, mask_br, mask_bc, mask_nbr,
mask_nbc)
DECLARE_2D_TILE(mask_tile_type_float, float, SUBGROUP_SIZE, mask_br, mask_bc,
mask_nbr, mask_nbc)

#if BROADCAST_MASK_Q
DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br,
mask_bc, mask_nbr, mask_nbc)
DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, half, SUBGROUP_SIZE, mask_br, mask_bc,
mask_nbr, mask_nbc)
#endif

#ifdef BLOCK_A
DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE,
DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_half, half, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n)
#endif
#ifdef BLOCK_2D_A
DECLARE_2D_TILE_BLOCK2D_OPS(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE,
DECLARE_2D_TILE_BLOCK2D_OPS(a_tile_type_half, half, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 8, 1, ugemm_vs_sg_tile_n / 8)
#endif

#ifdef BLOCK_A
DECLARE_2D_TILE_COPY_REBLOCK(a_tile_type, SUBGROUP_SIZE, ugemm_vs_c_type_block0,
ugemm_vs_c_type_block1, ugemm_vs_c_type_nblock0,
ugemm_vs_c_type_nblock1, a_tile_type_dst, SUBGROUP_SIZE,
ugemm_vs_c_type_nblock1, a_tile_type_half, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n)
#else
DECLARE_2D_TILE_COPY_REBLOCK(a_tile_type, SUBGROUP_SIZE, ugemm_vs_c_type_block0,
ugemm_vs_c_type_block1, ugemm_vs_c_type_nblock0,
ugemm_vs_c_type_nblock1, a_tile_type_dst, SUBGROUP_SIZE,
ugemm_vs_c_type_nblock1, a_tile_type_half, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 8, 1, ugemm_vs_sg_tile_n / 8)
#endif

Expand Down Expand Up @@ -165,10 +160,10 @@ DECLARE_2D_TILE_RSELECT(a_scale_tile_type, SUBGROUP_SIZE, ugemm_vs_sg_tile_n, 1,
#define binary_add(x, y) ((x) + (y))

__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void
micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
const global VAL_DATA_T *V, global DST_DATA_T *A,
const global SCALE_DATA_T *scale_ptr, const global MSK_DATA_T *msk,
int d, int k, int q, const global KEY_ATTR_SCALES_DATA_T *K_scales,
micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
const global VAL_DATA_T *V, global half *A,
const global SCALE_DATA_T *scale_ptr, const global half *msk, int d,
int k, int q, const global KEY_ATTR_SCALES_DATA_T *K_scales,
const global KEY_ATTR_ZP_DATA_T *K_zp,
const global VAL_ATTR_SCALES_DATA_T *V_scales,
const global VAL_ATTR_ZP_DATA_T *V_zp) {
Expand Down Expand Up @@ -200,9 +195,8 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
uint sg_j_vs = sg_ij / ugemm_vs_sg_per_wg_m;

/* SLM allocations -- place in one array to work around compiler bug */
#define Q_slm_size (D_MAX * ugemm_kq_wg_tile_n * sizeof(QRY_DATA_T))
#define S_slm_size \
(ugemm_kq_wg_tile_m * ugemm_kq_wg_tile_n * sizeof(QRY_DATA_T))
#define Q_slm_size (D_MAX * ugemm_kq_wg_tile_n * sizeof(half))
#define S_slm_size (ugemm_kq_wg_tile_m * ugemm_kq_wg_tile_n * sizeof(half))
#define S_sum_slm_size \
(ugemm_kq_wg_tile_n * ugemm_kq_sg_per_wg_m * sizeof(float))
#define S_max_slm_size (ugemm_kq_wg_tile_n * sizeof(float))
Expand All @@ -211,8 +205,8 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
local char slm[Q_slm_size + S_slm_size + S_sum_slm_size + S_max_slm_size
+ ugemm_slm_size];

local QRY_DATA_T *Q_slm = (local QRY_DATA_T *)&slm[0];
local QRY_DATA_T *S_slm = (local QRY_DATA_T *)&slm[Q_slm_size];
local half *Q_slm = (local half *)&slm[0];
local half *S_slm = (local half *)&slm[Q_slm_size];
local float *S_sum_slm = (local float *)&slm[Q_slm_size + S_slm_size];
local float *S_max_slm
= (local float *)&slm[Q_slm_size + S_slm_size + S_sum_slm_size];
Expand Down Expand Up @@ -267,16 +261,16 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
tile_load(&Q_tile, (global uint *)Q, (d + 1) >> 1, q, ldq >> 1, 0,
wg_j0 + q0_copy);
#else
tile_load_packed_vec2(&Q_tile, Q, d, q, ldq, 0, wg_j0 + q0_copy);
tile_load_packed_half(&Q_tile, Q, d, q, ldq, 0, wg_j0 + q0_copy);
#endif

/* Load scale */
#if WITH_ATTN_SCALE
#if INVERT_SCALE
float iscale = CONVERT_FLOAT_T(*scale_ptr);
float iscale = convert_float(*scale_ptr);
float scale = native_recip(iscale);
#else
float scale = CONVERT_FLOAT_T(*scale_ptr);
float scale = convert_float(*scale_ptr);
float iscale = native_recip(scale);
#endif
#else
Expand Down Expand Up @@ -526,12 +520,12 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
tile_fill(S_sum_tile1, 0.0f);
tile_vreduce_add(S_tile, &S_sum_tile1);

/* Convert to half or bf16, VNNI format */
s_tile_type_packed S_tile_packed;
tile_copy_to_vec2(S_tile, S_tile_packed, VEC_TYPE2);
/* Convert to half, VNNI format */
s_tile_type_half2 S_tile_half2;
tile_copy_to_half2(S_tile, S_tile_half2);

/* Store to SLM, in packed format */
tile_store_t_sys_src2(S_tile_packed, (local uint *)S_slm,
tile_store_t_sys_src2(S_tile_half2, (local uint *)S_slm,
ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m / 2, sg_i0_kq / 2,
sg_j0_kq);
intel_work_group_barrier_arrive(CLK_LOCAL_MEM_FENCE);
Expand Down Expand Up @@ -709,17 +703,17 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
tile_hbroadcast_mul(&A_tile, A_scale_tile);

/* Convert to half precision and store */
a_tile_type_dst A_tile_dst;
tile_copy_reblock(A_tile, &A_tile_dst);
a_tile_type_half A_tile_half;
tile_copy_reblock(A_tile, &A_tile_half);

uint sg_i0_vs = sg_i_vs * ugemm_vs_sg_tile_m;
uint sg_j0_vs = sg_j_vs * ugemm_vs_sg_tile_n + wg_j0;

#ifdef BLOCK_2D_A
tile_store_block2d(A_tile_dst, A, d, q, lda, sg_i0_vs, sg_j0_vs);
tile_store_block2d(A_tile_half, A, d, q, lda, sg_i0_vs, sg_j0_vs);
#elif defined(BLOCK_A)
tile_store_block_rem_q(A_tile_dst, A, q, lda, sg_i0_vs, sg_j0_vs);
tile_store_block_rem_q(A_tile_half, A, q, lda, sg_i0_vs, sg_j0_vs);
#else
tile_store(A_tile_dst, A, d, q, lda, sg_i0_vs, sg_j0_vs);
tile_store(A_tile_half, A, d, q, lda, sg_i0_vs, sg_j0_vs);
#endif
}
10 changes: 1 addition & 9 deletions src/gpu/intel/ocl/micro_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,7 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
GEMMProblem problem;
problem.Ta_ext = jit::convert_dnnl_to_kernel_type(key_md()->data_type);
problem.Tb_ext = jit::convert_dnnl_to_kernel_type(qry_md()->data_type);
if (qry_md()->data_type == data_type::f16) {
problem.Ta = problem.Tb = Type::f16;
} else { // data_type is bf16
problem.Ta = problem.Tb = Type::bf16;
}
problem.Ta = problem.Tb = Type::f16;
problem.Tc = problem.Tc_ext = Type::f32;
problem.Ts = problem.Tc;

Expand Down Expand Up @@ -459,11 +455,7 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
kernel_ctx.define_int("NDIMS", ndims);

def_data_type(kernel_ctx, key_mdw.data_type(), "KEY");
def_data_type(kernel_ctx, qry_mdw.data_type(), "QRY");
def_data_type(kernel_ctx, val_mdw.data_type(), "VAL");
def_data_type(kernel_ctx, dst_mdw.data_type(), "DST");
def_data_type(kernel_ctx,
pd()->with_attn_mask() ? msk_mdw.data_type() : dnnl_f32, "MSK");

def_data_type(kernel_ctx, pd()->key_scales_dt(), "KEY_ATTR_SCALES");
def_data_type(kernel_ctx, pd()->value_scales_dt(), "VAL_ATTR_SCALES");
Expand Down
15 changes: 2 additions & 13 deletions src/gpu/intel/ocl/micro_sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,15 @@ struct micro_sdpa_t : public gpu_primitive_t {
attn_mask_md()->dims[mask_k_index] == desc()->keys(),
VERBOSE_INVALID_BROADCAST, "attn_mask", mask_k_index);
}
VDISPATCH_SDPA(
(utils::everyone_is(data_type::f16, qry_md()->data_type,
dst_md()->data_type)
|| utils::everyone_is(data_type::bf16,
qry_md()->data_type, dst_md()->data_type)),
VDISPATCH_SDPA(utils::everyone_is(data_type::f16,
qry_md()->data_type, dst_md()->data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_SDPA(
utils::one_of(key_md()->data_type, f16, u8, s8, u4, s4),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_SDPA(
utils::one_of(val_md()->data_type, f16, u8, s8, u4, s4),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_SDPA(
IMPLICATION(qry_md()->data_type == data_type::bf16,
utils::everyone_is(data_type::bf16,
val_md()->data_type, key_md()->data_type,
attn_mask_md()->data_type,
desc()->scale_dt)),
"Key and Value tensors (with scales and mask) should be "
"bf16 if Query is bf16");
VDISPATCH_SDPA(set_default_formats() == status::success,
VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_SDPA(desc()->values() == desc()->head_size(),
Expand Down
43 changes: 12 additions & 31 deletions src/gpu/intel/ocl/tile_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#define GPU_OCL_TILE_OPS_H

#include "gpu/intel/ocl/ocl_generic_vector_ops.h"
#include "gpu/intel/ocl/ocl_types.h"

float __builtin_IB_atomic_max_local_f32(__local float *, float);

Expand All @@ -31,11 +30,6 @@ __attribute__((overloadable)) half local_atomic_max(
return v;
}

__attribute__((overloadable)) ushort local_atomic_max(
local ushort *p, ushort v) { /* not implemented */
return v;
}

__attribute__((overloadable)) uint local_atomic_max(local uint *p, uint v) {
return atomic_max(p, v);
}
Expand Down Expand Up @@ -97,14 +91,6 @@ DEF_BLOCK_LOAD_STORE(half, ushort, _us, 2)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 4)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 8)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 16)

typedef ushort ushort1 __attribute__((ext_vector_type(1)));
DEF_BLOCK_LOAD_STORE1(ushort, ushort, _us)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 2)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 4)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 8)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 16)

DEF_BLOCK_LOAD_STORE1(uint, uint, )
DEF_BLOCK_LOAD_STORE(uint, uint, , 2)
DEF_BLOCK_LOAD_STORE(uint, uint, , 4)
Expand Down Expand Up @@ -152,9 +138,6 @@ DEF_BLOCK2D_LOAD_STORE(half, ushort, 8, 16, u16_m8k16v1, 16, 8)
DEF_BLOCK2D_LOAD_STORE(half, ushort, 8, 16, u16_m4k32v1, 32, 4)
DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)

DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 8, 16, u16_m4k32v1, 32, 4)
DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 16, 16, u16_m8k32v1, 32, 8)

#define tile_fill(t, v) \
do { \
_Pragma("unroll") for (int i = 0; i < sizeof(t.x) / sizeof(t.x[0]); \
Expand Down Expand Up @@ -194,15 +177,14 @@ DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 16, 16, u16_m8k32v1, 32, 8)
= __builtin_convertvector(t.x[i], __typeof__(t_new.x[i])); \
} while (0)

#define tile_copy_to_vec2(t, t_new, type) \
#define tile_copy_to_half2(t, t_new) \
do { \
_Pragma("unroll") for (int i = 0; i < sizeof(t.x) / sizeof(t.x[0]); \
i++) { \
_Pragma("unroll") for (int s = 0; \
s < sizeof(t.x[0]) / sizeof(t.x[0][0]) / 2; \
s++) { \
type v = {CONVERT_DATA_T(t.x[i][2 * s]), \
CONVERT_DATA_T(t.x[i][2 * s + 1])}; \
half2 v = {t.x[i][2 * s], t.x[i][2 * s + 1]}; \
t_new.x[i][s] = as_uint(v); \
} \
} \
Expand Down Expand Up @@ -496,8 +478,8 @@ DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 16, 16, u16_m8k32v1, 32, 8)
tile_type0 t0, tile_type1 *t1) { \
_Pragma("unroll") for (int j = 0; j < bc0 * nbc0; j++) { \
_Pragma("unroll") for (int i0 = 0; i0 < br0 * nbr0; i0 += sg0) { \
tile_access(*t1, i0, j, sg1, br1, bc1, nbr1) = CONVERT_DATA_T( \
tile_access(t0, i0, j, sg0, br0, bc0, nbr0)); \
tile_access(*t1, i0, j, sg1, br1, bc1, nbr1) \
= tile_access(t0, i0, j, sg0, br0, bc0, nbr0); \
} \
} \
}
Expand Down Expand Up @@ -591,28 +573,27 @@ DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 16, 16, u16_m8k32v1, 32, 8)
tile_store_block2d(t, ptr, m, n, m, offset_r, offset_c); \
}

#define DECLARE_2D_TILE_LOAD_PACKED_VEC( \
tile_type, element_type, vec_type, sg, br, bc, nbr, nbc) \
__attribute__((overloadable)) void tile_load_packed_vec2(tile_type *t, \
const global element_type *ptr, int m, int n, int ld, \
int offset_r, int offset_c) { \
#define DECLARE_2D_TILE_LOAD_PACKED_HALF(tile_type, sg, br, bc, nbr, nbc) \
__attribute__((overloadable)) void tile_load_packed_half(tile_type *t, \
const global half *ptr, int m, int n, int ld, int offset_r, \
int offset_c) { \
ptr += ld * offset_c + offset_r; \
_Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \
if (offset_c + j < n) { \
_Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \
int i = 2 * (i0 + get_sub_group_local_id()); \
vec_type loaded = 0; \
half2 loaded = 0; \
if (offset_r + i < m) loaded.s0 = ptr[i]; \
if (offset_r + i + 1 < m) loaded.s1 = ptr[i + 1]; \
tile_access(*t, i0, j, sg, br, bc, nbr) = as_uint(loaded); \
} \
} \
} \
} \
__attribute__((overloadable)) void tile_load_packed_vec2(tile_type *t, \
const global element_type *ptr, int m, int n, int offset_r, \
__attribute__((overloadable)) void tile_load_packed_half(tile_type *t, \
const global half *ptr, int m, int n, int offset_r, \
int offset_c) { \
tile_load_packed_vec2(t, ptr, m, n, m, offset_r, offset_c); \
tile_load_packed_half(t, ptr, m, n, m, offset_r, offset_c); \
}

#define cooperative_prefetch_2d(ptr, r, c, ld, sg_id, n_sg, sg_size, caching) \
Expand Down

0 comments on commit 4a0e123

Please sign in to comment.