Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update brgemm and matmul initialization for some cases #2368

Merged
merged 2 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/cpu/x64/brgemm/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ status_t brgemm_desc_set_attr(
if (brg->is_dgmm)
CHECK(brdgmm_blocking(brg));
else
CHECK(brgemm_blocking(brg));
CHECK(brgemm_blocking(brg, true));
}

if (!brg->is_dgmm) {
Expand Down
21 changes: 20 additions & 1 deletion src/cpu/x64/brgemm/brgemm_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ int calculate_max_bcast_block(brgemm_desc_t *brg, const int adj_ld_block2) {
return max_bcast_block;
}

status_t brgemm_blocking(brgemm_desc_t *brg) {
status_t brgemm_blocking(brgemm_desc_t *brg, bool attr_blocking) {
const data_type_t ld_step_compute_dt
= get_mac_emu_data_type(brg->dt_b, brg->isa_impl,
brg->isa_impl != avx2_vnni_2 && !brg->is_fp8_via_convert());
Expand Down Expand Up @@ -750,6 +750,25 @@ status_t brgemm_blocking(brgemm_desc_t *brg) {
brg->rdb = brg->reduce_dim / brg->rd_block;
brg->rdb_tail = brg->reduce_dim % brg->rd_block;

// Remove these guards in the future (add tail processing by reduction
// dimension)
// TODO: these checks do not work for fp8-f16 and f16-fp8 cfgs
if (attr_blocking
&& !IMPLICATION(brg->rdb > 0 && brg->rdb_tail,
brg->is_input_convert() || brg->amx_wary_k_tail())) {
return status::unimplemented;
}

if (attr_blocking
&& !IMPLICATION(
(brg->rdb_tail
% ((brg->is_bf16_tmm || brg->is_f16_tmm) ? 2
: 4))
!= 0,
brg->is_input_convert() || brg->amx_wary_k_tail())) {
return status::unimplemented;
}

//TODO: check this condition
brg->interleave_tilestores_ = brg->beta == 0
&& (brg->brgattr.use_interleave_stores
Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/brgemm/brgemm_utils.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2022-2024 Intel Corporation
* Copyright 2022-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -34,7 +34,7 @@ bool can_dispatch_uker(const brgemm_desc_t *brg);

void maybe_try_bf32(brgemm_desc_t *brg);

status_t brgemm_blocking(brgemm_desc_t *brg);
status_t brgemm_blocking(brgemm_desc_t *brg, bool attr_blocking = false);

status_t brdgmm_blocking(brgemm_desc_t *brg);

Expand Down
4 changes: 2 additions & 2 deletions src/cpu/x64/matmul/brgemm_matmul_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2037,8 +2037,8 @@ void matmul_amx_blocking_params_t::set_blocking_parameters(

const dim_t current_k_tail = K % k_blk_;

extendable_k_
= !use_buffer_a && K % wei_k_blk && k_chunk_elems_ > wei_k_blk;
extendable_k_ = !use_buffer_a && K % wei_k_blk
&& k_chunk_elems_ > wei_k_blk && !packed_sparse_weights;

if (extendable_k_) {
if (k_chunk_elems_ >= K) {
Expand Down
Loading