From a813ef3867ba429daf01fe78cce37b35a124f0af Mon Sep 17 00:00:00 2001 From: mouliraj-mcw Date: Tue, 3 Dec 2024 08:54:58 +0000 Subject: [PATCH] #15642: Replace shapes in eltwise --- .../eltwise/binary/device/binary_composite_op.cpp | 2 +- .../eltwise/binary/device/binary_device_operation.cpp | 8 ++++---- ...adcast_height_and_width_multi_core_program_factory.cpp | 8 ++++---- .../operations/eltwise/unary_backward/unary_backward.cpp | 6 +++--- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index f09d2b08e8ac..e064634da176 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -339,7 +339,7 @@ Tensor ExecutePrelu::invoke( if (s_a.rank() > 2) { SmallVector reshape(s_a.rank(), 1); reshape[1] = s_a[1]; - b = ttnn::reshape(input_b, ttnn::Shape(reshape)); + b = ttnn::reshape(input_b, ttnn::SimpleShape(reshape)); } Tensor result = ttnn::where(ttnn::ltz(input_a, output_mem_config), ttnn::multiply(input_a, b), input_a); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index ce524ac4ae6c..2026ab875514 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -81,10 +81,10 @@ BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_f } if (height_b == 1) { if (tensor_args.input_tensor_a.is_sharded()) { - if (tensor_args.input_tensor_a.get_padded_shape()[0] == - tensor_args.input_tensor_b->get_padded_shape()[0] || - tensor_args.input_tensor_a.get_padded_shape()[0] > 1 and - tensor_args.input_tensor_b->get_padded_shape()[0] == 1) { + if (tensor_args.input_tensor_a.get_logical_shape()[0] == + tensor_args.input_tensor_b->get_logical_shape()[0] || + tensor_args.input_tensor_a.get_logical_shape()[0] > 1 and + tensor_args.input_tensor_b->get_logical_shape()[0] == 1) { return BroadcastHeightMultiCoreShardedOptimized{}; } return BroadcastHeightMultiCoreSharded{}; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp index d54c9bdef6fe..37552a8bb443 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp @@ -44,8 +44,8 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( const auto& b = tensor_args.input_tensor_b; auto& output = tensor_return_value; auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); - const auto ashape = a.get_padded_shape(); - const auto bshape = b.has_value() ? b->get_padded_shape() : Shape{1, 1}; + const auto ashape = a.get_logical_shape(); + const auto bshape = b.has_value() ? b->get_logical_shape() : Shape{1, 1}; uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; @@ -298,8 +298,8 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a auto dst_buffer = output_tensor.buffer(); - const auto ashape = input_tensor_a.get_padded_shape(); - const auto bshape = input_tensor_b.has_value() ? input_tensor_b->get_padded_shape() : Shape{1, 1}; + const auto ashape = input_tensor_a.get_logical_shape(); + const auto bshape = input_tensor_b.has_value() ? input_tensor_b->get_logical_shape() : Shape{1, 1}; uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp index 423a0a6775ca..f012ecb14a18 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp @@ -1794,7 +1794,7 @@ std::vector ExecuteUnaryBackwardRepeat::invoke( ttnn::SmallVector dim = {0}; TT_FATAL(shape[1] == 1 && shape[2] == 1 && shape[3] == 1, "repeat[1], [2], [3] should be 1"); std::array intended_shape_array = {1, shape_wh[1], shape_wh[2], shape_wh[3]}; - const ttnn::Shape required = ttnn::Shape(intended_shape_array); + const auto required = ttnn::SimpleShape(intended_shape_array); Tensor result = ttnn::moreh_sum( grad, dim, @@ -1813,7 +1813,7 @@ std::vector ExecuteUnaryBackwardRepeat::invoke( ttnn::SmallVector dim = {1}; TT_FATAL(shape[0] == 1 && shape[2] == 1 && shape[3] == 1, "repeat[0], [2], [3] should be 1"); std::array intended_shape_array = {shape_wh[0], 1, shape_wh[2], shape_wh[3]}; - const ttnn::Shape required = ttnn::Shape(intended_shape_array); + const auto required = ttnn::SimpleShape(intended_shape_array); Tensor result = ttnn::moreh_sum( grad, dim, @@ -1874,7 +1874,7 @@ std::vector ExecuteUnaryBackwardProd::invoke( // all_dimensions = False Tensor updated_grad = prod_result; auto step = ttnn::SmallVector({1, 1, 1, 1}); - if (prod_result.get_logical_shape() != grad.get_padded_shape()) { + if (prod_result.get_logical_shape() != grad.get_logical_shape()) { if (dim == 3 || dim == -1) { ttnn::SmallVector after_permute_dims = {0, 3, 1, 2}; Tensor required = ttnn::permute(grad, after_permute_dims, output_memory_config);