diff --git a/src/microkernel-utils.c b/src/microkernel-utils.c index bf81a8db233..e40bd61b8a8 100644 --- a/src/microkernel-utils.c +++ b/src/microkernel-utils.c @@ -9,6 +9,33 @@ #include "xnnpack/math.h" #include "xnnpack/microkernel-utils.h" +size_t xnn_gemm_best_nc(size_t num_groups, size_t m, size_t n, size_t mr, + size_t nr, size_t num_threads) { + const size_t target_tiles_per_thread = 5; + size_t nc = n; + if (num_threads > 1) { + const size_t num_tile_rows = divide_round_up(m, mr) * num_groups; + const size_t max_nc = divide_round_up( + n * num_tile_rows, num_threads * target_tiles_per_thread); + if (max_nc < nc) { + nc = min(nc, divide_round_up(nc, divide_round_up(nc, max_nc) * nr) * nr); + } + } + +#ifndef NDEBUG + // Verify that we indeed have at least `target_tiles_per_thread` tiles per + // thread. + if (num_threads > 1 && nr < nc) { + const size_t num_tiles_m = divide_round_up(m, mr); + const size_t num_tiles_n = divide_round_up(n, nc); + const size_t num_tiles = num_groups * num_tiles_m * num_tiles_n; + assert(target_tiles_per_thread * num_threads <= num_tiles); + } +#endif // NDEBUG + + return nc; +} + static size_t dwconv_num_middle_pass( size_t kernel_size, size_t first_pass_tile, diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index c828d6a4ae5..ccf49c4f8c5 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -19,6 +19,7 @@ #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microkernel-type.h" +#include "xnnpack/microkernel-utils.h" #include "xnnpack/microparams-init.h" #include "xnnpack/microparams.h" #include "xnnpack/operator-type.h" @@ -627,17 +628,9 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( memcpy(&batch_matrix_multiply_op->context.gemm.gemm.gemm.params, params, params_size); batch_matrix_multiply_op->context.gemm.gemm.gemm.fused_params = &batch_matrix_multiply_op->context.gemm.gemm.gemm.params; - size_t nc = n; - if (num_threads > 1) { - const size_t num_other_tiles = divide_round_up(m, mr); - const size_t target_tiles_per_thread = 5; - const size_t max_nc = divide_round_up(n * num_other_tiles, num_threads * target_tiles_per_thread); - if (max_nc < nc) { - nc = min(nc, divide_round_up(nc, max_nc * nr) * nr); - } - } + size_t nc = xnn_gemm_best_nc(batch_size_c, m, n, mr, nr, num_threads); - #if XNN_MAX_UARCH_TYPES > 1 +#if XNN_MAX_UARCH_TYPES > 1 if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { gemm_compute->type = xnn_parallelization_type_3d_tile_2d_with_uarch; gemm_compute->task_3d_tile_2d_with_id = diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index 0692799b06b..5269796d896 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -1946,15 +1946,8 @@ static enum xnn_status reshape_gemm( memcpy(&convolution_op->context.gemm.gemm.gemm.params, &convolution_op->params, sizeof(convolution_op->context.gemm.gemm.gemm.params)); convolution_op->context.gemm.gemm.gemm.fused_params = &convolution_op->context.gemm.gemm.gemm.params; - size_t nc = group_output_channels; - if (num_threads > 1) { - const size_t num_other_tiles = groups * divide_round_up(batch_output_size, mr); - const size_t target_tiles_per_thread = 5; - const size_t max_nc = divide_round_up(group_output_channels * num_other_tiles, num_threads * target_tiles_per_thread); - if (max_nc < nc) { - nc = min(nc, divide_round_up(nc, max_nc * nr) * nr); - } - } + size_t nc = xnn_gemm_best_nc(groups, batch_output_size, group_output_channels, + mr, nr, num_threads); if (groups == 1) { #if XNN_MAX_UARCH_TYPES > 1 diff --git a/src/operators/dynamic-fully-connected-nc.c b/src/operators/dynamic-fully-connected-nc.c index c3eb10fab03..cf5be914310 100644 --- a/src/operators/dynamic-fully-connected-nc.c +++ b/src/operators/dynamic-fully-connected-nc.c @@ -17,7 +17,9 @@ #include "xnnpack/config.h" #include "xnnpack/log.h" #include "xnnpack/math.h" +#include "xnnpack/microfnptr.h" #include "xnnpack/microkernel-type.h" +#include "xnnpack/microkernel-utils.h" #include "xnnpack/microparams.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator.h" @@ -389,20 +391,11 @@ static enum xnn_status reshape_dynamic_fully_connected_nc( } dynamic_fully_connected_op->context.gemm.gemm.gemm.fused_params = &dynamic_fully_connected_op->context.gemm.gemm.gemm.params; - size_t nc = output_channels; - const size_t num_threads = pthreadpool_get_threads_count(threadpool); - if (num_threads > 1) { - const size_t num_other_tiles = divide_round_up(batch_size, mr); - const size_t target_tiles_per_thread = 5; - const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread); - if (max_nc < nc) { - nc = min(nc, divide_round_up(output_channels, - divide_round_up(nc, max_nc) * nr) * - nr); - } - } + size_t nc = + xnn_gemm_best_nc(/*num_groups=*/1, batch_size, output_channels, mr, nr, + pthreadpool_get_threads_count(threadpool)); - #if XNN_MAX_UARCH_TYPES > 1 +#if XNN_MAX_UARCH_TYPES > 1 if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d_with_uarch; dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 5a81fc86e83..c74784e279f 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -24,6 +24,7 @@ #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microkernel-type.h" +#include "xnnpack/microkernel-utils.h" #include "xnnpack/microparams-init.h" #include "xnnpack/microparams.h" #include "xnnpack/operator-type.h" @@ -1981,20 +1982,11 @@ static enum xnn_status reshape_fully_connected_nc( memcpy(&fully_connected_op->context.gemm.gemm.gemm.params, params, params_size); fully_connected_op->context.gemm.gemm.gemm.fused_params = &fully_connected_op->context.gemm.gemm.gemm.params; - size_t nc = output_channels; - const size_t num_threads = pthreadpool_get_threads_count(threadpool); - if (num_threads > 1) { - const size_t num_other_tiles = divide_round_up(batch_size, mr); - const size_t target_tiles_per_thread = 5; - const size_t max_nc = divide_round_up(output_channels * num_other_tiles, num_threads * target_tiles_per_thread); - if (max_nc < nc) { - nc = min(nc, divide_round_up(output_channels, - divide_round_up(nc, max_nc) * nr) * - nr); - } - } + size_t nc = + xnn_gemm_best_nc(/*num_groups=*/1, batch_size, output_channels, mr, nr, + pthreadpool_get_threads_count(threadpool)); - #if XNN_MAX_UARCH_TYPES > 1 +#if XNN_MAX_UARCH_TYPES > 1 if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_with_uarch; if (dynamic_quantization) { diff --git a/src/xnnpack/microkernel-utils.h b/src/xnnpack/microkernel-utils.h index dda4f47fa57..17965325b19 100644 --- a/src/xnnpack/microkernel-utils.h +++ b/src/xnnpack/microkernel-utils.h @@ -13,6 +13,11 @@ extern "C" { #endif +// Computes the largest `nc`, the largest multiple of `nr` such that there are +// at least five tiles per thread (if `num_threads > 1`). +size_t xnn_gemm_best_nc(size_t num_groups, size_t m, size_t n, size_t mr, + size_t nr, size_t num_threads); + // The total tile size needed to cover kernel_size. XNN_INTERNAL size_t xnn_dwconv_multipass_tile_size( size_t kernel_size,