Skip to content

Commit

Permalink
Fix optimal nc computation in batch-matrix-multiply, `convolution…
Browse files Browse the repository at this point in the history
…-nhwc`, and `dynamic-fully-connected` ops.

This is the same fix that had already been applied to the `fully-connected` op.

This speeds up the GEMMs for two reasons:
 * Larger tile sizes lead to fewer tiles and thus fewer overheads, e.g. in `pthreadpool`,
 * Getting the minimum number of tiles (e.g. when running with 8 threads) right provides better load balancing.

PiperOrigin-RevId: 700327559
  • Loading branch information
gonnet authored and xnnpack-bot committed Nov 26, 2024
1 parent a4e966a commit 24eabea
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 45 deletions.
27 changes: 27 additions & 0 deletions src/microkernel-utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,33 @@
#include "xnnpack/math.h"
#include "xnnpack/microkernel-utils.h"

size_t xnn_gemm_best_nc(size_t num_groups, size_t m, size_t n, size_t mr,
size_t nr, size_t num_threads) {
const size_t target_tiles_per_thread = 5;
size_t nc = n;
if (num_threads > 1) {
const size_t num_tile_rows = divide_round_up(m, mr) * num_groups;
const size_t max_nc = divide_round_up(
n * num_tile_rows, num_threads * target_tiles_per_thread);
if (max_nc < nc) {
nc = min(nc, divide_round_up(nc, divide_round_up(nc, max_nc) * nr) * nr);
}
}

#ifndef NDEBUG
// Verify that we indeed have at least `target_tiles_per_thread` tiles per
// thread.
if (num_threads > 1 && nr < nc) {
const size_t num_tiles_m = divide_round_up(m, mr);
const size_t num_tiles_n = divide_round_up(n, nc);
const size_t num_tiles = num_groups * num_tiles_m * num_tiles_n;
assert(target_tiles_per_thread * num_threads <= num_tiles);
}
#endif // NDEBUG

return nc;
}

static size_t dwconv_num_middle_pass(
size_t kernel_size,
size_t first_pass_tile,
Expand Down
13 changes: 3 additions & 10 deletions src/operators/batch-matrix-multiply-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "xnnpack/math.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microkernel-type.h"
#include "xnnpack/microkernel-utils.h"
#include "xnnpack/microparams-init.h"
#include "xnnpack/microparams.h"
#include "xnnpack/operator-type.h"
Expand Down Expand Up @@ -627,17 +628,9 @@ static enum xnn_status reshape_batch_matrix_multiply_nc(
memcpy(&batch_matrix_multiply_op->context.gemm.gemm.gemm.params, params, params_size);
batch_matrix_multiply_op->context.gemm.gemm.gemm.fused_params = &batch_matrix_multiply_op->context.gemm.gemm.gemm.params;

size_t nc = n;
if (num_threads > 1) {
const size_t num_other_tiles = divide_round_up(m, mr);
const size_t target_tiles_per_thread = 5;
const size_t max_nc = divide_round_up(n * num_other_tiles, num_threads * target_tiles_per_thread);
if (max_nc < nc) {
nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
}
}
size_t nc = xnn_gemm_best_nc(batch_size_c, m, n, mr, nr, num_threads);

#if XNN_MAX_UARCH_TYPES > 1
#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
gemm_compute->type = xnn_parallelization_type_3d_tile_2d_with_uarch;
gemm_compute->task_3d_tile_2d_with_id =
Expand Down
11 changes: 2 additions & 9 deletions src/operators/convolution-nhwc.c
Original file line number Diff line number Diff line change
Expand Up @@ -1946,15 +1946,8 @@ static enum xnn_status reshape_gemm(
memcpy(&convolution_op->context.gemm.gemm.gemm.params, &convolution_op->params, sizeof(convolution_op->context.gemm.gemm.gemm.params));
convolution_op->context.gemm.gemm.gemm.fused_params = &convolution_op->context.gemm.gemm.gemm.params;

size_t nc = group_output_channels;
if (num_threads > 1) {
const size_t num_other_tiles = groups * divide_round_up(batch_output_size, mr);
const size_t target_tiles_per_thread = 5;
const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
if (max_nc < nc) {
nc = min(nc, divide_round_up(nc, max_nc * nr) * nr);
}
}
size_t nc = xnn_gemm_best_nc(groups, batch_output_size, group_output_channels,
mr, nr, num_threads);

if (groups == 1) {
#if XNN_MAX_UARCH_TYPES > 1
Expand Down
19 changes: 6 additions & 13 deletions src/operators/dynamic-fully-connected-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include "xnnpack/config.h"
#include "xnnpack/log.h"
#include "xnnpack/math.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microkernel-type.h"
#include "xnnpack/microkernel-utils.h"
#include "xnnpack/microparams.h"
#include "xnnpack/operator-type.h"
#include "xnnpack/operator.h"
Expand Down Expand Up @@ -389,20 +391,11 @@ static enum xnn_status reshape_dynamic_fully_connected_nc(
}
dynamic_fully_connected_op->context.gemm.gemm.gemm.fused_params = &dynamic_fully_connected_op->context.gemm.gemm.gemm.params;

size_t nc = output_channels;
const size_t num_threads = pthreadpool_get_threads_count(threadpool);
if (num_threads > 1) {
const size_t num_other_tiles = divide_round_up(batch_size, mr);
const size_t target_tiles_per_thread = 5;
const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
if (max_nc < nc) {
nc = min(nc, divide_round_up(output_channels,
divide_round_up(nc, max_nc) * nr) *
nr);
}
}
size_t nc =
xnn_gemm_best_nc(/*num_groups=*/1, batch_size, output_channels, mr, nr,
pthreadpool_get_threads_count(threadpool));

#if XNN_MAX_UARCH_TYPES > 1
#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d_with_uarch;
dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm;
Expand Down
18 changes: 5 additions & 13 deletions src/operators/fully-connected-nc.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "xnnpack/math.h"
#include "xnnpack/microfnptr.h"
#include "xnnpack/microkernel-type.h"
#include "xnnpack/microkernel-utils.h"
#include "xnnpack/microparams-init.h"
#include "xnnpack/microparams.h"
#include "xnnpack/operator-type.h"
Expand Down Expand Up @@ -1981,20 +1982,11 @@ static enum xnn_status reshape_fully_connected_nc(
memcpy(&fully_connected_op->context.gemm.gemm.gemm.params, params, params_size);
fully_connected_op->context.gemm.gemm.gemm.fused_params = &fully_connected_op->context.gemm.gemm.gemm.params;

size_t nc = output_channels;
const size_t num_threads = pthreadpool_get_threads_count(threadpool);
if (num_threads > 1) {
const size_t num_other_tiles = divide_round_up(batch_size, mr);
const size_t target_tiles_per_thread = 5;
const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread);
if (max_nc < nc) {
nc = min(nc, divide_round_up(output_channels,
divide_round_up(nc, max_nc) * nr) *
nr);
}
}
size_t nc =
xnn_gemm_best_nc(/*num_groups=*/1, batch_size, output_channels, mr, nr,
pthreadpool_get_threads_count(threadpool));

#if XNN_MAX_UARCH_TYPES > 1
#if XNN_MAX_UARCH_TYPES > 1
if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) {
fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_with_uarch;
if (dynamic_quantization) {
Expand Down
5 changes: 5 additions & 0 deletions src/xnnpack/microkernel-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
extern "C" {
#endif

// Computes the largest `nc`, the largest multiple of `nr` such that there are
// at least five tiles per thread (if `num_threads > 1`).
size_t xnn_gemm_best_nc(size_t num_groups, size_t m, size_t n, size_t mr,
size_t nr, size_t num_threads);

// The total tile size needed to cover kernel_size.
XNN_INTERNAL size_t xnn_dwconv_multipass_tile_size(
size_t kernel_size,
Expand Down

0 comments on commit 24eabea

Please sign in to comment.