Skip to content

Commit

Permalink
xe: ocl: add support for zp in sum post op
Browse files Browse the repository at this point in the history
  • Loading branch information
dyoussif committed Jan 16, 2025
1 parent b1d384c commit 921efc6
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 13 deletions.
18 changes: 9 additions & 9 deletions src/gpu/intel/ocl/ocl_post_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,24 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
}

#define FMA_BLOCK( \
block_size, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b) \
block_size, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c) \
unroll_for(; nof_elems >= block_size; acc_ptr += block_size, \
a_ptr += block_size, nof_elems -= block_size) { \
CONCAT2(acc_elem_dt, block_size) \
a_conv = CONCAT3(convert_, acc_elem_dt, block_size)( \
*((CONCAT2(a_elem_dt, block_size) *)a_ptr)); \
*((CONCAT2(acc_elem_dt, block_size) *)acc_ptr) = fma( \
a_conv, b, *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr)); \
*((CONCAT2(acc_elem_dt, block_size) *)acc_ptr) = fma(a_conv - c, b, \
*((CONCAT2(acc_elem_dt, block_size) *)acc_ptr)); \
}

#define FMA_MIXED(acc_nof_elems, a, a_elem_dt, b, acc_ptr, acc_elem_dt) \
#define FMA_MIXED(acc_nof_elems, a, a_elem_dt, b, acc_ptr, acc_elem_dt, c) \
{ \
auto nof_elems = acc_nof_elems; \
a_elem_dt *a_ptr = (a_elem_dt *)(&a); \
FMA_BLOCK(8, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
FMA_BLOCK(4, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
FMA_BLOCK(2, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \
if (nof_elems == 1) { *acc_ptr += (*a_ptr) * b; } \
FMA_BLOCK(8, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \
FMA_BLOCK(4, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \
FMA_BLOCK(2, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \
if (nof_elems == 1) { *acc_ptr += (*a_ptr - c) * b; } \
}

#define po_dt(idx) CONCAT3(PO_, idx, _BIN_ARG_ACTUAL_DATA_T)
Expand Down Expand Up @@ -227,7 +227,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y,
#define APPLY_PO_SUM( \
idx, accumulator, acc_size, acc_elem_dt, sum_src, sum_elem_dt) \
FMA_MIXED(acc_size, sum_src, sum_elem_dt, CONCAT3(PO_, idx, _SUM_SCALE), \
accumulator, acc_elem_dt);
accumulator, acc_elem_dt, CONCAT3(PO_, idx, _SUM_ZP));

#define APPLY_PO_ELTWISE(idx, accumulator, nelems) \
FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \
Expand Down
11 changes: 7 additions & 4 deletions src/gpu/intel/primitive_conf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,7 @@ bool post_ops_with_binary_ok(const primitive_attr_t *attr,
const auto &p = attr->post_ops_;

auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(false); };
auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); };
auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false, false); };
auto is_binary = [&](int idx) { return p.entry_[idx].is_binary(); };
auto is_prelu = [&](int idx) { return p.entry_[idx].is_prelu(); };

Expand All @@ -563,7 +563,6 @@ bool post_ops_with_binary_ok(const primitive_attr_t *attr,
}
}
if (is_sum(po_idx)) {
if (p.entry_[po_idx].sum.zero_point != 0) return false;
if (p.entry_[po_idx].sum.dt != dnnl_data_type_undef
&& types::data_type_size(p.entry_[po_idx].sum.dt)
!= types::data_type_size(dst_dt))
Expand Down Expand Up @@ -692,19 +691,23 @@ status_t def_post_ops_cfg(compute::kernel_ctx_t &kernel_ctx,
("PO_" + std::to_string(idx) + "_ELTWISE_SCALE").c_str(),
1.0f);
}
if (e.is_sum(false)) {
if (e.is_sum(false, false)) {
kernel_ctx.define_int(
"PO_" + std::to_string(idx) + "_KIND", po_sum_id);
kernel_ctx.define_int(
"PO_" + std::to_string(idx) + "_ALG", alg_kind::undef);
kernel_ctx.define_float(
("PO_" + std::to_string(idx) + "_SUM_SCALE").c_str(),
e.sum.scale);
kernel_ctx.define_int(
("PO_" + std::to_string(idx) + "_SUM_ZP").c_str(),
e.sum.zero_point);

} else {
kernel_ctx.define_float(
("PO_" + std::to_string(idx) + "_SUM_SCALE").c_str(), 1.0f);
}
if (!(e.is_binary() || e.is_eltwise(false) || e.is_sum(false)
if (!(e.is_binary() || e.is_eltwise(false) || e.is_sum(false, false)
|| e.is_prelu())) {
// empty post op
kernel_ctx.define_int(
Expand Down

0 comments on commit 921efc6

Please sign in to comment.