Skip to content

Commit

Permalink
Use reference kernels for reduce op, to avoid two kernel calls and si…
Browse files Browse the repository at this point in the history
…mplify operator setup.

PiperOrigin-RevId: 704921025
  • Loading branch information
dsharletg authored and xnnpack-bot committed Dec 11, 2024
1 parent b8362d1 commit c48c6e4
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 110 deletions.
11 changes: 3 additions & 8 deletions src/microparams-init.c
Original file line number Diff line number Diff line change
Expand Up @@ -1218,30 +1218,25 @@ size_t xnn_init_qu8_reduce_scalar_params(

size_t xnn_update_f32_reduce_scalar_params(
struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)],
float scale,
int32_t num_elements)
float scale)
{
params->f32.scale = scale;
return sizeof(params->f32);
}

size_t xnn_update_qs8_reduce_scalar_params(
struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)],
float scale,
int32_t num_elements)
float scale)
{
params->qs8.scale = params->qs8.input_output_scale * scale;
params->qs8.num_elements = num_elements;
return sizeof(params->qs8);
}

size_t xnn_update_qu8_reduce_scalar_params(
struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)],
float scale,
int32_t num_elements)
float scale)
{
params->qu8.scale = params->qs8.input_output_scale * scale;
params->qu8.num_elements = num_elements;
return sizeof(params->qu8);
}

Expand Down
48 changes: 2 additions & 46 deletions src/operator-run.c
Original file line number Diff line number Diff line change
Expand Up @@ -2091,29 +2091,7 @@ void xnn_compute_contiguous_reduce(
void* workspace_ptr = (void*) ((uintptr_t) context->workspace + workspace_offset);
output_ptr = (void*) ((uintptr_t) context->output + output_offset);

if (context->s32_f32_cvt_ukernel) {
struct xnn_s32_f32_cvt_params s32_f32_cvt_params;
s32_f32_cvt_params.scalar.zero_point = context->params.qs8.num_elements * (int32_t) context->params.qs8.input_zero_point;
context->s32_f32_cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr,
workspace_ptr, (union xnn_unary_uparams*) &s32_f32_cvt_params);
struct xnn_f32_qs8_cvt_params cvt_params;
cvt_params.scalar.scale = context->params.qs8.scale;
cvt_params.scalar.output_zero_point = context->params.qs8.output_zero_point;
context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr,
output_ptr, (union xnn_unary_uparams*) &cvt_params);
} else if (context->u32_f32_cvt_ukernel) {
struct xnn_s32_f32_cvt_params s32_f32_cvt_params;
s32_f32_cvt_params.scalar.zero_point = context->params.qu8.num_elements * (int32_t) context->params.qu8.input_zero_point;
context->u32_f32_cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr,
workspace_ptr, (union xnn_unary_uparams*) &s32_f32_cvt_params);
struct xnn_f32_qu8_cvt_params cvt_params;
cvt_params.scalar.scale = context->params.qu8.scale;
cvt_params.scalar.output_zero_point = context->params.qu8.output_zero_point;
context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr,
output_ptr, (union xnn_unary_uparams*) &cvt_params);
} else {
context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, output_ptr, /*params=*/NULL);
}
context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, output_ptr, &context->cvt_params);
}
}

Expand Down Expand Up @@ -2174,29 +2152,7 @@ void xnn_compute_discontiguous_reduce(
void* workspace_ptr = (void*) ((uintptr_t) context->workspace + workspace_offset);
output_ptr = (void*) ((uintptr_t) context->output + output_offset);

if (context->s32_f32_cvt_ukernel) {
struct xnn_s32_f32_cvt_params s32_f32_cvt_params;
s32_f32_cvt_params.scalar.zero_point = context->params.qs8.num_elements * (int32_t) context->params.qs8.input_zero_point;
context->s32_f32_cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr,
workspace_ptr, (union xnn_unary_uparams*) &s32_f32_cvt_params);
struct xnn_f32_qs8_cvt_params cvt_params;
cvt_params.scalar.scale = context->params.qs8.scale;
cvt_params.scalar.output_zero_point = context->params.qs8.output_zero_point;
context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr,
output_ptr, (union xnn_unary_uparams*) &cvt_params);
} else if (context->u32_f32_cvt_ukernel) {
struct xnn_s32_f32_cvt_params s32_f32_cvt_params;
s32_f32_cvt_params.scalar.zero_point = context->params.qu8.num_elements * (int32_t) context->params.qu8.input_zero_point;
context->u32_f32_cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr,
workspace_ptr, (union xnn_unary_uparams*) &s32_f32_cvt_params);
struct xnn_f32_qu8_cvt_params cvt_params;
cvt_params.scalar.scale = context->params.qu8.scale;
cvt_params.scalar.output_zero_point = context->params.qu8.output_zero_point;
context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr,
output_ptr, (union xnn_unary_uparams*) &cvt_params);
} else {
context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, output_ptr, /*params=*/NULL);
}
context->cvt_ukernel(context->accumulation_element_size * output2_block_size, workspace_ptr, output_ptr, &context->cvt_params);
}
}

Expand Down
96 changes: 50 additions & 46 deletions src/operators/reduce-nd.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include <assert.h>
#include <inttypes.h>
#include <math.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
Expand All @@ -16,6 +17,7 @@
#include "xnnpack/compute.h"
#include "xnnpack/config-types.h"
#include "xnnpack/config.h"
#include "xnnpack/reference-config.h"
#include "xnnpack/datatype.h"
#include "xnnpack/log.h"
#include "xnnpack/microkernel-type.h"
Expand All @@ -34,10 +36,10 @@ static enum xnn_status create_reduce_nd(
const struct xnn_reduce_config* rdsum_config,
const struct xnn_reduce_config* rsum_config,
const struct xnn_unary_elementwise_config* cvt_config,
const struct xnn_unary_elementwise_config* s32_f32_cvt_config,
const struct xnn_unary_elementwise_config* u32_f32_cvt_config,
const void* params,
size_t params_size,
const void* cvt_params,
size_t cvt_params_size,
xnn_operator_t* reduce_op_out)
{
xnn_operator_t reduce_op = NULL;
Expand All @@ -64,13 +66,14 @@ static enum xnn_status create_reduce_nd(
reduce_op->rdsum_config = rdsum_config;
reduce_op->rsum_config = rsum_config;
reduce_op->cvt_config = cvt_config;
reduce_op->s32_f32_cvt_config = s32_f32_cvt_config;
reduce_op->u32_f32_cvt_config = u32_f32_cvt_config;
reduce_op->reduce.log2_data_element_size = log2_data_element_size;
reduce_op->reduce.log2_accumulator_element_size = log2_accumulator_element_size;
if (params_size != 0) {
memcpy(&reduce_op->params, params, params_size);
}
if (cvt_params_size != 0) {
memcpy(&reduce_op->params2, cvt_params, cvt_params_size);
}

reduce_op->state = xnn_run_state_invalid;

Expand Down Expand Up @@ -195,20 +198,22 @@ static enum xnn_status reshape_reduce_nd(
if (workspace_alignment != NULL) {
*workspace_alignment = XNN_ALLOCATION_ALIGNMENT;
}

size_t num_reduction_elements;
if (normalized_reduction_axes[num_reduction_axes - 1] == num_input_dims - 1) {
if (workspace_size != NULL) {
const size_t num_output_elements = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4];
*workspace_size = (num_output_elements << log2_accumulator_element_size) + XNN_EXTRA_BYTES;
}
const size_t scale_dim = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5];
num_reduction_elements = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5];
const size_t axis_dim = normalized_input_shape[5];

if (reduce_op->rsum_config->update != NULL) {
float scale = 1.0f;
if (reduce_op->type == xnn_operator_type_mean_nd) {
scale = 1.0f / scale_dim;
scale = 1.0f / num_reduction_elements;
}
reduce_op->rsum_config->update(&reduce_op->params.reduce, scale, scale_dim);
reduce_op->rsum_config->update(&reduce_op->params.reduce, scale);
}

reduce_op->context.reduce = (struct reduce_context) {
Expand All @@ -217,7 +222,6 @@ static enum xnn_status reshape_reduce_nd(
.accumulation_element_size = UINT32_C(1) << log2_accumulator_element_size,
.output_element_size = UINT32_C(1) << log2_data_element_size,
};
memcpy(&reduce_op->context.reduce.params, &reduce_op->params.reduce, sizeof(reduce_op->params.reduce));

reduce_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_contiguous_reduce;
reduce_op->compute[0].range[0] = normalized_input_shape[0];
Expand All @@ -229,29 +233,22 @@ static enum xnn_status reshape_reduce_nd(
for (int i = XNN_MAX_TENSOR_DIMS / 2 - 2; i >= 0; --i) {
reduce_op->context.reduce.output_stride[i] = (reduce_op->context.reduce.output_stride[i + 1] * normalized_input_shape[(i + 1) * 2]);
}

if (reduce_op->s32_f32_cvt_config) {
reduce_op->context.reduce.s32_f32_cvt_ukernel = reduce_op->s32_f32_cvt_config->ukernel;
}
if (reduce_op->u32_f32_cvt_config) {
reduce_op->context.reduce.u32_f32_cvt_ukernel = reduce_op->u32_f32_cvt_config->ukernel;
}
} else {
// Reduction along the non-innermost dimension
const size_t channel_like_dim = normalized_input_shape[XNN_MAX_TENSOR_DIMS - 1];
if (workspace_size != NULL) {
const size_t num_output_elements = normalized_input_shape[1] * normalized_input_shape[3] * normalized_input_shape[5];
*workspace_size = (num_output_elements << log2_accumulator_element_size) + XNN_EXTRA_BYTES;
}
const size_t scale_dim = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4];
num_reduction_elements = normalized_input_shape[0] * normalized_input_shape[2] * normalized_input_shape[4];
const size_t axis_dim = normalized_input_shape[4];

if (reduce_op->rdsum_config->update != NULL) {
float scale = 1.0f;
if (reduce_op->type == xnn_operator_type_mean_nd) {
scale = 1.0f / scale_dim;
scale = 1.0f / num_reduction_elements;
}
reduce_op->rdsum_config->update(&reduce_op->params.reduce, scale, scale_dim);
reduce_op->rdsum_config->update(&reduce_op->params.reduce, scale);
}
if (reduce_op->channels != channel_like_dim) {
const size_t zero_size = (channel_like_dim << log2_data_element_size) + XNN_EXTRA_BYTES;
Expand All @@ -274,7 +271,6 @@ static enum xnn_status reshape_reduce_nd(
.accumulation_element_size = UINT32_C(1) << log2_accumulator_element_size,
.output_element_size = UINT32_C(1) << log2_data_element_size,
};
memcpy(&reduce_op->context.reduce.params, &reduce_op->params.reduce, sizeof(reduce_op->params.reduce));
reduce_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_discontiguous_reduce;
reduce_op->compute[0].range[0] = normalized_input_shape[1];
reduce_op->compute[0].range[1] = normalized_input_shape[3];
Expand All @@ -286,18 +282,35 @@ static enum xnn_status reshape_reduce_nd(
reduce_op->context.reduce.output_stride[i] = (reduce_op->context.reduce.output_stride[i + 1] * normalized_input_shape[(i * 2+3)]);
}
}
memcpy(&reduce_op->context.reduce.params, &reduce_op->params.reduce, sizeof(reduce_op->params.reduce));
memcpy(&reduce_op->context.reduce.cvt_params, &reduce_op->params2.unary, sizeof(reduce_op->params2.unary));
reduce_op->context.reduce.input_stride[XNN_MAX_TENSOR_DIMS - 1] = (1 << log2_data_element_size);
if (reduce_op->cvt_config) {
reduce_op->context.reduce.cvt_ukernel = reduce_op->cvt_config->ukernel;
// int32 is not actually a quantized type, so we need to include the input
// zero point (multiplied by the number of reduction elements) as part of
// the computation of the output zero point.
// The conversion normally looks like:
//
// y = (x - x_zero_point) * x_scale * inv_y_scale + y_zero_point
//
// Since this conversion ignores x_zero_point and x_scale, rewrite to:
//
// y = x * x_scale * inv_y_scale - x_zero_point * x_scale * inv_y_scale + y_zero_point
//
// Now we can say:
//
// inv_y_scale' = x_scale * inv_y_scale
// y_zero_point' = y_zero_point - x_zero_point * x_scale * inv_y_scale
reduce_op->context.reduce.cvt_params.reference.inv_y_scale =
reduce_op->context.reduce.params.qs8.scale;
reduce_op->context.reduce.cvt_params.reference.y_zero_point -=
((int32_t) num_reduction_elements *
reduce_op->context.reduce.cvt_params.reference.x_zero_point) *
reduce_op->context.reduce.cvt_params.reference.inv_y_scale;
}
if (reduce_op->s32_f32_cvt_config) {
reduce_op->context.reduce.s32_f32_cvt_ukernel = reduce_op->s32_f32_cvt_config->ukernel;
}
if (reduce_op->u32_f32_cvt_config) {
reduce_op->context.reduce.u32_f32_cvt_ukernel = reduce_op->u32_f32_cvt_config->ukernel;
}
for (int i = XNN_MAX_TENSOR_DIMS - 2; i >= 0; --i) {
reduce_op->context.reduce.input_stride[i] = (reduce_op->context.reduce.input_stride[i + 1] * normalized_input_shape[i + 1]);
for (int i = XNN_MAX_TENSOR_DIMS - 2; i >= 0; --i) {
reduce_op->context.reduce.input_stride[i] = (reduce_op->context.reduce.input_stride[i + 1] * normalized_input_shape[i + 1]);
}
memcpy(reduce_op->context.reduce.input_shape, normalized_input_shape, XNN_MAX_TENSOR_DIMS * sizeof(size_t));
reduce_op->state = xnn_run_state_needs_setup;
Expand Down Expand Up @@ -363,8 +376,6 @@ enum xnn_status xnn_create_reduce_nd(
const struct xnn_reduce_config* rsum_config = NULL;
const struct xnn_reduce_config* rdsum_config = NULL;
const struct xnn_unary_elementwise_config* cvt_config = NULL;
const struct xnn_unary_elementwise_config* s32_f32_cvt_config = NULL;
const struct xnn_unary_elementwise_config* u32_f32_cvt_config = NULL;
uint32_t log2_data_element_size = xnn_datatype_log2_size_bytes(datatype);
uint32_t log2_accumulator_element_size;
switch(datatype) {
Expand All @@ -373,37 +384,29 @@ enum xnn_status xnn_create_reduce_nd(
rsum_config = xnn_init_f16_f32acc_rsum_config();
rdsum_config = xnn_init_f16_f32acc_rdsum_config();
cvt_config = xnn_init_f32_to_f16_cvt_config();
s32_f32_cvt_config = unused;
u32_f32_cvt_config = unused;
break;
}
case xnn_datatype_fp32: {
log2_accumulator_element_size = 2;
rsum_config = xnn_init_f32_rsum_config();
rdsum_config = xnn_init_f32_rdsum_config();
cvt_config = unused;
s32_f32_cvt_config = unused;
u32_f32_cvt_config = unused;
break;
}
case xnn_datatype_qint8: { // qs8
log2_accumulator_element_size = 2;
rsum_config = xnn_init_qs8_rsum_config();
rdsum_config = xnn_init_qs8_rdsum_config();
cvt_config = xnn_init_f32_to_qs8_cvt_config();
s32_f32_cvt_config = xnn_init_s32_to_f32_cvt_config();
u32_f32_cvt_config = unused;
cvt_config = xnn_init_unary_reference_config(xnn_unary_convert, xnn_datatype_int32, xnn_datatype_qint8);
break;
}
case xnn_datatype_quint8: { // qu8
log2_accumulator_element_size = 2;
rsum_config = xnn_init_qu8_rsum_config();
rdsum_config = xnn_init_qu8_rdsum_config();
cvt_config = xnn_init_f32_to_qu8_cvt_config();
s32_f32_cvt_config = unused;
// We just use an int32 -> f32 conversion. This means we effectively only
// We just use an int32 -> qu8 conversion. This means we effectively only
// have a 31-bit accumulator instead of 32-bit, but that seems insignificant.
u32_f32_cvt_config = xnn_init_s32_to_f32_cvt_config();
cvt_config = xnn_init_unary_reference_config(xnn_unary_convert, xnn_datatype_int32, xnn_datatype_quint8);
break;
}
default:
Expand All @@ -413,16 +416,13 @@ enum xnn_status xnn_create_reduce_nd(
};

// Check configs and restore unused pointers to NULL.
if (rdsum_config == NULL || rsum_config == NULL || cvt_config == NULL ||
s32_f32_cvt_config == NULL || u32_f32_cvt_config == NULL) {
if (rdsum_config == NULL || rsum_config == NULL || cvt_config == NULL) {
xnn_log_error(
"failed to create %s (%s) operator: unsupported hardware configuration",
xnn_operator_type_to_string(operator_type), xnn_datatype_to_string(datatype));
return xnn_status_unsupported_hardware;
} else {
cvt_config = cvt_config == unused ? NULL : cvt_config;
s32_f32_cvt_config = s32_f32_cvt_config == unused ? NULL : s32_f32_cvt_config;
u32_f32_cvt_config = u32_f32_cvt_config == unused ? NULL : u32_f32_cvt_config;
}

struct xnn_reduce_params params;
Expand All @@ -431,11 +431,15 @@ enum xnn_status xnn_create_reduce_nd(
if (rsum_config->init) {
params_size = rsum_config->init(&params, input_quantization, output_quantization);
}
union xnn_unary_uparams cvt_params;
size_t cvt_params_size = 0;
if (cvt_config && cvt_config->init) {
cvt_params_size = cvt_config->init(&cvt_params, NULL, input_quantization, output_quantization);
}

return create_reduce_nd(
flags, log2_data_element_size, log2_accumulator_element_size, operator_type,
rdsum_config, rsum_config, cvt_config, s32_f32_cvt_config,
u32_f32_cvt_config, &params, params_size, reduce_op_out);
rdsum_config, rsum_config, cvt_config, &params, params_size, &cvt_params, cvt_params_size, reduce_op_out);
}

enum xnn_status xnn_reshape_reduce_nd(
Expand Down
3 changes: 1 addition & 2 deletions src/xnnpack/compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -1302,9 +1302,8 @@ struct reduce_context {
xnn_rdsum_ukernel_fn rdsum;
} ukernel;
xnn_vunary_ukernel_fn cvt_ukernel;
xnn_vunary_ukernel_fn s32_f32_cvt_ukernel;
xnn_vunary_ukernel_fn u32_f32_cvt_ukernel;
struct xnn_reduce_params params;
union xnn_unary_uparams cvt_params;
};

#ifndef __cplusplus
Expand Down
3 changes: 1 addition & 2 deletions src/xnnpack/microfnptr.h
Original file line number Diff line number Diff line change
Expand Up @@ -2043,8 +2043,7 @@ typedef size_t (*xnn_init_reduce_params_fn)(

typedef size_t (*xnn_update_reduce_params_fn)(
struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)],
float scale,
int32_t num_elements);
float scale);

typedef size_t (*xnn_init_qs8_qc8w_conv_minmax_params_fn)(
union xnn_qs8_qc8w_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
Expand Down
3 changes: 1 addition & 2 deletions src/xnnpack/microparams-init.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@ DECLARE_INIT_REDUCE_PARAMS_FUNCTION(xnn_init_qu8_reduce_scalar_params);
#define DECLARE_UPDATE_REDUCE_PARAMS_FUNCTION(fn_name) \
XNN_INTERNAL size_t fn_name( \
struct xnn_reduce_params params[XNN_MIN_ELEMENTS(1)], \
float scale, \
int32_t num_elements);
float scale);

DECLARE_UPDATE_REDUCE_PARAMS_FUNCTION(xnn_update_f32_reduce_scalar_params);
DECLARE_UPDATE_REDUCE_PARAMS_FUNCTION(xnn_update_qs8_reduce_scalar_params);
Expand Down
2 changes: 0 additions & 2 deletions src/xnnpack/microparams.h
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,13 @@ struct xnn_f32_reduce_params {
};

struct xnn_qs8_reduce_params {
int32_t num_elements;
float scale;
float input_output_scale;
int8_t input_zero_point;
int8_t output_zero_point;
};

struct xnn_qu8_reduce_params {
int32_t num_elements;
float scale;
float input_output_scale;
uint8_t input_zero_point;
Expand Down
Loading

0 comments on commit c48c6e4

Please sign in to comment.