Skip to content

Commit

Permalink
#12164: Add queue_id and optional output tensors to rsub_bw, subalpha_bw
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Sep 12, 2024
1 parent c672ab6 commit 07ca39c
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 38 deletions.
43 changes: 43 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,46 @@ def test_bw_rsub(input_shapes, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True], [False, False]])
def test_bw_rsub_opt(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -5, 5, device, True)

grad_data, grad_tensor = data_gen_with_range(input_shapes, -10, 5, device)

input_grad = None
other_grad = None
tt_output_tensor_on_device = None

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0

if are_required_outputs[0] and are_required_outputs[1]:
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.rsub_bw(
grad_tensor, input_tensor, other_tensor, input_grad=input_grad, other_grad=other_grad, queue_id=cq_id
)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())
tt_output_tensor_on_device = [input_grad, other_grad]
else:
tt_output_tensor_on_device = ttnn.rsub_bw(grad_tensor, input_tensor, other_tensor, queue_id=cq_id)

golden_function = ttnn.get_golden_function(ttnn.rsub_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,32 @@ def test_bw_subalpha_default(input_shapes, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
def test_bw_subalpha_opt_output(input_shapes, device):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.subalpha_bw(grad_tensor, input_tensor, other_tensor, input_grad=input_grad, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())

tt_output_tensor_on_device = [input_grad]

golden_function = ttnn.get_golden_function(ttnn.subalpha_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,49 @@ struct ExecuteAddalphaBW {
std::optional<Tensor> input_b_grad = std::nullopt);
};

struct ExecuteBackwardSubAlpha {
static std::vector<std::optional<ttnn::Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
float alpha,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
std::optional<Tensor> input_grad = std::nullopt,
std::optional<Tensor> other_grad = std::nullopt);

static std::vector<std::optional<ttnn::Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
float alpha,
const std::optional<MemoryConfig> &memory_config = std::nullopt);

};

struct ExecuteBackwardRsub {
static std::vector<std::optional<ttnn::Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
std::optional<Tensor> input_grad = std::nullopt,
std::optional<Tensor> other_grad = std::nullopt);

static std::vector<std::optional<ttnn::Tensor>> invoke(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt);

};

} // operations::binary

constexpr auto atan2_bw = ttnn::register_operation<"ttnn::atan2_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::ATAN2_BW>>();
constexpr auto rsub_bw = ttnn::register_operation<"ttnn::rsub_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::RSUB_BW>>();
constexpr auto xlogy_bw = ttnn::register_operation<"ttnn::xlogy_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::XLOGY_BW>>();
constexpr auto hypot_bw = ttnn::register_operation<"ttnn::hypot_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::HYPOT_BW>>();
constexpr auto ldexp_bw = ttnn::register_operation<"ttnn::ldexp_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::LDEXP_BW>>();
Expand All @@ -278,7 +317,13 @@ constexpr auto min_bw = ttnn::register_operation<"ttnn::min_bw", operations::bin
constexpr auto max_bw = ttnn::register_operation<"ttnn::max_bw", operations::binary_backward::ExecuteBinaryBackwardTensor<operations::binary_backward::BinaryBackwardOpType::MAX_BW>>();


constexpr auto subalpha_bw = ttnn::register_operation<"ttnn::subalpha_bw", operations::binary_backward::ExecuteBinaryBackwardFloatDefault<operations::binary_backward::BinaryBackwardOpType::SUBALPHA_BW>>();
constexpr auto subalpha_bw = ttnn::register_operation<
"ttnn::subalpha_bw",
operations::binary_backward::ExecuteBackwardSubAlpha>();

constexpr auto rsub_bw = ttnn::register_operation<
"ttnn::rsub_bw",
operations::binary_backward::ExecuteBackwardRsub>();

constexpr auto concat_bw = ttnn::register_operation<"ttnn::concat_bw", operations::binary_backward::ExecuteBinaryBackwardIntDefault<operations::binary_backward::BinaryBackwardOpType::CONCAT_BW>>();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,141 @@ void bind_binary_backward_float_default(py::module& module, const binary_backwar
);
}

template <typename binary_backward_operation_t>
void bind_binary_backward_sub_alpha(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, std::string_view description, std::string_view supported_dtype) {
auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, {2}: float, *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
{5}
Args:
* :attr:`grad_tensor`
* :attr:`input_tensor_a`
* :attr:`input_tensor_b`
* :attr:`{2}` (float):`{3}`,Default value = {4}
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
Supported dtypes, layouts, and ranks:
{6}
Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
>>> output = {1}(grad_tensor, tensor1, tensor2, float)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
parameter_name,
parameter_doc,
parameter_value,
description,
supported_dtype);


bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor,
const ttnn::Tensor& other_tensor,
float alpha,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::vector<bool>& are_required_outputs,
const std::optional<ttnn::Tensor>& input_grad,
const std::optional<ttnn::Tensor>& other_grad,
const uint8_t& queue_id) -> std::vector<std::optional<ttnn::Tensor>> {
return self(queue_id, grad_tensor, input_tensor, other_tensor, alpha, memory_config, are_required_outputs, input_grad, other_grad);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::arg(parameter_name.c_str()) = parameter_value,
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_grad") = std::nullopt,
py::arg("other_grad") = std::nullopt,
py::arg("queue_id") = 0}
);
}

template <typename binary_backward_operation_t>
void bind_binary_backward_rsub(py::module& module, const binary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) {

auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
{2}
Args:
* :attr:`grad_tensor`
* :attr:`input_tensor_a`
* :attr:`input_tensor_b`
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
Supported dtypes, layouts, and ranks:
{3}
Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
>>> output = {1}(grad_tensor, tensor1, tensor2, float)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description,
supported_dtype);


bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor,
const ttnn::Tensor& other_tensor,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::vector<bool>& are_required_outputs,
const std::optional<ttnn::Tensor>& input_grad,
const std::optional<ttnn::Tensor>& other_grad,
const uint8_t& queue_id) -> std::vector<std::optional<ttnn::Tensor>> {
return self(queue_id, grad_tensor, input_tensor, other_tensor, memory_config, are_required_outputs, input_grad, other_grad);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_grad") = std::nullopt,
py::arg("other_grad") = std::nullopt,
py::arg("queue_id") = 0}
);
}

template <typename binary_backward_operation_t>
void bind_binary_bw_mul(py::module& module, const binary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) {
auto doc = fmt::format(
Expand Down Expand Up @@ -846,7 +981,7 @@ void py_module(py::module& module) {
| BFLOAT16, BFLOAT8_B | ROW_MAJOR, TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+)doc");

detail::bind_binary_backward_float_default(
detail::bind_binary_backward_sub_alpha(
module,
ttnn::subalpha_bw,
"alpha", "Alpha value", 1.0f,
Expand Down Expand Up @@ -948,7 +1083,7 @@ void py_module(py::module& module) {
| BFLOAT16, BFLOAT8_B | ROW_MAJOR, TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+)doc");

detail::bind_binary_backward_ops(
detail::bind_binary_backward_rsub(
module,
ttnn::rsub_bw,
R"doc(Performs backward operations for subraction of :attr:`input_tensor_a` from :attr:`input_tensor_b` with given attr:`grad_tensor` (reversed order of subtraction operator).)doc",
Expand Down
Loading

0 comments on commit 07ca39c

Please sign in to comment.