Skip to content

Commit

Permalink
#12164: Add queue_id and optional output tensors to sub_bw
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Sep 12, 2024
1 parent 07ca39c commit f67b09e
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 30 deletions.
20 changes: 9 additions & 11 deletions tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_bw_rsub(input_shapes, device):
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True], [False, False]])
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_rsub_opt(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -5, 5, device, True)
Expand All @@ -57,18 +57,16 @@ def test_bw_rsub_opt(input_shapes, device, are_required_outputs):

cq_id = 0

if are_required_outputs[0] and are_required_outputs[1]:
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.rsub_bw(
grad_tensor, input_tensor, other_tensor, input_grad=input_grad, other_grad=other_grad, queue_id=cq_id
)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())
tt_output_tensor_on_device = [input_grad, other_grad]
else:
tt_output_tensor_on_device = ttnn.rsub_bw(grad_tensor, input_tensor, other_tensor, queue_id=cq_id)
pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.rsub_bw(grad_tensor, input_tensor, other_tensor, input_grad=input_grad, other_grad=other_grad, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())
tt_output_tensor_on_device = [input_grad, other_grad]

golden_function = ttnn.get_golden_function(ttnn.rsub_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
40 changes: 40 additions & 0 deletions tests/ttnn/unit_tests/operations/backward/test_backward_sub.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,43 @@ def test_bw_unary_sub(input_shapes, scalar, device):

status = compare_pcc(tt_output_tensor_on_device, golden_tensor)
assert status


@pytest.mark.parametrize(
"input_shapes",
(
(torch.Size([1, 1, 32, 32])),
(torch.Size([1, 1, 320, 384])),
(torch.Size([1, 3, 320, 384])),
),
)
@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]])
def test_bw_sub_opt(input_shapes, device, are_required_outputs):
in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True)
grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device)

input_grad = None
other_grad = None
tt_output_tensor_on_device = None

if are_required_outputs[0]:
_, input_grad = data_gen_with_range(input_shapes, -1, 1, device)
if are_required_outputs[1]:
_, other_grad = data_gen_with_range(input_shapes, -1, 1, device)

cq_id = 0

pages_before = ttnn._ttnn.reports.get_buffer_pages()
ttnn.sub_bw(grad_tensor, input_tensor, other_tensor, input_grad=input_grad, other_grad=other_grad, queue_id=cq_id)
assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages())
tt_output_tensor_on_device = [input_grad, other_grad]

golden_function = ttnn.get_golden_function(ttnn.sub_bw)
golden_tensor = golden_function(grad_data, in_data, other_data)

status = True
for i in range(len(are_required_outputs)):
if are_required_outputs[i]:
status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]])
assert status
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,15 @@ struct ExecuteBackwardAdd {
float scalar,
const std::optional<MemoryConfig> &memory_config = std::nullopt);

static std::vector<Tensor> invoke(
static std::vector<std::optional<Tensor>> invoke(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt);
const std::optional<MemoryConfig> &memory_config = std::nullopt,
const std::vector<bool> &are_required_outputs = std::vector<bool>{true, true},
std::optional<Tensor> input_grad = std::nullopt,
std::optional<Tensor> other_grad = std::nullopt);

static std::vector<ComplexTensor> invoke(
const ComplexTensor &grad_tensor_arg,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,104 @@ void bind_binary_bw_mul(py::module& module, const binary_backward_operation_t& o
py::arg("memory_config") = std::nullopt});
}


template <typename binary_backward_operation_t>
void bind_binary_bw_sub(py::module& module, const binary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) {
auto doc = fmt::format(
R"doc({0}(input_tensor_a: Union[ttnn.Tensor, ComplexTensor] , input_tensor_b: Union[ComplexTensor, ttnn.Tensor, int, float], *, memory_config: Optional[ttnn.MemoryConfig] = None, dtype: Optional[ttnn.DataType] = None, activations: Optional[List[str]] = None) -> ttnn.Tensor or ComplexTensor
{2}
Supports broadcasting.
Args:
* :attr:`input_tensor_a` (ComplexTensor or ttnn.Tensor)
* :attr:`input_tensor_b` (ComplexTensor or ttnn.Tensor or Number): the tensor or number to add to :attr:`input_tensor_a`.
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
* :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor
* :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor
* :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor
* :attr:`queue_id` (Optional[uint8]): command queue id
Supported dtypes, layouts, and ranks:
{3}
Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT
Example:
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
>>> output = {1}(tensor1, tensor2)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description,
supported_dtype);

bind_registered_operation(
module,
operation,
doc,
// tensor and scalar
ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const Tensor& grad_tensor,
const Tensor& input_tensor_a,
const float scalar,
const std::optional<MemoryConfig>& memory_config){
return self(grad_tensor, input_tensor_a, scalar, memory_config);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("scalar"),
py::kw_only(),
py::arg("memory_config") = std::nullopt},

// tensor and tensor
ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor,
const ttnn::Tensor& other_tensor,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::vector<bool>& are_required_outputs,
const std::optional<ttnn::Tensor>& input_grad,
const std::optional<ttnn::Tensor>& other_grad,
const uint8_t& queue_id) -> std::vector<std::optional<ttnn::Tensor>> {
return self(queue_id, grad_tensor, input_tensor, other_tensor, memory_config, are_required_outputs, input_grad, other_grad);
},
py::arg("grad_tensor"),
py::arg("input_tensor"),
py::arg("other_tensor"),
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_grad") = std::nullopt,
py::arg("other_grad") = std::nullopt,
py::arg("queue_id") = 0},

// complex tensor
ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ComplexTensor& grad_tensor,
const ComplexTensor& input_tensor_a,
const ComplexTensor& input_tensor_b,
float alpha,
const std::optional<MemoryConfig>& memory_config) {
return self(grad_tensor, input_tensor_a, input_tensor_b, alpha, memory_config);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::arg("alpha"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

template <typename binary_backward_operation_t>
void bind_binary_bw_div(py::module& module, const binary_backward_operation_t& operation, std::string_view description, std::string_view supported_dtype) {
auto doc = fmt::format(
Expand Down Expand Up @@ -913,7 +1011,7 @@ void py_module(py::module& module) {
| BFLOAT16, BFLOAT8_B | ROW_MAJOR, TILE | 2, 3, 4 |
+----------------------------+---------------------------------+-------------------+)doc");

detail::bind_binary_bw_operation(
detail::bind_binary_bw_sub(
module,
ttnn::sub_bw,
R"doc(Performs backward operations for sub of :attr:`input_tensor_a` and :attr:`input_tensor_b` or :attr:`scalar` with given :attr:`grad_tensor`.)doc",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,16 @@ std::vector<Tensor> ExecuteBackwardSub::invoke(
return grad_tensor;
}

std::vector<Tensor> ExecuteBackwardSub::invoke(
const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
grad_tensor.emplace_back(grad);
Tensor grad_b = ttnn::multiply(ttnn::neg(grad, output_mem_config), 1.0f, std::nullopt, output_mem_config);
grad_tensor.emplace_back(grad_b);
return grad_tensor;
std::vector<std::optional<Tensor>> ExecuteBackwardSub::invoke(
uint8_t queue_id,
const Tensor& grad,
const Tensor& input,
const Tensor& other,
const std::optional<MemoryConfig>& output_mem_config,
const std::vector<bool>& are_required_outputs,
std::optional<Tensor> input_grad,
std::optional<Tensor> other_grad) {
return ttnn::subalpha_bw(queue_id, grad, input, other, 1.0f, output_mem_config, are_required_outputs, input_grad, other_grad);
}

std::vector<ComplexTensor> ExecuteBackwardSub::invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ namespace ttnn::operations::binary_backward {
enum class BinaryBackwardOpType {
ATAN2_BW,
ADDALPHA_BW,
SUB_BW,
XLOGY_BW,
HYPOT_BW,
LDEXP_BW,
Expand All @@ -39,7 +38,6 @@ std::vector<Tensor> _xlogy_bw( const Tensor& grad, const Tensor& input, const Te
std::vector<Tensor> _hypot_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _ldexp_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _logaddexp_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _sub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _gt_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _logaddexp2_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _squared_difference_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
Expand Down Expand Up @@ -92,13 +90,6 @@ struct OpHandler<BinaryBackwardOpType::LOGADDEXP_BW> {
}
};

template <>
struct OpHandler<BinaryBackwardOpType::SUB_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config ) {
return _sub_bw(grad, input, other, output_mem_config);
}
};

template <>
struct OpHandler<BinaryBackwardOpType::LOGADDEXP2_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config ) {
Expand Down

0 comments on commit f67b09e

Please sign in to comment.