Skip to content

Commit

Permalink
gpu: intel: ocl: create load/store defs for bf16 & add cnv macros
Browse files Browse the repository at this point in the history
  • Loading branch information
h-sadia committed Jan 10, 2025
1 parent fdf93c8 commit 660916e
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 38 deletions.
31 changes: 18 additions & 13 deletions src/gpu/intel/ocl/micro_sdpa.cl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
* limitations under the License.
*******************************************************************************/

#include "gpu/intel/ocl/ocl_types.h"
#include "gpu/intel/ocl/sdpa_utils.h"
#include "gpu/intel/ocl/tile_ops.h"

Expand All @@ -38,14 +37,20 @@
typedef ugemm_kq_c_type s_tile_type;
typedef ugemm_vs_c_type a_tile_type;

#ifdef QRY_DT_F16
#define VEC_TYPE half2
#else // data type is bf16
#define VEC_TYPE ushort2
#endif

DECLARE_2D_TILE(q_tile_type, uint, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)

#ifdef BLOCK_Q
DECLARE_2D_TILE_BLOCK_OPS(
q_tile_type, uint, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
#elif Q_ALIGN < 4
DECLARE_2D_TILE_LOAD_PACKED_HALF(
q_tile_type, SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
DECLARE_2D_TILE_LOAD_PACKED_VEC(q_tile_type, QRY_DATA_T, VEC_TYPE,
SUBGROUP_SIZE, D_MAX / 2, 1, 1, q_tile_sg_n)
#endif

#ifdef BLOCK_A
Expand All @@ -56,7 +61,7 @@ DECLARE_2D_TILE(a_tile_type_dst, DST_DATA_T, SUBGROUP_SIZE, ugemm_vs_sg_tile_m,
8, 1, ugemm_vs_sg_tile_n / 8)
#endif

DECLARE_2D_TILE(s_tile_type_half2, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
DECLARE_2D_TILE(s_tile_type_packed, uint, SUBGROUP_SIZE, ugemm_kq_c_type_block0,
ugemm_kq_c_type_block1 / 2, ugemm_kq_c_type_nblock0,
ugemm_kq_c_type_nblock1)

Expand Down Expand Up @@ -233,7 +238,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
K_scales += KEY_OFF(b1, b0_kv, 0, 0) / KEY_GROUP_SIZE;
#endif
#if KEY_SCALES == QUANTIZE_COMMON
float k_scale = convert_float(*K_scales);
float k_scale = CONVERT_FLOAT_T(*K_scales);
#endif
#if KEY_ZERO_POINTS
K_zp += KEY_OFF(b1, b0_kv, 0, 0) / KEY_GROUP_SIZE
Expand All @@ -243,7 +248,7 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
V_scales += VAL_OFF(b1, b0_kv, 0, 0) / VAL_GROUP_SIZE;
#endif
#if VAL_SCALES == QUANTIZE_COMMON
float v_scale = convert_float(*V_scales);
float v_scale = CONVERT_FLOAT_T(*V_scales);
#endif
#if VAL_ZERO_POINTS
V_zp += VAL_OFF(b1, b0_kv, 0, 0) / VAL_GROUP_SIZE
Expand All @@ -260,16 +265,16 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
tile_load(&Q_tile, (global uint *)Q, (d + 1) >> 1, q, ldq >> 1, 0,
wg_j0 + q0_copy);
#else
tile_load_packed_half(&Q_tile, Q, d, q, ldq, 0, wg_j0 + q0_copy);
tile_load_packed_vec(&Q_tile, Q, d, q, ldq, 0, wg_j0 + q0_copy);
#endif

/* Load scale */
#if WITH_ATTN_SCALE
#if INVERT_SCALE
float iscale = convert_float(*scale_ptr);
float iscale = CONVERT_FLOAT_T(*scale_ptr);
float scale = native_recip(iscale);
#else
float scale = convert_float(*scale_ptr);
float scale = CONVERT_FLOAT_T(*scale_ptr);
float iscale = native_recip(scale);
#endif
#else
Expand Down Expand Up @@ -446,12 +451,12 @@ micro_sdpa(const global KEY_DATA_T *K, const global QRY_DATA_T *Q,
tile_fill(S_sum_tile1, 0.0f);
tile_vreduce_add(S_tile, &S_sum_tile1);

/* Convert to half, VNNI format */
s_tile_type_half2 S_tile_half2;
tile_copy_to_half2(S_tile, S_tile_half2);
/* Convert to half or bf16, VNNI format */
s_tile_type_packed S_tile_packed;
tile_copy_to_vec(S_tile, S_tile_packed, VEC_TYPE);

/* Store to SLM, in packed format */
tile_store_t_sys_src2(S_tile_half2, (local uint *)S_slm,
tile_store_t_sys_src2(S_tile_packed, (local uint *)S_slm,
ugemm_vs_sg_tile_n, ugemm_kq_wg_tile_m / 2, sg_i0_kq / 2,
sg_j0_kq);
intel_work_group_barrier_arrive(CLK_LOCAL_MEM_FENCE);
Expand Down
75 changes: 50 additions & 25 deletions src/gpu/intel/ocl/tile_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#define GPU_OCL_TILE_OPS_H

#include "gpu/intel/ocl/ocl_generic_vector_ops.h"
#include "gpu/intel/ocl/ocl_types.h"

float __builtin_IB_atomic_max_local_f32(__local float *, float);

Expand All @@ -30,6 +31,11 @@ __attribute__((overloadable)) half local_atomic_max(
return v;
}

__attribute__((overloadable)) ushort local_atomic_max(
local ushort *p, half v) { /* not implemented */
return v;
}

__attribute__((overloadable)) uint local_atomic_max(local uint *p, uint v) {
return atomic_max(p, v);
}
Expand All @@ -53,8 +59,8 @@ __attribute__((overloadable)) int local_atomic_max(local int *p, int v) {

#define DEF_BLOCK_LOAD_STORE1(type, itype, suffix) \
__attribute__((overloadable)) \
type##1 block_load(const global type *p, int vlen) __attribute__( \
(enable_if(vlen == 1, "wrong vector length"))) { \
type##1 block_load(const global type *p, int vlen) \
__attribute__((enable_if(vlen == 1, "wrong vector length"))) { \
type##1 x; \
x[0] = as_##type( \
intel_sub_group_block_read##suffix((global void *)p)); \
Expand All @@ -68,8 +74,8 @@ __attribute__((overloadable)) int local_atomic_max(local int *p, int v) {

#define DEF_BLOCK_LOAD_STORE16(type, itype, suffix) \
__attribute__((overloadable)) \
type##16 block_load(const global type *p, int vlen) __attribute__( \
(enable_if(vlen == 16, "wrong vector length"))) { \
type##16 block_load(const global type *p, int vlen) \
__attribute__((enable_if(vlen == 16, "wrong vector length"))) { \
type##16 x; \
x.s01234567 = as_##type##8( \
intel_sub_group_block_read##suffix##8((global void *)p)); \
Expand All @@ -86,11 +92,23 @@ __attribute__((overloadable)) int local_atomic_max(local int *p, int v) {
as_##itype##8(v.s89abcdef)); \
}

#ifdef QRY_DT_F16
DEF_BLOCK_LOAD_STORE1(half, ushort, _us)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 2)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 4)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 8)
DEF_BLOCK_LOAD_STORE(half, ushort, _us, 16)
#endif

#ifdef QRY_DT_BF16
typedef ushort ushort1 __attribute__((ext_vector_type(1)));
DEF_BLOCK_LOAD_STORE1(ushort, ushort, _us)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 2)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 4)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 8)
DEF_BLOCK_LOAD_STORE(ushort, ushort, _us, 16)
#endif

DEF_BLOCK_LOAD_STORE1(uint, uint, )
DEF_BLOCK_LOAD_STORE(uint, uint, , 2)
DEF_BLOCK_LOAD_STORE(uint, uint, , 4)
Expand All @@ -103,11 +121,10 @@ DEF_BLOCK_LOAD_STORE16(uint, uint, )
void __builtin_IB_subgroup_block_write_flat_##suffix( \
long, int, int, int, int2, itype##vl); \
__attribute__((overloadable)) type##vl block2d_load(const global type *p, \
int w, int h, int ld, int x, int y, int br, int bc, \
int sg) __attribute__((enable_if(br == BR, "wrong #rows"))) \
int w, int h, int ld, int x, int y, int br, int bc, int sg) \
__attribute__((enable_if(br == BR, "wrong #rows"))) \
__attribute__((enable_if(bc == BC, "wrong #columns"))) \
__attribute__( \
(enable_if(sg == SG, "wrong subgroup size"))) { \
__attribute__((enable_if(sg == SG, "wrong subgroup size"))) { \
ulong pp = as_long(p); \
ulong prem = pp & 0x3F; \
pp &= ~0x3F; \
Expand All @@ -119,11 +136,10 @@ DEF_BLOCK_LOAD_STORE16(uint, uint, )
} \
__attribute__((overloadable)) void block2d_store(type##vl v, \
const global type *p, int w, int h, int ld, int x, int y, int br, \
int bc, \
int sg) __attribute__((enable_if(br == BR, "wrong #rows"))) \
int bc, int sg) \
__attribute__((enable_if(br == BR, "wrong #rows"))) \
__attribute__((enable_if(bc == BC, "wrong #columns"))) \
__attribute__( \
(enable_if(sg == SG, "wrong subgroup size"))) { \
__attribute__((enable_if(sg == SG, "wrong subgroup size"))) { \
ulong pp = as_long(p); \
ulong prem = pp & 0x3F; \
pp &= ~0x3F; \
Expand All @@ -134,8 +150,15 @@ DEF_BLOCK_LOAD_STORE16(uint, uint, )
pp, w - 1, h - 1, ld - 1, coord, as_##itype##vl(v)); \
}

#ifdef QRY_DT_F16
DEF_BLOCK2D_LOAD_STORE(half, ushort, 8, 16, u16_m4k32v1, 32, 4)
DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)
#endif

#ifdef QRY_DT_BF16
DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 8, 16, u16_m4k32v1, 32, 4)
DEF_BLOCK2D_LOAD_STORE(ushort, ushort, 16, 16, u16_m8k32v1, 32, 8)
#endif

#define tile_fill(t, v) \
do { \
Expand Down Expand Up @@ -176,14 +199,15 @@ DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)
= __builtin_convertvector(t.x[i], __typeof__(t_new.x[i])); \
} while (0)

#define tile_copy_to_half2(t, t_new) \
#define tile_copy_to_vec(t, t_new, type) \
do { \
_Pragma("unroll") for (int i = 0; i < sizeof(t.x) / sizeof(t.x[0]); \
i++) { \
_Pragma("unroll") for (int s = 0; \
s < sizeof(t.x[0]) / sizeof(t.x[0][0]) / 2; \
s++) { \
half2 v = {t.x[i][2 * s], t.x[i][2 * s + 1]}; \
type v = {CONVERT_DATA_T(t.x[i][2 * s]), \
CONVERT_DATA_T(t.x[i][2 * s + 1])}; \
t_new.x[i][s] = as_uint(v); \
} \
} \
Expand Down Expand Up @@ -477,15 +501,15 @@ DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)
tile_type0 t0, tile_type1 *t1) { \
_Pragma("unroll") for (int j = 0; j < bc0 * nbc0; j++) { \
_Pragma("unroll") for (int i0 = 0; i0 < br0 * nbr0; i0 += sg0) { \
tile_access(*t1, i0, j, sg1, br1, bc1, nbr1) \
= tile_access(t0, i0, j, sg0, br0, bc0, nbr0); \
tile_access(*t1, i0, j, sg1, br1, bc1, nbr1) = CONVERT_DATA_T( \
tile_access(t0, i0, j, sg0, br0, bc0, nbr0)); \
} \
} \
}

#define DECLARE_2D_TILE(tile_type, element_type, sg, br, bc, nbr, nbc) \
typedef element_type __attribute__((ext_vector_type(br * bc / sg))) \
_e_##tile_type; \
_e_##tile_type; \
typedef struct { \
_e_##tile_type x[nbr * nbc]; \
} tile_type; \
Expand Down Expand Up @@ -572,27 +596,28 @@ DEF_BLOCK2D_LOAD_STORE(half, ushort, 16, 16, u16_m8k32v1, 32, 8)
tile_store_block2d(t, ptr, m, n, m, offset_r, offset_c); \
}

#define DECLARE_2D_TILE_LOAD_PACKED_HALF(tile_type, sg, br, bc, nbr, nbc) \
__attribute__((overloadable)) void tile_load_packed_half(tile_type *t, \
const global half *ptr, int m, int n, int ld, int offset_r, \
int offset_c) { \
#define DECLARE_2D_TILE_LOAD_PACKED_VEC( \
tile_type, element_type, vec_type, sg, br, bc, nbr, nbc) \
__attribute__((overloadable)) void tile_load_packed_vec(tile_type *t, \
const global element_type *ptr, int m, int n, int ld, \
int offset_r, int offset_c) { \
ptr += ld * offset_c + offset_r; \
_Pragma("unroll") for (int j = 0; j < bc * nbc; j++, ptr += ld) { \
if (offset_c + j < n) { \
_Pragma("unroll") for (int i0 = 0; i0 < br * nbr; i0 += sg) { \
int i = 2 * (i0 + get_sub_group_local_id()); \
half2 loaded = 0; \
vec_type loaded = 0; \
if (offset_r + i < m) loaded.s0 = ptr[i]; \
if (offset_r + i + 1 < m) loaded.s1 = ptr[i + 1]; \
tile_access(*t, i0, j, sg, br, bc, nbr) = as_uint(loaded); \
} \
} \
} \
} \
__attribute__((overloadable)) void tile_load_packed_half(tile_type *t, \
const global half *ptr, int m, int n, int offset_r, \
__attribute__((overloadable)) void tile_load_packed_vec(tile_type *t, \
const global element_type *ptr, int m, int n, int offset_r, \
int offset_c) { \
tile_load_packed_half(t, ptr, m, n, m, offset_r, offset_c); \
tile_load_packed_vec(t, ptr, m, n, m, offset_r, offset_c); \
}

#define cooperative_prefetch_2d(ptr, r, c, ld, sg_id, n_sg, sg_size, caching) \
Expand Down

0 comments on commit 660916e

Please sign in to comment.