Skip to content

Commit

Permalink
#9874: Overload binary eq in unary eq
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN authored and Aswinmcw committed Jul 10, 2024
1 parent ed73256 commit 10ec3f8
Show file tree
Hide file tree
Showing 11 changed files with 50 additions and 61 deletions.
3 changes: 1 addition & 2 deletions docs/source/ttnn/ttnn/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ Pointwise Unary
ttnn/assign_bw
ttnn/multigammaln_bw
ttnn/add_bw
ttnn/unary_eq_bw
ttnn/eq_bw

Pointwise Binary
================
Expand Down Expand Up @@ -220,7 +220,6 @@ Pointwise Binary
ttnn/logaddexp_bw
ttnn/logaddexp2_bw
ttnn/squared_difference_bw
ttnn/binary_eq_bw
ttnn/concat_bw
ttnn/binary_le_bw
ttnn/rsub_bw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_bw_binary_eq(input_shapes, device):
other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True)
_, grad_tensor = data_gen_with_range(input_shapes, -20, 40, device)

tt_output_tensor_on_device = ttnn.binary_eq_bw(grad_tensor, input_tensor, other_tensor)
tt_output_tensor_on_device = ttnn.eq_bw(grad_tensor, input_tensor, other_tensor)
in_grad = torch.zeros_like(in_data)
other_grad = torch.zeros_like(other_data)

Expand Down Expand Up @@ -51,7 +51,7 @@ def test_bw_binary_eq_opt_output(input_shapes, device, are_required_outputs):
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

tt_output_tensor_on_device = ttnn.binary_eq_bw(
tt_output_tensor_on_device = ttnn.eq_bw(
grad_tensor,
input_tensor,
other_tensor,
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_bw_binary_eq_opt_output_qid(input_shapes, device, are_required_outputs)

cq_id = 0

tt_output_tensor_on_device = ttnn.binary_eq_bw(
tt_output_tensor_on_device = ttnn.eq_bw(
grad_tensor,
input_tensor,
other_tensor,
Expand All @@ -114,3 +114,23 @@ def test_bw_binary_eq_opt_output_qid(input_shapes, device, are_required_outputs)
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("other", [1.0])
def test_bw_unary_eq(input_shapes, other, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

tt_output_tensor_on_device = ttnn.eq_bw(grad_tensor, input_tensor, other)
pt_y = torch.zeros_like(grad_data)
golden_tensor = [pt_y]
comp_pass = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert comp_pass

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,4 @@ constexpr auto lerp_bw = ttnn::register_operation<operations::binary_backward::E
//type 2
constexpr auto addalpha_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::ADDALPHA_BW>>("ttnn::addalpha_bw");

//type 3
constexpr auto binary_eq_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::BINARY_EQ_BW>>("ttnn::binary_eq_bw");
} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,6 @@ void py_module(py::module& module) {
ttnn::squared_difference_bw,
R"doc(Performs backward operations for squared_difference of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc");

detail::bind_binary_backward(
module,
ttnn::binary_eq_bw,
R"doc(Performs backward operations for equal to comparison on :attr:`input_tensor_a` , attr:`input_tensor_b` with given attr:`grad_tensor`.
Returns an tensor of zeros like :attr:`input_tensor_a` and :attr:`input_tensor_b` tensor.)doc");

detail::bind_binary_backward(
module,
ttnn::concat_bw,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ std::vector<ttnn::Tensor> _squared_difference_bw(
return grad_tensor;
}

std::vector<std::optional<Tensor>> _binary_eq_bw(
std::vector<std::optional<Tensor>> _eq_bw(
uint8_t cq_id,
const Tensor& grad,
const Tensor& input,
Expand Down Expand Up @@ -326,9 +326,9 @@ std::vector<std::optional<Tensor>> _binary_eq_bw(
return std::move(result);
}

std::vector<ttnn::Tensor> _binary_eq_bw_inter(
std::vector<ttnn::Tensor> _eq_bw_inter(
const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) {
auto result = _binary_eq_bw(0, grad, input, other, output_mem_config, {true, true}, std::nullopt, std::nullopt);
auto result = _eq_bw(0, grad, input, other, output_mem_config, {true, true}, std::nullopt, std::nullopt);
std::vector<ttnn::Tensor> output_tensors;
output_tensors.reserve(result.size());

Expand All @@ -342,7 +342,7 @@ std::vector<ttnn::Tensor> _binary_eq_bw_inter(
return output_tensors;
}

std::vector<std::optional<Tensor>> _binary_eq_bw_overload(
std::vector<std::optional<Tensor>> _eq_bw_overload(
const Tensor& grad,
const Tensor& input,
const Tensor& other,
Expand All @@ -351,7 +351,7 @@ std::vector<std::optional<Tensor>> _binary_eq_bw_overload(
std::optional<Tensor> input_grad,
std::optional<Tensor> other_grad) {
uint8_t default_queue_id = 0;
return _binary_eq_bw(default_queue_id, grad, input, other, output_mem_config, are_required_outputs, input_grad, other_grad);
return _eq_bw(default_queue_id, grad, input, other, output_mem_config, are_required_outputs, input_grad, other_grad);
}

std::vector<Tensor> _assign_bw(
Expand Down Expand Up @@ -636,8 +636,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tens
return _squared_difference_bw;
case BinaryBackwardOpType::ADD_BW:
return _add_bw_inter;
case BinaryBackwardOpType::BINARY_EQ_BW:
return _binary_eq_bw_inter;
case BinaryBackwardOpType::EQ_BW:
return _eq_bw_inter;
case BinaryBackwardOpType::ASSIGN_BW:
return _assign_bw;
case BinaryBackwardOpType::BINARY_LE_BW:
Expand Down Expand Up @@ -716,8 +716,8 @@ std::function<std::vector<std::optional<ttnn::Tensor>>(uint8_t , const Tensor&,
switch (OpType) {
case BinaryBackwardOpType::ADD_BW:
return _add_bw;
case BinaryBackwardOpType::BINARY_EQ_BW:
return _binary_eq_bw;
case BinaryBackwardOpType::EQ_BW:
return _eq_bw;
case BinaryBackwardOpType::MUL_BW:
return _mul_bw;
default:
Expand All @@ -730,8 +730,8 @@ std::function<std::vector<std::optional<ttnn::Tensor>>(const Tensor&, const Tens
switch (OpType) {
case BinaryBackwardOpType::ADD_BW:
return _add_bw_overload;
case BinaryBackwardOpType::BINARY_EQ_BW:
return _binary_eq_bw_overload;
case BinaryBackwardOpType::EQ_BW:
return _eq_bw_overload;
case BinaryBackwardOpType::MUL_BW:
return _mul_bw_overload;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ enum class BinaryBackwardOpType {
LOGADDEXP2_BW,
SQUARED_DIFFERENCE_BW,
ADD_BW,
BINARY_EQ_BW,
EQ_BW,
ASSIGN_BW,
CONCAT_BW,
BINARY_LE_BW,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ std::vector<Tensor> _unary_comp_bw(const Tensor& grad, const Tensor& input, floa
return grad_tensor;
}

std::vector<Tensor> _unary_eq_bw(
std::vector<Tensor> _eq_bw(
const Tensor& grad, const Tensor& input, float other, const MemoryConfig& output_mem_config) {
return _unary_comp_bw(grad, input, other, output_mem_config);
}
Expand All @@ -113,8 +113,8 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, float, con
return _clamp_min_bw;
case UnaryBackwardOpType::ADD_BW:
return _add_bw;
case UnaryBackwardOpType::UNARY_EQ_BW:
return _unary_eq_bw;
case UnaryBackwardOpType::EQ_BW:
return _eq_bw;
default:
TT_ASSERT(false && "Undefined op type");
return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ enum class UnaryBackwardOpType {
ASSIGN_BW,
MULTIGAMMALN_BW,
ADD_BW,
UNARY_EQ_BW,
EQ_BW,
};


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,6 @@ constexpr auto clamp_bw = ttnn::register_operation<operations::unary_backward::E
constexpr auto assign_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ASSIGN_BW>>("ttnn::assign_bw");
constexpr auto multigammaln_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::MULTIGAMMALN_BW>>("ttnn::multigammaln_bw");
constexpr auto add_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::ADD_BW>>("ttnn::add_bw");
constexpr auto unary_eq_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::UNARY_EQ_BW>>("ttnn::unary_eq_bw");
constexpr auto eq_bw = ttnn::register_operation<operations::unary_backward::ExecuteUnaryBackward<operations::unary_backward::UnaryBackwardOpType::EQ_BW>>("ttnn::eq_bw");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ Keyword args:
}else if(operation.base_name()=="add_bw"){
using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward<binary_backward::BinaryBackwardOpType::ADD_BW>;
return BinaryBackwardOp::execute_on_worker_thread(grad_tensor, input_tensor_a, output_memory_config, input_tensor_b);
}else if(operation.base_name()=="eq_bw"){
using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward<binary_backward::BinaryBackwardOpType::EQ_BW>;
return BinaryBackwardOp::execute_on_worker_thread(grad_tensor, input_tensor_a, output_memory_config, input_tensor_b);
}
return BinaryBackwardOp::execute_on_worker_thread(grad_tensor, input_tensor_a, output_memory_config, input_tensor_b);

Expand All @@ -87,6 +90,9 @@ Keyword args:
if(operation.base_name()=="add_bw"){
using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward<binary_backward::BinaryBackwardOpType::ADD_BW>;
return BinaryBackwardOp::execute_on_main_thread(queue_id, grad_tensor, input_tensor_a, input_tensor_b, memory_config, are_required_outputs, input_a_grad, input_b_grad);
}else if(operation.base_name()=="eq_bw"){
using BinaryBackwardOp = ttnn::operations::binary_backward::ExecuteBinaryBackward<binary_backward::BinaryBackwardOpType::EQ_BW>;
return BinaryBackwardOp::execute_on_main_thread(queue_id, grad_tensor, input_tensor_a, input_tensor_b, memory_config, are_required_outputs, input_a_grad, input_b_grad);
}
return BinaryBackwardOp::execute_on_main_thread(queue_id, grad_tensor, input_tensor_a, input_tensor_b, memory_config, are_required_outputs, input_a_grad, input_b_grad);
},
Expand Down Expand Up @@ -185,8 +191,9 @@ void py_module(py::module& module) {

detail::bind_unary_backward(
module,
ttnn::unary_eq_bw,
R"doc(Performs backward operations for equal to comparison on :attr:`input_tensor`, :attr:`alpha` with given :attr:`grad_tensor`.)doc");
ttnn::eq_bw,
R"doc(Performs backward operations for equal to comparison on :attr:`input_tensor`, :attr:`alpha` or attr:`input_tensor_b` with given :attr:`grad_tensor`.
Returns an tensor of zeros like input tensors.)doc");

}

Expand Down

0 comments on commit 10ec3f8

Please sign in to comment.