Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable bfloat16 for micro sdpa kernel #2344

Merged
merged 3 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 42 additions & 36 deletions src/gpu/intel/ocl/micro_sdpa.cl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*******************************************************************************/

#include "gpu/intel/ocl/ocl_types.h"
#include "gpu/intel/ocl/sdpa_utils.h"
#include "gpu/intel/ocl/tile_ops.h"

Expand All @@ -38,25 +37,31 @@
typedef ugemm_kq_c_type s_tile_type;
typedef ugemm_vs_c_type a_tile_type;

#ifdef QRY_DT_F16
#define VEC_TYPE2 half2
#else // data type is bf16
#define VEC_TYPE2 ushort2
Comment on lines +42 to +43
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#else // data type is bf16
#define VEC_TYPE2 ushort2
#else if defined(QRY_DT_BF16)
#define VEC_TYPE2 ushort2
#else
#error "Not supported data type"

#endif

DECLARE_2D_TILE(q_tile_type, uint, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)

#ifdef BLOCK_Q
DECLARE_2D_TILE_BLOCK_OPS(
q_tile_type, uint, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
#elif Q_ALIGN < 4
DECLARE_2D_TILE_LOAD_PACKED_HALF(
q_tile_type, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
DECLARE_2D_TILE_LOAD_PACKED_VEC(q_tile_type, QRY_DATA_T, VEC_TYPE2,
SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
#endif

#ifdef BLOCK_A
DECLARE_2D_TILE(a_tile_type_half, half, SUBGROUP_SIZE, ugemm_vs_sg_tile_m, 1, 1,
ugemm_vs_sg_tile_n)
DECLARE_2D_TILE(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, ugemm_vs_sg_tile_m,
1, 1, ugemm_vs_sg_tile_n)
#else
DECLARE_2D_TILE(a_tile_type_half, half, SUBGROUP_SIZE, ugemm_vs_sg_tile_m, 8, 1,
ugemm_vs_sg_tile_n / 8)
DECLARE_2D_TILE(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, ugemm_vs_sg_tile_m,
8, 1, ugemm_vs_sg_tile_n / 8)
#endif

DECLARE_2D_TILE(s_tile_type_half2, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
DECLARE_2D_TILE(s_tile_type_packed, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1 / 2, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1)

Expand All @@ -78,34 +83,34 @@ DECLARE_2D_TILE(
#define mask_nbc ugemm_kq_c_type_nblock1
#endif

DECLARE_2D_TILE(mask_tile_type, half, SUBGROUP_SIZE, mask_br, mask_bc, mask_nbr,
mask_nbc)
DECLARE_2D_TILE(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br, mask_bc,
mask_nbr, mask_nbc)
DECLARE_2D_TILE(mask_tile_type_float, float, SUBGROUP_SIZE, mask_br, mask_bc,
mask_nbr, mask_nbc)

#if BROADCAST_MASK_Q
DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, half, SUBGROUP_SIZE, mask_br, mask_bc,
mask_nbr, mask_nbc)
DECLARE_2D_TILE_BLOCK_OPS(mask_tile_type, MSK_DATA_T, SUBGROUP_SIZE, mask_br,
mask_bc, mask_nbr, mask_nbc)
#endif

#ifdef BLOCK_A
DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_half, half, SUBGROUP_SIZE,
DECLARE_2D_TILE_BLOCK_OPS(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n)
#endif
#ifdef BLOCK_2D_A
DECLARE_2D_TILE_BLOCK2D_OPS(a_tile_type_half, half, SUBGROUP_SIZE,
DECLARE_2D_TILE_BLOCK2D_OPS(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 8, 1, ugemm_vs_sg_tile_n / 8)
#endif

#ifdef BLOCK_A
DECLARE_2D_TILE_COPY_REBLOCK(a_tile_type, SUBGROUP_SIZE, ugemm_vs_c_type_block0,
ugemm_vs_c_type_block1, ugemm_vs_c_type_nblock0,
ugemm_vs_c_type_nblock1, a_tile_type_half, SUBGROUP_SIZE,
ugemm_vs_c_type_nblock1, a_tile_type_dst, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 1, 1, ugemm_vs_sg_tile_n)
#else
DECLARE_2D_TILE_COPY_REBLOCK(a_tile_type, SUBGROUP_SIZE, ugemm_vs_c_type_block0,
ugemm_vs_c_type_block1, ugemm_vs_c_type_nblock0,
ugemm_vs_c_type_nblock1, a_tile_type_half, SUBGROUP_SIZE,
ugemm_vs_c_type_nblock1, a_tile_type_dst, SUBGROUP_SIZE,
ugemm_vs_sg_tile_m, 8, 1, ugemm_vs_sg_tile_n / 8)
#endif

Expand Down Expand Up @@ -160,10 +165,10 @@ DECLARE_2D_TILE_RSELECT(a_scale_tile_type, SUBGROUP_SIZE, ugemm_vs_sg_tile_n, 1,
#define binary_add(x, y) ((x) + (y))

__attribute__((intel_reqd_sub_group_size(SUBGROUP_SIZE))) kernel void
micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
const global VAL_DATA_T *V, global half *A,
const global SCALE_DATA_T *scale_ptr, const global half *msk, int d,
int k, int q, const global KEY_ATTR_SCALES_DATA_T *K_scales,
micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
const global VAL_DATA_T *V, global DST_DATA_T *A,
const global SCALE_DATA_T *scale_ptr, const global MSK_DATA_T *msk,
int d, int k, int q, const global KEY_ATTR_SCALES_DATA_T *K_scales,
const global KEY_ATTR_ZP_DATA_T *K_zp,
const global VAL_ATTR_SCALES_DATA_T *V_scales,
const global VAL_ATTR_ZP_DATA_T *V_zp) {
Expand Down Expand Up @@ -195,8 +200,9 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
uint sg_j_vs = sg_ij / ugemm_vs_sg_per_wg_m;

/* SLM allocations -- place in one array to work around compiler bug */
#define Q_slm_size (D_MAX * ugemm_kq_wg_tile_n * sizeof(half))
#define S_slm_size (ugemm_kq_wg_tile_m * ugemm_kq_wg_tile_n * sizeof(half))
#define Q_slm_size (D_MAX * ugemm_kq_wg_tile_n * sizeof(QRY_DATA_T))
#define S_slm_size \
(ugemm_kq_wg_tile_m * ugemm_kq_wg_tile_n * sizeof(QRY_DATA_T))
#define S_sum_slm_size \
(ugemm_kq_wg_tile_n * ugemm_kq_sg_per_wg_m * sizeof(float))
#define S_max_slm_size (ugemm_kq_wg_tile_n * sizeof(float))
Expand All @@ -205,8 +211,8 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
local char slm[Q_slm_size + S_slm_size + S_sum_slm_size + S_max_slm_size
+ ugemm_slm_size];

local half *Q_slm = (local half *)&slm[0];
local half *S_slm = (local half *)&slm[Q_slm_size];
local QRY_DATA_T *Q_slm = (local QRY_DATA_T *)&slm[0];
local QRY_DATA_T *S_slm = (local QRY_DATA_T *)&slm[Q_slm_size];
local float *S_sum_slm = (local float *)&slm[Q_slm_size + S_slm_size];
local float *S_max_slm
= (local float *)&slm[Q_slm_size + S_slm_size + S_sum_slm_size];
Expand Down Expand Up @@ -259,16 +265,16 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
tile_load(&Q_tile, (global uint *)Q, (d + 1) >> 1, q, ldq >> 1, 0,
wg_j0 + q0_copy);
#else
tile_load_packed_half(&Q_tile, Q, d, q, ldq, 0, wg_j0 + q0_copy);
tile_load_packed_vec2(&Q_tile, Q, d, q, ldq, 0, wg_j0 + q0_copy);
#endif

/* Load scale */
#if WITH_ATTN_SCALE
#if INVERT_SCALE
float iscale = convert_float(*scale_ptr);
float iscale = CONVERT_FLOAT_T(*scale_ptr);
float scale = native_recip(iscale);
#else
float scale = convert_float(*scale_ptr);
float scale = CONVERT_FLOAT_T(*scale_ptr);
float iscale = native_recip(scale);
#endif
#else
Expand Down Expand Up @@ -445,12 +451,12 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
tile_fill(S_sum_tile1, 0.0f);
tile_vreduce_add(S_tile, &S_sum_tile1);

/* Convert to half, VNNI format */
s_tile_type_half2 S_tile_half2;
tile_copy_to_half2(S_tile, S_tile_half2);
/* Convert to half or bf16, VNNI format */
s_tile_type_packed S_tile_packed;
tile_copy_to_vec2(S_tile, S_tile_packed, VEC_TYPE2);

/* Store to SLM, in packed format */
tile_store_t_sys_src2(S_tile_half2, (local uint *)S_slm,
tile_store_t_sys_src2(S_tile_packed, (local uint *)S_slm,
ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m / 2, sg_i0_kq / 2,
sg_j0_kq);
intel_work_group_barrier_arrive(CLK_LOCAL_MEM_FENCE);
Expand Down Expand Up @@ -580,17 +586,17 @@ micro_sdpa(const global KEY_DATA_T *K, const global half *Q,
tile_hbroadcast_mul(&A_tile, A_scale_tile);

/* Convert to half precision and store */
a_tile_type_half A_tile_half;
tile_copy_reblock(A_tile, &A_tile_half);
a_tile_type_dst A_tile_dst;
tile_copy_reblock(A_tile, &A_tile_dst);

uint sg_i0_vs = sg_i_vs * ugemm_vs_sg_tile_m;
uint sg_j0_vs = sg_j_vs * ugemm_vs_sg_tile_n + wg_j0;

#ifdef BLOCK_2D_A
tile_store_block2d(A_tile_half, A, d, q, lda, sg_i0_vs, sg_j0_vs);
tile_store_block2d(A_tile_dst, A, d, q, lda, sg_i0_vs, sg_j0_vs);
#elif defined(BLOCK_A)
tile_store_block_rem_q(A_tile_half, A, q, lda, sg_i0_vs, sg_j0_vs);
tile_store_block_rem_q(A_tile_dst, A, q, lda, sg_i0_vs, sg_j0_vs);
#else
tile_store(A_tile_half, A, d, q, lda, sg_i0_vs, sg_j0_vs);
tile_store(A_tile_dst, A, d, q, lda, sg_i0_vs, sg_j0_vs);
#endif
}
10 changes: 9 additions & 1 deletion src/gpu/intel/ocl/micro_sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,11 @@ status_t micro_sdpa_t::pd_t::init_microkernels(impl::engine_t *engine) {
GEMMProblem problem;
problem.Ta_ext = jit::convert_dnnl_to_kernel_type(key_md()->data_type);
problem.Tb_ext = jit::convert_dnnl_to_kernel_type(qry_md()->data_type);
problem.Ta = problem.Tb = Type::f16;
if (qry_md()->data_type == data_type::f16) {
problem.Ta = problem.Tb = Type::f16;
} else { // data_type is bf16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else if, and else for error/unimplemented.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its not needed since the init function will return if its not f16 or bf16.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the next data type will start being supported, it would be more difficult to find this particular spot what goes wrong instead of catching nicely here. This is about being nice to others in the future, not about the current state in the present.

problem.Ta = problem.Tb = Type::bf16;
}
problem.Tc = problem.Tc_ext = Type::f32;
problem.Ts = problem.Tc;

Expand Down Expand Up @@ -398,7 +402,11 @@ status_t micro_sdpa_t::init(impl::engine_t *engine) {
kernel_ctx.define_int("NDIMS", ndims);

def_data_type(kernel_ctx, key_mdw.data_type(), "KEY");
def_data_type(kernel_ctx, qry_mdw.data_type(), "QRY");
def_data_type(kernel_ctx, val_mdw.data_type(), "VAL");
def_data_type(kernel_ctx, dst_mdw.data_type(), "DST");
def_data_type(kernel_ctx,
pd()->with_attn_mask() ? msk_mdw.data_type() : dnnl_f32, "MSK");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no mask, shouldn't data type be undef to catch error earlier?

Maybe make a normal name "MASK", not sure I see the value of shorting one letter out...

Copy link
Contributor Author

@h-sadia h-sadia Jan 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The kernel is written in a way that mask has to be defined in any case. Its unavoidable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind to put a comment on this regard next to this line for other developers in a separate PR?

Copy link
Contributor Author

@h-sadia h-sadia Jan 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can follow it up with since I still have to create a PR for enabling quantization for bf16.


def_data_type(kernel_ctx, pd()->key_scales_dt(), "KEY_ATTR_SCALES");
def_data_type(kernel_ctx, pd()->value_scales_dt(), "VAL_ATTR_SCALES");
Expand Down
15 changes: 13 additions & 2 deletions src/gpu/intel/ocl/micro_sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,26 @@ struct micro_sdpa_t : public gpu_primitive_t {
attn_mask_md()->dims[mask_k_index] == desc()->keys(),
VERBOSE_INVALID_BROADCAST, "attn_mask", mask_k_index);
}
VDISPATCH_SDPA(utils::everyone_is(data_type::f16,
qry_md()->data_type, dst_md()->data_type),
VDISPATCH_SDPA(
(utils::everyone_is(data_type::f16, qry_md()->data_type,
dst_md()->data_type)
|| utils::everyone_is(data_type::bf16,
qry_md()->data_type, dst_md()->data_type)),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_SDPA(
utils::one_of(key_md()->data_type, f16, u8, s8, u4, s4),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_SDPA(
utils::one_of(val_md()->data_type, f16, u8, s8, u4, s4),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_SDPA(
IMPLICATION(qry_md()->data_type == data_type::bf16,
utils::everyone_is(data_type::bf16,
val_md()->data_type, key_md()->data_type,
attn_mask_md()->data_type,
desc()->scale_dt)),
"Key and Value tensors (with scales and mask) should be "
"bf16 if Query is bf16");
VDISPATCH_SDPA(set_default_formats() == status::success,
VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_SDPA(desc()->values() == desc()->head_size(),
Expand Down
43 changes: 31 additions & 12 deletions src/gpu/intel/ocl/tile_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define GPU_OCL_TILE_OPS_H

#include "gpu/intel/ocl/ocl_generic_vector_ops.h"
#include "gpu/intel/ocl/ocl_types.h"

float __builtin_IB_atomic_max_local_f32(__local float *, float);

Expand All @@ -30,6 +31,11 @@ __attribute__((overloadable)) half local_atomic_max(
return v;
}

__attribute__((overloadable)) ushort local_atomic_max(
local ushort *p, ushort v) { /* not implemented */
return v;
}

__attribute__((overloadable)) uint local_atomic_max(local uint *p, uint v) {
return atomic_max(p, v);
}
Expand Down Expand Up @@ -91,6 +97,14 @@ DEF_BLOCK_LOAD_STORE(half, ushort, _us, 2)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 4)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 8)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 16)

typedef ushort ushort1 __attribute__((ext_vector_type(1)));
DEF_BLOCK_LOAD_STORE1(ushort, ushort, _us)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 2)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 4)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 8)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 16)

DEF_BLOCK_LOAD_STORE1(uint, uint, )
DEF_BLOCK_LOAD_STORE(uint, uint, , 2)
DEF_BLOCK_LOAD_STORE(uint, uint, , 4)
Expand Down Expand Up @@ -137,6 +151,9 @@ DEF_BLOCK_LOAD_STORE16(uint, uint, )
DEF_BLOCK2D_LOAD_STORE(half, ushort, 8, 16, u16_m4k32v1, 32, 4)
DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)

DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 8, 16, u16_m4k32v1, 32, 4)
DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 16, 16, u16_m8k32v1, 32, 8)

#define tile_fill(t, v) \
do { \
_Pragma("unroll") for (int i = 0; i < sizeof(t.x) / sizeof(t.x[0]); \
Expand Down Expand Up @@ -176,14 +193,15 @@ DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)
= __builtin_convertvector(t.x[i], __typeof__(t_new.x[i])); \
} while (0)

#define tile_copy_to_half2(t, t_new) \
#define tile_copy_to_vec2(t, t_new, type) \
do { \
_Pragma("unroll") for (int i = 0; i < sizeof(t.x) / sizeof(t.x[0]); \
i++) { \
_Pragma("unroll") for (int s = 0; \
s < sizeof(t.x[0]) / sizeof(t.x[0][0]) / 2; \
s++) { \
half2 v = {t.x[i][2 * s], t.x[i][2 * s + 1]}; \
type v = {CONVERT_DATA_T(t.x[i][2 * s]), \
CONVERT_DATA_T(t.x[i][2 * s + 1])}; \
t_new.x[i][s] = as_uint(v); \
} \
} \
Expand Down Expand Up @@ -477,8 +495,8 @@ DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)
tile_type0 t0, tile_type1 *t1) { \
_Pragma("unroll") for (int j = 0; j < bc0 * nbc0; j++) { \
_Pragma("unroll") for (int i0 = 0; i0 < br0 * nbr0; i0 += sg0) { \
tile_access(*t1, i0, j, sg1, br1, bc1, nbr1) \
= tile_access(t0, i0, j, sg0, br0, bc0, nbr0); \
tile_access(*t1, i0, j, sg1, br1, bc1, nbr1) = CONVERT_DATA_T( \
tile_access(t0, i0, j, sg0, br0, bc0, nbr0)); \
} \
} \
}
Expand Down Expand Up @@ -572,27 +590,28 @@ DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)
tile_store_block2d(t, ptr, m, n, m, offset_r, offset_c); \
}

#define DECLARE_2D_TILE_LOAD_PACKED_HALF(tile_type, sg, br, bc, nbr, nbc) \
__attribute__((overloadable)) void tile_load_packed_half(tile_type *t, \
const global half *ptr, int m, int n, int ld, int offset_r, \
int offset_c) { \
#define DECLARE_2D_TILE_LOAD_PACKED_VEC( \
tile_type, element_type, vec_type, sg, br, bc, nbr, nbc) \
__attribute__((overloadable)) void tile_load_packed_vec2(tile_type *t, \
const global element_type *ptr, int m, int n, int ld, \
int offset_r, int offset_c) { \
ptr += ld * offset_c + offset_r; \
_Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \
if (offset_c + j < n) { \
_Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \
int i = 2 * (i0 + get_sub_group_local_id()); \
half2 loaded = 0; \
vec_type loaded = 0; \
if (offset_r + i < m) loaded.s0 = ptr[i]; \
if (offset_r + i + 1 < m) loaded.s1 = ptr[i + 1]; \
tile_access(*t, i0, j, sg, br, bc, nbr) = as_uint(loaded); \
} \
} \
} \
} \
__attribute__((overloadable)) void tile_load_packed_half(tile_type *t, \
const global half *ptr, int m, int n, int offset_r, \
__attribute__((overloadable)) void tile_load_packed_vec2(tile_type *t, \
const global element_type *ptr, int m, int n, int offset_r, \
int offset_c) { \
tile_load_packed_half(t, ptr, m, n, m, offset_r, offset_c); \
tile_load_packed_vec2(t, ptr, m, n, m, offset_r, offset_c); \
}

#define cooperative_prefetch_2d(ptr, r, c, ld, sg_id, n_sg, sg_size, caching) \
Expand Down
Loading