diff --git a/src/operator-run.c b/src/operator-run.c index 4b4f4028e14..6da02bba8d1 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -379,8 +379,8 @@ void xnn_compute_batched_packw_gemm_goi( void xnn_compute_hmp_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t group_index, size_t mr_block_start, - size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { + uint32_t uarch_index, size_t group_index, size_t nr_block_start, + size_t mr_block_start, size_t nr_block_size, size_t mr_block_size) { const size_t k_scaled = context->k_scaled; const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; @@ -448,20 +448,17 @@ void xnn_compute_hmp_grouped_gemm( void xnn_compute_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t group_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size) { + size_t group_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { xnn_compute_hmp_grouped_gemm(context, XNN_UARCH_DEFAULT, group_index, - mr_block_start, nr_block_start, mr_block_size, - nr_block_size); + nr_block_start, mr_block_start, nr_block_size, + mr_block_size); } void xnn_compute_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size) -{ + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size) { const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; @@ -480,11 +477,8 @@ void xnn_compute_gemm( void xnn_compute_dqgemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size) -{ + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size) { const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; @@ -504,8 +498,8 @@ void xnn_compute_dqgemm( void xnn_compute_hmp_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size) { + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset( mr_block_start, context->k_scaled, context->mr, context->kr, context->sr); const size_t cm_stride = context->cm_stride; @@ -523,10 +517,10 @@ void xnn_compute_hmp_qp8gemm( void xnn_compute_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, - size_t nr_block_size) { - xnn_compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, mr_block_start, - nr_block_start, mr_block_size, nr_block_size); + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size) { + xnn_compute_hmp_qp8gemm(context, XNN_UARCH_DEFAULT, nr_block_start, + mr_block_start, nr_block_size, mr_block_size); } void xnn_compute_spmm( @@ -2368,8 +2362,8 @@ void xnn_compute_rope( #if XNN_MAX_UARCH_TYPES > 1 void xnn_compute_hmp_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size) { + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; @@ -2382,32 +2376,31 @@ void xnn_compute_hmp_gemm( (void*)((uintptr_t)context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), cm_stride, context->cn_stride, context->fused_params); - } - - void xnn_compute_hmp_dqgemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size) - { - const size_t a_stride = context->a_stride; - const size_t cm_stride = context->cm_stride; +} - context->dq_ukernel.function[uarch_index]( - mr_block_size, - nr_block_size, - context->k_scaled, - (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), - a_stride, - (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), - (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), - cm_stride, - context->cn_stride, - context->fused_params, - (const void*) ((uintptr_t) &context->quantization_params[mr_block_start])); - } +void xnn_compute_hmp_dqgemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size) { + const size_t a_stride = context->a_stride; + const size_t cm_stride = context->cm_stride; + const size_t k_scaled = context->k_scaled; + const size_t cn_stride = context->cn_stride; + const uintptr_t a = (uintptr_t)context->a; + const void* packed_w = (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride); + const uintptr_t c = + (uintptr_t)context->c + (nr_block_start << context->log2_csize); + const void* fused_params = context->fused_params; + const void* quantization_params = + (const void*)((uintptr_t)&context->quantization_params[mr_block_start]); + + context->dq_ukernel.function[uarch_index]( + mr_block_size, nr_block_size, k_scaled, + (const void*)(a + mr_block_start * a_stride), a_stride, + (const void*)packed_w, (void*)(c + mr_block_start * cm_stride), cm_stride, + cn_stride, fused_params, quantization_params); +} void xnn_compute_hmp_grouped_batch_igemm( const struct igemm_context context[restrict XNN_MIN_ELEMENTS(1)], diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index ccf49c4f8c5..2f408bcf446 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -646,10 +646,10 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_gemm; #endif gemm_compute->range[0] = batch_size_c; - gemm_compute->range[1] = m; - gemm_compute->range[2] = n; - gemm_compute->tile[0] = mr; - gemm_compute->tile[1] = nc; + gemm_compute->range[2] = m; + gemm_compute->range[1] = n; + gemm_compute->tile[1] = mr; + gemm_compute->tile[0] = nc; batch_matrix_multiply_op->state = xnn_run_state_needs_setup; return xnn_status_success; diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index 5269796d896..8d13befc74b 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -1962,10 +1962,10 @@ static enum xnn_status reshape_gemm( convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; convolution_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; #endif - convolution_op->compute[0].range[0] = batch_output_size; - convolution_op->compute[0].range[1] = group_output_channels; - convolution_op->compute[0].tile[0] = mr; - convolution_op->compute[0].tile[1] = nc; + convolution_op->compute[0].range[1] = batch_output_size; + convolution_op->compute[0].range[0] = group_output_channels; + convolution_op->compute[0].tile[1] = mr; + convolution_op->compute[0].tile[0] = nc; } else { #if XNN_MAX_UARCH_TYPES > 1 if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { @@ -1980,10 +1980,10 @@ static enum xnn_status reshape_gemm( convolution_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm; #endif convolution_op->compute[0].range[0] = groups; - convolution_op->compute[0].range[1] = batch_output_size; - convolution_op->compute[0].range[2] = group_output_channels; - convolution_op->compute[0].tile[0] = mr; - convolution_op->compute[0].tile[1] = nc; + convolution_op->compute[0].range[2] = batch_output_size; + convolution_op->compute[0].range[1] = group_output_channels; + convolution_op->compute[0].tile[1] = mr; + convolution_op->compute[0].tile[0] = nc; } convolution_op->state = xnn_run_state_needs_setup; diff --git a/src/operators/dynamic-fully-connected-nc.c b/src/operators/dynamic-fully-connected-nc.c index cf5be914310..c11351647ae 100644 --- a/src/operators/dynamic-fully-connected-nc.c +++ b/src/operators/dynamic-fully-connected-nc.c @@ -396,21 +396,27 @@ static enum xnn_status reshape_dynamic_fully_connected_nc( pthreadpool_get_threads_count(threadpool)); #if XNN_MAX_UARCH_TYPES > 1 - if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { - dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d_with_uarch; - dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; - } else { - dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d; - dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; - } - #else - dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d; - dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; - #endif - dynamic_fully_connected_op->compute[1].range[0] = batch_size; - dynamic_fully_connected_op->compute[1].range[1] = output_channels; - dynamic_fully_connected_op->compute[1].tile[0] = mr; - dynamic_fully_connected_op->compute[1].tile[1] = nc; + if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { + dynamic_fully_connected_op->compute[1].type = + xnn_parallelization_type_2d_tile_2d_with_uarch; + dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = + (pthreadpool_task_2d_tile_2d_with_id_t)xnn_compute_hmp_gemm; + } else { + dynamic_fully_connected_op->compute[1].type = + xnn_parallelization_type_2d_tile_2d; + dynamic_fully_connected_op->compute[1].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm; + } +#else + dynamic_fully_connected_op->compute[1].type = + xnn_parallelization_type_2d_tile_2d; + dynamic_fully_connected_op->compute[1].task_2d_tile_2d = + (pthreadpool_task_2d_tile_2d_t)xnn_compute_gemm; +#endif + dynamic_fully_connected_op->compute[1].range[1] = batch_size; + dynamic_fully_connected_op->compute[1].range[0] = output_channels; + dynamic_fully_connected_op->compute[1].tile[1] = mr; + dynamic_fully_connected_op->compute[1].tile[0] = nc; dynamic_fully_connected_op->state = xnn_run_state_needs_setup; return xnn_status_success; diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index c74784e279f..2f61f53d636 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -2019,10 +2019,10 @@ static enum xnn_status reshape_fully_connected_nc( fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; } #endif - fully_connected_op->compute[0].range[0] = batch_size; - fully_connected_op->compute[0].range[1] = output_channels; - fully_connected_op->compute[0].tile[0] = mr; - fully_connected_op->compute[0].tile[1] = nc; + fully_connected_op->compute[0].range[1] = batch_size; + fully_connected_op->compute[0].range[0] = output_channels; + fully_connected_op->compute[0].tile[1] = mr; + fully_connected_op->compute[0].tile[0] = nc; fully_connected_op->state = xnn_run_state_needs_setup; return xnn_status_success; diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 835e7d13e2e..fddfe229f49 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -345,59 +345,59 @@ struct gemm_context { XNN_PRIVATE void xnn_compute_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], size_t group_index, - size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); + size_t mr_block_start, + size_t nr_block_size, + size_t mr_block_size); XNN_PRIVATE void xnn_compute_dqgemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); + size_t mr_block_start, + size_t nr_block_size, + size_t mr_block_size); XNN_PRIVATE void xnn_compute_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); + size_t mr_block_start, + size_t nr_block_size, + size_t mr_block_size); XNN_PRIVATE void xnn_compute_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, - size_t nr_block_size); + size_t nr_block_start, size_t mr_block_start, size_t nr_block_size, + size_t mr_block_size); #if XNN_MAX_UARCH_TYPES > 1 XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t group_index, - size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); + size_t mr_block_start, + size_t nr_block_size, + size_t mr_block_size); XNN_PRIVATE void xnn_compute_hmp_gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, - size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); + size_t mr_block_start, + size_t nr_block_size, + size_t mr_block_size); XNN_PRIVATE void xnn_compute_hmp_dqgemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, - size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); + size_t mr_block_start, + size_t nr_block_size, + size_t mr_block_size); XNN_PRIVATE void xnn_compute_hmp_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size); + uint32_t uarch_index, size_t nr_block_start, size_t mr_block_start, + size_t nr_block_size, size_t mr_block_size); #endif // XNN_MAX_UARCH_TYPES > 1 #endif