Skip to content

Commit

Permalink
#9628: Rename functions
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 20, 2024
1 parent 2d18571 commit 7224325
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace ttnn {

namespace operations::binary_backward {

//OpHandler_binary_bw : get_function_binary_bw_type1
//OpHandler_binary_bw : get_function_binary_bw
template <BinaryBackwardOpType binary_backward_op_type>
struct ExecuteBinaryBackwardType1 {

Expand All @@ -30,13 +30,13 @@ struct ExecuteBinaryBackwardType1 {
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {
auto op_type = get_function_binary_bw_type1<binary_backward_op_type>();
auto op_type = get_function_binary_bw<binary_backward_op_type>();
auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, output_memory_config);
}
};

//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default
//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_opt_float_default
template <BinaryBackwardOpType unary_backward_op_type>
struct ExecuteBinaryBackwardOptionalFloatDefault {

Expand All @@ -59,7 +59,7 @@ struct ExecuteBinaryBackwardOptionalFloatDefault {
std::optional<Tensor> input_b_grad = std::nullopt) {

auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
auto op_type = get_function_binary_bw_type1_opt_float_default<unary_backward_op_type>();
auto op_type = get_function_binary_bw_opt_float_default<unary_backward_op_type>();
return op_type(queue_id, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config, are_required_outputs, input_a_grad, input_b_grad);
}

Expand All @@ -74,13 +74,13 @@ struct ExecuteBinaryBackwardOptionalFloatDefault {
std::optional<Tensor> input_b_grad = std::nullopt) {

auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
auto op_type = get_function_binary_bw_type1_opt_float_default<unary_backward_op_type>();
auto op_type = get_function_binary_bw_opt_float_default<unary_backward_op_type>();
return op_type(DefaultQueueId, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config, are_required_outputs, input_a_grad, input_b_grad);
}

};

//OpHandler_binary_bw_float : get_function_binary_bw_type1_float
//OpHandler_binary_bw_float : get_function_binary_bw_float
template <BinaryBackwardOpType binary_backward_op_type>
struct ExecuteBinaryBackwardType1Float {

Expand All @@ -98,7 +98,7 @@ struct ExecuteBinaryBackwardType1Float {
const Tensor &input_tensor_b_arg,
float parameter,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {
auto op_type = get_function_binary_bw_type1<binary_backward_op_type>();
auto op_type = get_function_binary_bw_float<binary_backward_op_type>();
auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, parameter, output_memory_config);
}
Expand Down Expand Up @@ -193,15 +193,15 @@ struct ExecuteBinaryBackward {

} // operations::binary

//OpHandler_binary_bw : get_function_binary_bw_type1
//OpHandler_binary_bw : get_function_binary_bw
constexpr auto atan2_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardType1<operations::binary_backward::BinaryBackwardOpType::ATAN2_BW>>("ttnn::atan2_bw");
constexpr auto rsub_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardType1<operations::binary_backward::BinaryBackwardOpType::RSUB_BW>>("ttnn::rsub_bw");
constexpr auto embedding_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardType1<operations::binary_backward::BinaryBackwardOpType::EMBEDDING_BW>>("ttnn::embedding_bw");

//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default
//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_opt_float_default
constexpr auto addalpha_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardOptionalFloatDefault<operations::binary_backward::BinaryBackwardOpType::ADDALPHA_BW>>("ttnn::addalpha_bw");

//OpHandler_binary_bw_float : get_function_binary_bw_type1_float
//OpHandler_binary_bw_float : get_function_binary_bw_float
constexpr auto subalpha_bw = ttnn::register_operation<operations::binary_backward::ExecuteBinaryBackwardType1Float<operations::binary_backward::BinaryBackwardOpType::SUBALPHA_BW>>("ttnn::subalpha_bw");

//type 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace binary_backward {

namespace detail {

//OpHandler_binary_bw : get_function_binary_bw_type1
//OpHandler_binary_bw : get_function_binary_bw
template <typename binary_backward_operation_t>
void bind_binary_backward_type_1(py::module& module, const binary_backward_operation_t& operation, const std::string& description) {
auto doc = fmt::format(
Expand Down Expand Up @@ -67,7 +67,7 @@ Keyword args:
py::arg("memory_config") = std::nullopt});
}

//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default
//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_opt_float_default
template <typename binary_backward_operation_t>
void bind_binary_backward_opt_float_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string& description) {
auto doc = fmt::format(
Expand Down Expand Up @@ -129,7 +129,7 @@ void bind_binary_backward_opt_float_default(py::module& module, const binary_bac
);
}

//OpHandler_binary_bw_float : get_function_binary_bw_type1_float
//OpHandler_binary_bw_float : get_function_binary_bw_float
template <typename binary_backward_operation_t>
void bind_binary_backward_float_default(py::module& module, const binary_backward_operation_t& operation, const std::string& parameter_name, const std::string& parameter_doc, float parameter_value, const std::string& description) {
auto doc = fmt::format(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,22 @@ enum class BinaryBackwardOpType {
MUL_BW,
};
struct BinaryBackwardFunction{
static std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(BinaryBackwardOpType OpType); //get_function_binary_bw_type1
static std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(BinaryBackwardOpType OpType); //get_function_binary_bw
static std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(BinaryBackwardOpType OpType);
static std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tensor&, std::string, const MemoryConfig&)> get_function_type1_w_string(BinaryBackwardOpType OpType);
static std::function<std::vector<std::optional<ttnn::Tensor>>(uint8_t , const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> get_function_type3(BinaryBackwardOpType OpType);
static std::function<std::vector<std::optional<ttnn::Tensor>>(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> get_function_type3_wo_qid(BinaryBackwardOpType OpType);
};

//OpHandler_binary_bw : get_function_binary_bw_type1
//OpHandler_binary_bw : get_function_binary_bw
std::vector<Tensor> _atan2_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _rsub_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);
std::vector<Tensor> _embedding_bw( const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional<MemoryConfig>& output_mem_config);

//OpHandler_binary_bw_float : get_function_binary_bw_type1_float
//OpHandler_binary_bw_float : get_function_binary_bw_float
std::vector<ttnn::Tensor> _subalpha_bw( const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0f, const std::optional<MemoryConfig>& output_mem_config = std::nullopt);

//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_type1_opt_float_default
//OpHandler_binary_bw_opt_float_default : get_function_binary_bw_opt_float_default
std::vector<std::optional<ttnn::Tensor>> _addalpha_bw( uint8_t queue_id, const Tensor& grad, const Tensor& input, const Tensor& other, float alpha = 1.0f, const std::optional<MemoryConfig>& output_mem_config = std::nullopt, const std::vector<bool>& are_required_outputs = std::vector<bool>{true, true}, std::optional<Tensor> input_grad = std::nullopt, std::optional<Tensor> other_grad = std::nullopt);

// OpHandler struct template
Expand Down Expand Up @@ -100,25 +100,25 @@ struct OpHandler_binary_bw<BinaryBackwardOpType::EMBEDDING_BW> {
};

template <>
struct OpHandler_binary_bw<BinaryBackwardOpType::SUBALPHA_BW> {
struct OpHandler_binary_bw_float<BinaryBackwardOpType::SUBALPHA_BW> {
static std::vector<Tensor> handle( const Tensor& grad, const Tensor& input, const Tensor& other, float alpha, const std::optional<MemoryConfig>& output_mem_config ) {
return _subalpha_bw(grad, input, other, alpha, output_mem_config);
}
};

// Template functions to get the function pointers
template <BinaryBackwardOpType OpType>
auto get_function_binary_bw_type1() {
auto get_function_binary_bw() {
return &OpHandler_binary_bw<OpType>::handle;
}

template <BinaryBackwardOpType OpType>
auto get_function_binary_bw_type1_opt_float_default() {
auto get_function_binary_bw_opt_float_default() {
return &OpHandler_binary_bw_opt_float_default<OpType>::handle;
}

template <BinaryBackwardOpType OpType>
auto get_function_binary_bw_type1_float() {
auto get_function_binary_bw_float() {
return &OpHandler_binary_bw_float<OpType>::handle;
}

Expand Down

0 comments on commit 7224325

Please sign in to comment.