Skip to content

Commit

Permalink
Exclude Padding from Shape Validation in Concat Operation #15308
Browse files Browse the repository at this point in the history
  • Loading branch information
shwetankTT committed Nov 21, 2024
1 parent 0ab59b7 commit 87ed00b
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ConcatOpParallelizationStrategy ConcatDeviceOperation::get_parallelization_strat

void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) const {
const auto &first_input = input_tensors[0];
tt::tt_metal::LegacyShape shape_first = first_input.get_legacy_shape();
ttnn::SimpleShape shape_first = first_input.get_logical_shape();
TT_FATAL(this->dim < shape_first.rank(), "ConcatDeviceOperation dim specified is larger than input tensor rank.");
shape_first[this->dim] = 0;
bool shard_first = input_tensors[0].is_sharded();
Expand All @@ -38,7 +38,7 @@ void ConcatDeviceOperation::validate(const std::vector<Tensor> &input_tensors) c
TT_FATAL(in_ref.device() == first_input.device(), "Operands to concat need to be on the same device.");
TT_FATAL(in_ref.get_layout() == first_input.get_layout(), "All Tensors should have same layouts.");
TT_FATAL(in_ref.get_dtype() == first_input.get_dtype(), "All Tensors should have same dtypes.");
tt::tt_metal::LegacyShape curr_shape = in_ref.get_legacy_shape();
ttnn::SimpleShape curr_shape = in_ref.get_logical_shape();
TT_FATAL(curr_shape.rank() == shape_first.rank(), "Input tensor ranks must be equal");
curr_shape[this->dim] = 0;
// last tensor can support without any kernel changes
Expand Down

0 comments on commit 87ed00b

Please sign in to comment.