diff --git a/src/gpu/intel/ocl/ocl_post_ops.h b/src/gpu/intel/ocl/ocl_post_ops.h index da32a3900a8..4552d4eaa53 100644 --- a/src/gpu/intel/ocl/ocl_post_ops.h +++ b/src/gpu/intel/ocl/ocl_post_ops.h @@ -71,24 +71,24 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y, } #define FMA_BLOCK( \ - block_size, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b) \ + block_size, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c) \ unroll_for(; nof_elems >= block_size; acc_ptr += block_size, \ a_ptr += block_size, nof_elems -= block_size) { \ CONCAT2(acc_elem_dt, block_size) \ a_conv = CONCAT3(convert_, acc_elem_dt, block_size)( \ *((CONCAT2(a_elem_dt, block_size) *)a_ptr)); \ - *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr) = fma( \ - a_conv, b, *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr)); \ + *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr) = fma(a_conv - c, b, \ + *((CONCAT2(acc_elem_dt, block_size) *)acc_ptr)); \ } -#define FMA_MIXED(acc_nof_elems, a, a_elem_dt, b, acc_ptr, acc_elem_dt) \ +#define FMA_MIXED(acc_nof_elems, a, a_elem_dt, b, acc_ptr, acc_elem_dt, c) \ { \ auto nof_elems = acc_nof_elems; \ a_elem_dt *a_ptr = (a_elem_dt *)(&a); \ - FMA_BLOCK(8, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \ - FMA_BLOCK(4, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \ - FMA_BLOCK(2, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b); \ - if (nof_elems == 1) { *acc_ptr += (*a_ptr) * b; } \ + FMA_BLOCK(8, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \ + FMA_BLOCK(4, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \ + FMA_BLOCK(2, nof_elems, acc_ptr, acc_elem_dt, a_ptr, a_elem_dt, b, c); \ + if (nof_elems == 1) { *acc_ptr += (*a_ptr - c) * b; } \ } #define po_dt(idx) CONCAT3(PO_, idx, _BIN_ARG_ACTUAL_DATA_T) @@ -227,7 +227,7 @@ float fwd_Xnary(unsigned kind, unsigned algorithm, float x, float y, #define APPLY_PO_SUM( \ idx, accumulator, acc_size, acc_elem_dt, sum_src, sum_elem_dt) \ FMA_MIXED(acc_size, sum_src, sum_elem_dt, CONCAT3(PO_, idx, _SUM_SCALE), \ - accumulator, acc_elem_dt); + accumulator, acc_elem_dt, CONCAT3(PO_, idx, _SUM_ZP)); #define APPLY_PO_ELTWISE(idx, accumulator, nelems) \ FWD_XNARY_GENERIC_DT(PO_ELTWISE, CONCAT3(PO_, idx, _ALG), accumulator, \ diff --git a/src/gpu/intel/primitive_conf.cpp b/src/gpu/intel/primitive_conf.cpp index c0447d17ba5..82c9691c5c5 100644 --- a/src/gpu/intel/primitive_conf.cpp +++ b/src/gpu/intel/primitive_conf.cpp @@ -543,7 +543,7 @@ bool post_ops_with_binary_ok(const primitive_attr_t *attr, const auto &p = attr->post_ops_; auto is_eltwise = [&](int idx) { return p.entry_[idx].is_eltwise(false); }; - auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false); }; + auto is_sum = [&](int idx) { return p.entry_[idx].is_sum(false, false); }; auto is_binary = [&](int idx) { return p.entry_[idx].is_binary(); }; auto is_prelu = [&](int idx) { return p.entry_[idx].is_prelu(); }; @@ -563,7 +563,6 @@ bool post_ops_with_binary_ok(const primitive_attr_t *attr, } } if (is_sum(po_idx)) { - if (p.entry_[po_idx].sum.zero_point != 0) return false; if (p.entry_[po_idx].sum.dt != dnnl_data_type_undef && types::data_type_size(p.entry_[po_idx].sum.dt) != types::data_type_size(dst_dt)) @@ -692,7 +691,7 @@ status_t def_post_ops_cfg(compute::kernel_ctx_t &kernel_ctx, ("PO_" + std::to_string(idx) + "_ELTWISE_SCALE").c_str(), 1.0f); } - if (e.is_sum(false)) { + if (e.is_sum(false, false)) { kernel_ctx.define_int( "PO_" + std::to_string(idx) + "_KIND", po_sum_id); kernel_ctx.define_int( @@ -700,11 +699,15 @@ status_t def_post_ops_cfg(compute::kernel_ctx_t &kernel_ctx, kernel_ctx.define_float( ("PO_" + std::to_string(idx) + "_SUM_SCALE").c_str(), e.sum.scale); + kernel_ctx.define_int( + ("PO_" + std::to_string(idx) + "_SUM_ZP").c_str(), + e.sum.zero_point); + } else { kernel_ctx.define_float( ("PO_" + std::to_string(idx) + "_SUM_SCALE").c_str(), 1.0f); } - if (!(e.is_binary() || e.is_eltwise(false) || e.is_sum(false) + if (!(e.is_binary() || e.is_eltwise(false) || e.is_sum(false, false) || e.is_prelu())) { // empty post op kernel_ctx.define_int(