Skip to content

Commit

Permalink
#9849: Move checks on batch dims for matmul to validate
Browse files Browse the repository at this point in the history
- This adds these checks to matmul_multicore and matmul_multicore_reuse as an intended side effect
  • Loading branch information
TT-BrianLiu committed Jul 8, 2024
1 parent 9506160 commit 7d03e02
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 48 deletions.
21 changes: 18 additions & 3 deletions tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,8 +750,9 @@ void Matmul::validate(
TT_FATAL(input_tensors.size() == 2);
const auto& input_tensor_a = input_tensors.at(0);
const auto& input_tensor_b = input_tensors.at(1);
auto a_shape = input_tensor_a.get_shape();
auto b_shape = input_tensor_b.get_shape();
const auto& a_shape = input_tensor_a.get_shape();
const auto& b_shape = input_tensor_b.get_shape();

TT_FATAL(
(input_tensor_a.get_layout() == Layout::TILE && input_tensor_b.get_layout() == Layout::TILE),
"Inputs to matmul must be tilized");
Expand All @@ -761,6 +762,20 @@ void Matmul::validate(
a_shape[-1],
b_shape[-2]);

if (this->bcast_batch) {
TT_FATAL(
get_batch_size(b_shape) == 1 &&
"matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN or equivalent");
} else {
// same condition as above, different message
TT_FATAL(a_shape.rank() == b_shape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank");
for (auto i = 0; i < a_shape.rank() - 2; i++) {
TT_FATAL(
a_shape[i] == b_shape[i] &&
"bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent");
}
}

TT_FATAL(is_floating_point(input_tensor_a.get_dtype()), "Unsupported data format");
TT_FATAL(
input_tensor_a.storage_type() == StorageType::DEVICE and input_tensor_b.storage_type() == StorageType::DEVICE,
Expand All @@ -781,7 +796,7 @@ void Matmul::validate(
uint32_t bias_batch_size = get_batch_size(bias_shape);
TT_FATAL(bias_batch_size == 1, "Unsupported bias shape: batch size not equal to 1.");
TT_FATAL(bias_shape[-2] == TILE_HEIGHT, "Unsupported bias shape: second last dimension not equal to tile height");
TT_FATAL(bias_shape[-1] == input_tensor_b.get_legacy_shape()[-1], "Unsupported bias shape: last dimension not equal to second input's last dimension.");
TT_FATAL(bias_shape[-1] == b_shape[-1], "Unsupported bias shape: last dimension not equal to second input's last dimension.");
}

if (this->untilize_out) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1509,19 +1509,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_(
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
tt_metal::Buffer* in0_buffer = a.buffer();
tt_metal::Buffer* in1_buffer = b.buffer();
if (bcast_batch)
TT_FATAL(
get_batch_size(bshape) == 1 &&
"matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN or equivalent");
else {
// same condition as above, different message
TT_FATAL(ashape.rank() == bshape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank");
for (auto i = 0; i < ashape.rank() - 2; i++) {
TT_FATAL(
ashape[i] == bshape[i] &&
"bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent");
}
}
TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0);
TT_FATAL(in1_buffer->size() % in1_single_tile_size == 0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1232,19 +1232,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_2d_optimized_(
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
tt_metal::Buffer* in0_buffer = a.buffer();
tt_metal::Buffer* in1_buffer = b.buffer();
if (bcast_batch)
TT_FATAL(
get_batch_size(bshape) == 1 &&
"matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN or equivalent");
else {
// same condition as above, different message
TT_FATAL(ashape.rank() == bshape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank");
for (auto i = 0; i < ashape.rank() - 2; i++) {
TT_FATAL(
ashape[i] == bshape[i] &&
"bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent");
}
}
TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0);
TT_FATAL(in1_buffer->size() % in1_single_tile_size == 0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1152,12 +1152,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_dram_sharded_optimized_(
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
tt_metal::Buffer* in0_buffer = a.buffer();
tt_metal::Buffer* in1_buffer = b.buffer();
TT_FATAL(ashape.rank() == bshape.rank() && ashape.rank() >= 2 && "bmm (non-bcast matmul) expects input tensors of the same rank and must have rank >= 2");
for (auto i = 0; i < ashape.rank() - 2; i++) {
TT_FATAL(
ashape[i] == bshape[i] &&
"bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent");
}
TT_FATAL(in0_buffer->size() % in0_single_tile_size == 0);
TT_FATAL(in1_buffer->size() % in1_single_tile_size == 0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,19 +463,6 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_optimized_(const Tensor
uint32_t in1_single_tile_size = tt_metal::detail::TileSize(in1_data_format);
tt_metal::Buffer *in0_buffer = a.buffer();
tt_metal::Buffer *in1_buffer = b.buffer();
if (bcast_batch)
TT_FATAL(
get_batch_size(bshape) == 1 &&
"matmul (batch bcast variant) expects input tensors of shapes BCMK*11KN=BCMN or equivalent");
else {
// same condition as above, different message
TT_FATAL(ashape.rank() == bshape.rank() && "bmm (non-bcast matmul) expects input tensors of the same rank");
for (auto i = 0; i < ashape.rank() - 2; i++) {
TT_FATAL(
ashape[i] == bshape[i] &&
"bmm (non-bcast matmul) expects input tensors of shapes BCMK*BCKN=BCMN or equivalent");
}
}

MathFidelity math_fidelity;
bool math_approx_mode;
Expand Down

0 comments on commit 7d03e02

Please sign in to comment.