From 87ed00b996c3f83bbb81772ec8c4c2812b869c5e Mon Sep 17 00:00:00 2001 From: Shwetank Singh Date: Thu, 21 Nov 2024 12:34:27 +0000 Subject: [PATCH] Exclude Padding from Shape Validation in Concat Operation #15308 --- .../data_movement/concat/device/concat_device_operation.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp index dd7054b7b43..377ecb7f83e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp @@ -25,7 +25,7 @@ ConcatOpParallelizationStrategy ConcatDeviceOperation::get_parallelization_strat void ConcatDeviceOperation::validate(const std::vector &input_tensors) const { const auto &first_input = input_tensors[0]; - tt::tt_metal::LegacyShape shape_first = first_input.get_legacy_shape(); + ttnn::SimpleShape shape_first = first_input.get_logical_shape(); TT_FATAL(this->dim < shape_first.rank(), "ConcatDeviceOperation dim specified is larger than input tensor rank."); shape_first[this->dim] = 0; bool shard_first = input_tensors[0].is_sharded(); @@ -38,7 +38,7 @@ void ConcatDeviceOperation::validate(const std::vector &input_tensors) c TT_FATAL(in_ref.device() == first_input.device(), "Operands to concat need to be on the same device."); TT_FATAL(in_ref.get_layout() == first_input.get_layout(), "All Tensors should have same layouts."); TT_FATAL(in_ref.get_dtype() == first_input.get_dtype(), "All Tensors should have same dtypes."); - tt::tt_metal::LegacyShape curr_shape = in_ref.get_legacy_shape(); + ttnn::SimpleShape curr_shape = in_ref.get_logical_shape(); TT_FATAL(curr_shape.rank() == shape_first.rank(), "Input tensor ranks must be equal"); curr_shape[this->dim] = 0; // last tensor can support without any kernel changes