Skip to content

Commit

Permalink
#9628: Remove std::function for BW Binary ops (#10492)
Browse files Browse the repository at this point in the history
* #9628: Remove std::function for atan2_bw

* #9628: Remove std::function for addalpha_bw

* #9628: Update rsub_bw

* #9628: Update embedding_bw
  • Loading branch information
VirdhatchaniKN authored Jul 20, 2024
1 parent 5df2e14 commit c88bb74
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,73 @@ namespace ttnn {

namespace operations::binary_backward {

//OpHandler_binary_bw : get_function_binary_bw_type1
template <BinaryBackwardOpType binary_backward_op_type>
struct ExecuteBinaryBackwardType1 {

static inline std::vector<ttnn::Tensor> create_async_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) {
const auto& input_tensor = input_tensors.at(0);
return {Tensor(operation::get_workers_for_op_output({input_tensor})),
Tensor(operation::get_workers_for_op_output({input_tensor}))};
}

//Type 1: 1 inputs, 1 grad tensor, 1 float, 1 default string
static std::vector<Tensor> execute_on_main_thread(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {
auto op_type = get_function_binary_bw_type1<binary_backward_op_type>();
auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, output_memory_config);
}
};

//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default
template <BinaryBackwardOpType unary_backward_op_type>
struct ExecuteBinaryBackwardOptionalFloatDefault {

static inline std::vector<ttnn::Tensor> create_async_output_tensors(
const std::vector<Tensor> &input_tensors, const std::vector<std::optional<const Tensor>>& optional_inputs) {
const auto& input_tensor = input_tensors.at(0);
return {Tensor(operation::get_workers_for_op_output({input_tensor})),
Tensor(operation::get_workers_for_op_output({input_tensor}))};
}

static std::vector<std::optional<Tensor>> execute_on_main_thread(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
float parameter,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
const std::vector<bool>& are_required_outputs = std::vector<bool>{true, true},
std::optional<Tensor> input_a_grad = std::nullopt,
std::optional<Tensor> input_b_grad = std::nullopt) {

auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
auto op_type = get_function_binary_bw_type1_opt_float_default<unary_backward_op_type>();
return op_type(queue_id, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config, are_required_outputs, input_a_grad, input_b_grad);
}

static std::vector<std::optional<Tensor>> execute_on_main_thread(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
float parameter,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
const std::vector<bool>& are_required_outputs = std::vector<bool>{true, true},
std::optional<Tensor> input_a_grad = std::nullopt,
std::optional<Tensor> input_b_grad = std::nullopt) {

auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
auto op_type = get_function_binary_bw_type1_opt_float_default<unary_backward_op_type>();
return op_type(DefaultQueueId, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config, are_required_outputs, input_a_grad, input_b_grad);
}

};

template <BinaryBackwardOpType binary_backward_op_type>
struct ExecuteBinaryBackward {
static inline std::vector<ttnn::Tensor> create_async_output_tensors(
Expand Down Expand Up @@ -138,9 +205,15 @@ struct ExecuteBinaryBackward {

} // operations::binary

//OpHandler_binary_bw : get_function_binary_bw_type1
constexpr auto atan2_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardType1<operations::binary_backward::BinaryBackwardOpType::ATAN2_BW>>("ttnn::atan2_bw");
constexpr auto rsub_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardType1<operations::binary_backward::BinaryBackwardOpType::RSUB_BW>>("ttnn::rsub_bw");
constexpr auto embedding_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardType1<operations::binary_backward::BinaryBackwardOpType::EMBEDDING_BW>>("ttnn::embedding_bw");

//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default
constexpr auto addalpha_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardOptionalFloatDefault<operations::binary_backward::BinaryBackwardOpType::ADDALPHA_BW>>("ttnn::addalpha_bw");

//type 1
constexpr auto atan2_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::ATAN2_BW>>("ttnn::atan2_bw");
constexpr auto embedding_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::EMBEDDING_BW>>("ttnn::embedding_bw");
constexpr auto subalpha_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::SUBALPHA_BW>>("ttnn::subalpha_bw");
constexpr auto xlogy_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::XLOGY_BW>>("ttnn::xlogy_bw");
constexpr auto hypot_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::HYPOT_BW>>("ttnn::hypot_bw");
Expand All @@ -149,12 +222,10 @@ constexpr auto logaddexp_bw = ttnn::register_operation<operations::binary_backwa
constexpr auto logaddexp2_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::LOGADDEXP2_BW>>("ttnn::logaddexp2_bw");
constexpr auto squared_difference_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::SQUARED_DIFFERENCE_BW>>("ttnn::squared_difference_bw");
constexpr auto concat_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::CONCAT_BW>>("ttnn::concat_bw");
constexpr auto rsub_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::RSUB_BW>>("ttnn::rsub_bw");
constexpr auto min_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::MIN_BW>>("ttnn::min_bw");
constexpr auto max_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::MAX_BW>>("ttnn::max_bw");
constexpr auto lerp_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::LERP_BW>>("ttnn::lerp_bw");

//type 2
constexpr auto addalpha_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackward<operations::binary_backward::BinaryBackwardOpType::ADDALPHA_BW>>("ttnn::addalpha_bw");

} // namespace ttnn
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,115 @@ namespace binary_backward {

namespace detail {

//OpHandler_binary_bw : get_function_binary_bw_type1
template <typename binary_backward_operation_t>
void bind_binary_backward_type_1(py::module& module, const binary_backward_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
{2}
Args:
* :attr:`grad_tensor`
* :attr:`input_tensor_a`
* :attr:`input_tensor_b`
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
>>> output = {1}(grad_tensor, tensor1, tensor2)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
description);

bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const std::optional<ttnn::MemoryConfig>& memory_config) -> std::vector<ttnn::Tensor> {
auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config());
return self(grad_tensor, input_tensor_a, input_tensor_b, output_memory_config);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::kw_only(),
py::arg("memory_config") = std::nullopt});
}

//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default
template <typename binary_backward_operation_t>
void bind_binary_backward_opt_float_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string& description) {
auto doc = fmt::format(
R"doc({0}(grad_tensor: ttnn.Tensor, input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, {2}: float, *, memory_config: ttnn.MemoryConfig) -> std::vector<Tensor>
{5}
Args:
* :attr:`grad_tensor`
* :attr:`input_tensor_a`
* :attr:`input_tensor_b`
* :attr:`{3}` (float):Default value = {4}
Keyword args:
* :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor
Example:
>>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device)
>>> output = {1}(grad_tensor, tensor1, tensor2, float)
)doc",
operation.base_name(),
operation.python_fully_qualified_name(),
parameter_name,
parameter_doc,
parameter_value,
description);


bind_registered_operation(
module,
operation,
doc,
ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
float parameter,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::vector<bool>& are_required_outputs,
const std::optional<ttnn::Tensor>& input_a_grad,
const std::optional<ttnn::Tensor>& input_b_grad,
const uint8_t& queue_id) -> std::vector<optional<ttnn::Tensor>> {
return self(queue_id, grad_tensor, input_tensor_a, input_tensor_b, parameter, memory_config, are_required_outputs, input_a_grad, input_b_grad);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::arg(parameter_name.c_str()) = parameter_value,
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_a_grad") = std::nullopt,
py::arg("input_b_grad") = std::nullopt,
py::arg("queue_id") = 0}
);
}

template <typename binary_backward_operation_t>
void bind_binary_backward(py::module& module, const binary_backward_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
Expand Down Expand Up @@ -212,12 +321,12 @@ Keyword args:


void py_module(py::module& module) {
detail::bind_binary_backward(
detail::bind_binary_backward_type_1(
module,
ttnn::atan2_bw,
R"doc(Performs backward operations for atan2 of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc");

detail::bind_binary_backward(
detail::bind_binary_backward_type_1(
module,
ttnn::embedding_bw,
R"doc(Performs backward operations for embedding_bw function and it returns specific indices of the embedding table specified by the :attr:`grad_tensor`.
Expand All @@ -228,9 +337,10 @@ void py_module(py::module& module) {
ttnn::subalpha_bw,
R"doc(Performs backward operations for subalpha of :attr:`input_tensor_a` and :attr:`input_tensor_b` with given :attr:`grad_tensor`.)doc");

detail::bind_binary_backward(
detail::bind_binary_backward_opt_float_default(
module,
ttnn::addalpha_bw,
"alpha", "Alpha value", 1.0f,
R"doc(Performs backward operations for addalpha on :attr:`input_tensor_b` , attr:`input_tensor_a`, attr:`alpha` with given attr:`grad_tensor`.)doc");

detail::bind_binary_backward(
Expand Down Expand Up @@ -273,7 +383,7 @@ void py_module(py::module& module) {
ttnn::concat_bw,
R"doc(Performs backward operations for concat on :attr:`input_tensor_a` and :attr:`input_tensor_b` with given attr:`grad_tensor`.)doc");

detail::bind_binary_backward(
detail::bind_binary_backward_type_1(
module,
ttnn::rsub_bw,
R"doc(Performs backward operations for subraction of :attr:`input_tensor_a` from :attr:`input_tensor_b` with given attr:`grad_tensor` (reversed order of subtraction operator).)doc");
Expand Down
Loading

0 comments on commit c88bb74

Please sign in to comment.