Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Add the dim 0 support repeat backward #5596

Merged
merged 1 commit into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/ttnn/ttnn/dependencies/tt_lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,8 @@ Backward Operations

.. autofunction:: tt_lib.tensor.multigammaln_bw

.. autofunction:: tt_lib.tensor.repeat_bw

Loss Functions
==============

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import tt_lib
from tests.tt_eager.python_api_testing.unit_testing.backward_ops.utility_funcs import data_gen_pt_tt, compare_pcc


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
),
)
@pytest.mark.parametrize("sizes", [[12, 1, 1, 1], [6, 1, 1, 1], [1, 24, 1, 1], [1, 3, 1, 1]])
def test_bw_repeat(input_shapes, sizes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)

pyt_y = in_data.repeat(sizes)

grad_data, grad_tensor = data_gen_pt_tt(pyt_y.shape, device, True)

tt_output_tensor_on_device = tt_lib.tensor.repeat_bw(grad_tensor, input_tensor, sizes)

in_data.retain_grad()

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad]
status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
40 changes: 40 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "tt_dnn/op_library/composite/composite_ops.hpp"
#include "tt_dnn/op_library/backward/backward_ops.hpp"
#include "tt_dnn/op_library/reduce/reduce_op.hpp"
#include "tt_dnn/op_library/reshape/reshape_op.hpp"
#include "tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp"
#include "tt_dnn/op_library/embeddings/embeddings_op.hpp"
#include "tt_numpy/functions.hpp"
#include "tt_eager/tensor/tensor_utils.hpp"
Expand Down Expand Up @@ -1783,6 +1785,44 @@ std::vector<Tensor> multigammaln_bw(const Tensor& grad, const Tensor& input, con
return operation::decorate_as_composite(__func__, _multigammaln_bw)(grad, input, output_mem_config);
}

// Repeat Backward
std::vector<Tensor> _repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto shape_wh = input.get_legacy_shape();
TT_FATAL( shape_wh[0] == 1 && "input shape[0] should be 1");
// input.get_legacy_shape()[0]
// If repeat shape has 0's, it returns zeros of given input
if (shape[0] == 0 || shape[1] == 0 || shape[2] == 0 || shape[3] == 0) {
Tensor zero_tensor = zeros_like(input, output_mem_config);
grad_tensor.emplace_back(zero_tensor);
return grad_tensor;
}
else if (shape[0] > 1){
std::vector<int64_t> dim = {0};
TT_FATAL( shape[1] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[1], [2], [3] should be 1");
Shape required = {1, shape_wh[1], shape_wh[2], shape_wh[3]};
Tensor result = tt::operations::primary::moreh_sum(grad, zeros(required, input.get_dtype(), input.get_layout(), input.device(), output_mem_config), dim, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
else if (shape[1] > 1)
{
std::vector<int64_t> dim = {1};
TT_FATAL( shape[0] == 1 && shape[2] == 1 && shape[3] == 1 && "repeat[0], [2], [3] should be 1");
Shape required = {shape_wh[0], 1, shape_wh[2], shape_wh[3]};
Tensor result = tt::operations::primary::moreh_sum(grad, zeros(required, input.get_dtype(), input.get_layout(), input.device(), output_mem_config), dim, output_mem_config);
grad_tensor.emplace_back(result);
return grad_tensor;
}
return grad_tensor;

}
std::vector<Tensor> repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _repeat_bw)(grad, input, shape, output_mem_config);
}


}//namespace tt_metal

}//namespace tt
3 changes: 3 additions & 0 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,9 @@ std::vector<Tensor> imag_bw(const Tensor& grad, const Tensor& input, const Memor
std::vector<Tensor> real_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> multigammaln_bw(const Tensor& grad, const Tensor& input, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> repeat_bw(const Tensor& grad, const Tensor& input, const Shape& shape, const MemoryConfig& output_mem_config);

} //namespace tt_metal

} //namespace tt
16 changes: 16 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,22 @@ namespace tt::tt_metal::detail{
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("repeat_bw", &tt::tt_metal::repeat_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("shape"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns a new tensor filled with repetition of input ``input`` tensor according to number of times specified in ``shape``. The rank of ``shape`` should be same as rank of tensor ``input_a``.
The limitation in our implementation is N and C should be 1 and the repeat is of any number for such dim, other should be 1.

Output tensor will have BFLOAT16 data type.

.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"

"grad", "Gradient tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"input", "Input tensor for which repetition is computed", "Tensor", "Tensor of shape [1, Z, Y, X]", "Yes"
"shape", "Shape value", "Shape", "The number of times to repeat this tensor along each dimension", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

m_tensor.def("unary_sub_bw", &tt::tt_metal::unary_sub_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for subraction of ``input`` tensors with given ``grad``.
Expand Down
Loading