Skip to content

Commit

Permalink
#0: perf check
Browse files Browse the repository at this point in the history
  • Loading branch information
KalaivaniMCW authored and Aswinmcw committed Nov 18, 2024
1 parent 55912dc commit e54a7c5
Showing 1 changed file with 32 additions and 6 deletions.
38 changes: 32 additions & 6 deletions ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,28 +108,54 @@ auto preprocess_inputs(const Tensor &input_tensor_a_arg, const Tensor &input_ten
Tensor input_tensor_a = input_tensor_a_arg;
Tensor input_tensor_b = input_tensor_b_arg;

auto shape_a = input_tensor_a.get_shape();
auto shape_b = input_tensor_b.get_shape();
auto rank_a = shape_a.rank();
auto rank_b = shape_b.rank();

if(rank_a != rank_b){

auto max_rank = std::max(rank_a, rank_b);
auto min_rank = std::min(rank_a, rank_b);

// if(optional_output_tensor.has_value()) {
// auto opt_rank = optional_output_tensor.value().get_shape().rank();
// TT_FATAL( max_rank == opt_rank,
// "Output Tensor rank {} doesn't match input tensor rank {}.", opt_rank, max_rank );
// }

std::vector<int32_t> shape_vector(max_rank, 1);
auto& reshaped_tensor = (rank_a > rank_b) ? input_tensor_b : input_tensor_a;
auto s_b = reshaped_tensor.get_shape();
for(int i=0; i < min_rank; ++i){
shape_vector[(max_rank - min_rank) + i] = s_b[i];
}
reshaped_tensor = ttnn::reshape(reshaped_tensor, shape_vector);

}

// TODO: #7731 (Remove calls to repeat )
auto repeat_smaller = [](const auto &first, auto &second) {
auto repeat_smaller = [](const auto &first, auto &second, auto first_rank, auto second_rank) {
const auto first_shape = first.get_shape();
const auto second_shape = second.get_shape();
// repeats second if it is smaller
if (first_shape.rank() == 4 and second_shape.rank() == 4 and first_shape[0] > second_shape[0]) {
if (first_rank == 4 and second_rank == 4 and first_shape[0] > second_shape[0]) {
TT_FATAL(second_shape[0] == 1, "Dimension trying to broadcast is not equal to 1");
Shape repeats(std::array<uint32_t, 4>{first_shape[0], 1, 1, 1});
second = ttnn::repeat(second, repeats);
}
// repeats second if it is smaller
if (first_shape.rank() >= 3 and second_shape.rank() >= 3 and first_shape[-3] > second_shape[-3]) {
if (first_rank >= 3 and second_rank >= 3 and first_shape[-3] > second_shape[-3]) {
TT_FATAL(second_shape[-3] == 1, "Dimension trying to broadcast is not equal to 1");
int rank_a = first_shape.rank();
int rank_a = first_rank;
std::vector<uint32_t> repeat_dim(rank_a, 1);
repeat_dim[rank_a - 3] = first_shape[rank_a - 3];
Shape repeats(repeat_dim);
second = ttnn::repeat(second, repeats);
}
};
repeat_smaller(input_tensor_a, input_tensor_b);
repeat_smaller(input_tensor_b, input_tensor_a);
repeat_smaller(input_tensor_a, input_tensor_b, rank_a, rank_b);
repeat_smaller(input_tensor_b, input_tensor_a, rank_b, rank_a);

return [](const auto &input_tensor_a, const auto &input_tensor_b) {
if constexpr (detail::is_associative(binary_op_type)) {
Expand Down

0 comments on commit e54a7c5

Please sign in to comment.