Skip to content

Commit

Permalink
#6226: Add backward support for backward div
Browse files Browse the repository at this point in the history
  • Loading branch information
Aswinmcw committed Mar 11, 2024
1 parent 49a6887 commit 5136302
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,37 @@
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_div(input_shapes, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
@pytest.mark.parametrize(
"round_mode",
(
None,
"trunc",
"floor",
),
)
def test_bw_div(input_shapes, round_mode, device):
in_data = torch.Tensor(size=input_shapes).uniform_()
in_data.requires_grad = True
input_tensor = (
tt_lib.tensor.Tensor(in_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
grad_data = torch.Tensor(size=input_shapes).uniform_()
grad_tensor = (
tt_lib.tensor.Tensor(grad_data, tt_lib.tensor.DataType.BFLOAT16).to(tt_lib.tensor.Layout.TILE).to(device)
)
other_data, other_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)

tt_output_tensor_on_device = tt_lib.tensor.div_bw(grad_tensor, input_tensor, other_tensor)
pyt_y = torch.div(in_data, other_data, rounding_mode=round_mode)

if round_mode == None:
round_mode = "None"
tt_output_tensor_on_device = tt_lib.tensor.div_bw(grad_tensor, input_tensor, other_tensor, round_mode=round_mode)

in_data.retain_grad()
other_data.retain_grad()

pyt_y = torch.div(in_data, other_data)

pyt_y.backward(gradient=grad_data)

golden_tensor = [in_data.grad, other_data.grad]

status = compare_results(tt_output_tensor_on_device, golden_tensor)
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,28 @@
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize(
"round_mode",
(
"None",
"trunc",
"floor",
),
)
@pytest.mark.parametrize("scalar", [0.05, 1.0, 0.5, 0.12])
def test_bw_unary_div(input_shapes, scalar, device):
def test_bw_unary_div(input_shapes, scalar, round_mode, device):
in_data, input_tensor = data_gen_pt_tt(input_shapes, device, True)
grad_data, grad_tensor = data_gen_pt_tt(input_shapes, device)

tt_output_tensor_on_device = tt_lib.tensor.unary_div_bw(grad_tensor, input_tensor, scalar=scalar)
tt_output_tensor_on_device = tt_lib.tensor.unary_div_bw(
grad_tensor, input_tensor, scalar=scalar, round_mode=round_mode
)

in_data.retain_grad()

pyt_y = torch.div(in_data, torch.tensor(scalar))
if round_mode == "None":
round_mode = None
pyt_y = torch.div(in_data, torch.tensor(scalar), rounding_mode=round_mode)

pyt_y.backward(gradient=grad_data)

Expand Down
39 changes: 27 additions & 12 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,30 +151,45 @@ std::vector<Tensor> sqrt_bw(const Tensor& grad, const Tensor& input, const Memor
}


std::vector<Tensor> _unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config) {
std::vector<Tensor> _unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
float inv_scalar = 1.0f/scalar;
Tensor result = mul_unary(grad, inv_scalar, output_mem_config);
grad_tensor.emplace_back(result);
if (round_mode=="None"){
Tensor result = mul_unary(grad, inv_scalar, output_mem_config);
grad_tensor.emplace_back(result);
}
else{
Tensor result = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(result);
}
return grad_tensor;
}
std::vector<Tensor> unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config)
std::vector<Tensor> unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _unary_div_bw)(grad, input, scalar, output_mem_config);
return operation::decorate_as_composite(__func__, _unary_div_bw)(grad, input, scalar, round_mode, output_mem_config);
}


std::vector<Tensor> _div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) {
std::vector<Tensor> _div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, string round_mode, const MemoryConfig& output_mem_config) {
std::vector<Tensor> grad_tensor;
Tensor grad_a = mul(grad, recip(other, output_mem_config), std::nullopt, output_mem_config);
grad_tensor.emplace_back(grad_a);
Tensor grad_b = mul(mul(neg(grad, output_mem_config), input, std::nullopt, output_mem_config), recip(square(other, output_mem_config), output_mem_config), std::nullopt, output_mem_config);
grad_tensor.emplace_back(grad_b);
if (round_mode=="None"){
Tensor grad_a = mul(grad, recip(other, output_mem_config), std::nullopt, output_mem_config);
grad_tensor.emplace_back(grad_a);
Tensor grad_b = mul(neg(grad, output_mem_config) , (mul(input, recip(square(other, output_mem_config), output_mem_config), std::nullopt, output_mem_config)), std::nullopt, output_mem_config);
grad_tensor.emplace_back(grad_b);
}
else{
Tensor grad_a = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(grad_a);
Tensor grad_b = zeros_like(grad, output_mem_config);
grad_tensor.emplace_back(grad_b);
}

return grad_tensor;
}
std::vector<Tensor> div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config)
std::vector<Tensor> div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, string round_mode, const MemoryConfig& output_mem_config)
{
return operation::decorate_as_composite(__func__, _div_bw)(grad, input, other, output_mem_config);
return operation::decorate_as_composite(__func__, _div_bw)(grad, input, other, round_mode, output_mem_config);
}


Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_dnn/op_library/backward/backward_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ std::vector<Tensor> unary_assign_bw(const Tensor& grad, const Tensor& input, con

std::vector<Tensor> binary_assign_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
std::vector<Tensor> unary_div_bw(const Tensor& grad, const Tensor& input, float scalar, string round_mode, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);
std::vector<Tensor> div_bw(const Tensor& grad, const Tensor& input, const Tensor& other, string round_mode, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

std::vector<Tensor> max_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG);

Expand Down
4 changes: 2 additions & 2 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_backward_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ namespace tt::tt_metal::detail{
)doc");

m_tensor.def("unary_div_bw", &tt::tt_metal::unary_div_bw,
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("scalar") = 1.0f, py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
py::arg("grad").noconvert(), py::arg("input").noconvert(), py::arg("scalar") = 1.0f, py::arg("round_mode") = "None", py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for division with given ``grad`` and ``scalar``.
Input tensors must have BFLOAT16 data type.
Expand All @@ -334,7 +334,7 @@ namespace tt::tt_metal::detail{
)doc");

m_tensor.def("div_bw", &tt::tt_metal::div_bw,
py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
py::arg("grad").noconvert(), py::arg("input_a").noconvert(), py::arg("input_b").noconvert(), py::arg("round_mode") = "None", py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Performs backward operations for division of ``input_b`` with given ``grad``.
Input tensor must have BFLOAT16 data type.
Expand Down

0 comments on commit 5136302

Please sign in to comment.