diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index 8cbf4e001b64..b0020048eb11 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -108,28 +108,54 @@ auto preprocess_inputs(const Tensor &input_tensor_a_arg, const Tensor &input_ten Tensor input_tensor_a = input_tensor_a_arg; Tensor input_tensor_b = input_tensor_b_arg; + auto shape_a = input_tensor_a.get_shape(); + auto shape_b = input_tensor_b.get_shape(); + auto rank_a = shape_a.rank(); + auto rank_b = shape_b.rank(); + + if(rank_a != rank_b){ + + auto max_rank = std::max(rank_a, rank_b); + auto min_rank = std::min(rank_a, rank_b); + + // if(optional_output_tensor.has_value()) { + // auto opt_rank = optional_output_tensor.value().get_shape().rank(); + // TT_FATAL( max_rank == opt_rank, + // "Output Tensor rank {} doesn't match input tensor rank {}.", opt_rank, max_rank ); + // } + + std::vector shape_vector(max_rank, 1); + auto& reshaped_tensor = (rank_a > rank_b) ? input_tensor_b : input_tensor_a; + auto s_b = reshaped_tensor.get_shape(); + for(int i=0; i < min_rank; ++i){ + shape_vector[(max_rank - min_rank) + i] = s_b[i]; + } + reshaped_tensor = ttnn::reshape(reshaped_tensor, shape_vector); + + } + // TODO: #7731 (Remove calls to repeat ) - auto repeat_smaller = [](const auto &first, auto &second) { + auto repeat_smaller = [](const auto &first, auto &second, auto first_rank, auto second_rank) { const auto first_shape = first.get_shape(); const auto second_shape = second.get_shape(); // repeats second if it is smaller - if (first_shape.rank() == 4 and second_shape.rank() == 4 and first_shape[0] > second_shape[0]) { + if (first_rank == 4 and second_rank == 4 and first_shape[0] > second_shape[0]) { TT_FATAL(second_shape[0] == 1, "Dimension trying to broadcast is not equal to 1"); Shape repeats(std::array{first_shape[0], 1, 1, 1}); second = ttnn::repeat(second, repeats); } // repeats second if it is smaller - if (first_shape.rank() >= 3 and second_shape.rank() >= 3 and first_shape[-3] > second_shape[-3]) { + if (first_rank >= 3 and second_rank >= 3 and first_shape[-3] > second_shape[-3]) { TT_FATAL(second_shape[-3] == 1, "Dimension trying to broadcast is not equal to 1"); - int rank_a = first_shape.rank(); + int rank_a = first_rank; std::vector repeat_dim(rank_a, 1); repeat_dim[rank_a - 3] = first_shape[rank_a - 3]; Shape repeats(repeat_dim); second = ttnn::repeat(second, repeats); } }; - repeat_smaller(input_tensor_a, input_tensor_b); - repeat_smaller(input_tensor_b, input_tensor_a); + repeat_smaller(input_tensor_a, input_tensor_b, rank_a, rank_b); + repeat_smaller(input_tensor_b, input_tensor_a, rank_b, rank_a); return [](const auto &input_tensor_a, const auto &input_tensor_b) { if constexpr (detail::is_associative(binary_op_type)) {