Skip to content

Commit

Permalink
#5644: Add backward support for repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
ruthreshx committed Feb 26, 2024
1 parent 634a8a6 commit 9b67f2a
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 0 deletions.
2 changes: 2 additions & 0 deletions docs/source/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,8 @@ Backward Operations

.. autofunction:: tt_lib.tensor.hypot_bw

.. autofunction:: tt_lib.tensor.repeat_bw

Loss Functions
==============

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_results


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
),
)
@pytest.mark.parametrize("sizes", [[12, 1, 1, 1], [6, 1, 1, 1], [1, 24, 1, 1], [1, 3, 1, 1]])
def test_bw_repeat(input_shapes, sizes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)

pyt_y = in_data.repeat(sizes)

grad_data, grad_tensor = data_gen_pt_tt(pyt_y.shape, device, True)

tt_output_tensor_on_device = tt_lib.tensor.repeat_bw(grad_tensor, input_tensor, sizes)

in_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
status = compare_results(tt_output_tensor_on_device, golden_tensor)
assert status
39 changes: 39 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "tt_dnn/op_library/composite/composite_ops.hpp"
#include "tt_dnn/op_library/backward/backward_ops.hpp"
#include "tt_dnn/op_library/reduce/reduce_op.hpp"
#include "tt_dnn/op_library/reshape/reshape_op.hpp"
#include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp"
#include "tt_dnn/op_library/embeddings/embeddings_op.hpp"
#include "tt_numpy/functions.hpp"
#include "tt_eager/tensor/tensor_utils.hpp"
Expand Down Expand Up @@ -570,6 +572,43 @@ std::vector<Tensor> exp2_bw(const Tensor& grad, const Tensor& input, const Memor
}


// Repeat Backward
std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto shape_wh = input.shape();
TT_ASSERT( shape_wh[0] == 1 && "input shape[0] should be 1");
// If repeat shape has 0's, it returns zeros of given input
if (shape[0] == 0 || shape[1] == 0 || shape[2] == 0 || shape[3] == 0) {
Tensor zero_tensor = zeros_like(input, output_mem_config);
grad_tensor.emplace_back(zero_tensor);
return grad_tensor;
}
else if (shape[0] > 1){
std::vector<int64_t> dim = {0};
TT_ASSERT( shape[1] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[1], [2], [3] should be 1");
Shape required = {1, shape_wh[1], shape_wh[2], shape_wh[3]};
Tensor result = tt::operations::primary::moreh_sum(grad, zeros(required, input.dtype(), input.layout(), input.device(), output_mem_config), dim, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
else if (shape[1] > 1)
{
std::vector<int64_t> dim = {1};
TT_ASSERT( shape[0] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[0], [2], [3] should be 1");
Shape required = {shape_wh[0], 1, shape_wh[2], shape_wh[3]};
Tensor result = tt::operations::primary::moreh_sum(grad, zeros(required, input.dtype(), input.layout(), input.device(), output_mem_config), dim, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
return grad_tensor;

}
std::vector<Tensor> repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config);
}


}//namespace tt_metal

}//namespace tt
2 changes: 2 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ std::vector<Tensor> exp2_bw(const Tensor& grad, const Tensor& input, const Memor

std::vector<Tensor> expm1_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config);

} //namespace tt_metal

} //namespace tt
16 changes: 16 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,22 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("repeat_bw", &tt::tt_metal::repeat_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("shape"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns a new tensor filled with repetition of input ``input`` tensor according to number of times specified in ``shape``. The rank of ``shape`` should be same as rank of tensor ``input_a``.
The limitation in our implementation is N and C should be 1 and the repeat is of any number for such dim, other should be 1.
Output tensor will have BFLOAT16 data type.
.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"
"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Input tensor for which repetition is computed", "Tensor", "Tensor of shape [1, Z, Y, X]", "Yes"
"shape", "Shape value", "Shape", "The number of times to repeat this tensor along each dimension", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("unary_sub_bw", &tt::tt_metal::unary_sub_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for subraction of ``input`` tensors with given ``grad``.
Expand Down

0 comments on commit 9b67f2a

Please sign in to comment.