Skip to content

Commit

Permalink
#9628: Remove get_function_type2
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 20, 2024
1 parent 1920676 commit 9e122ea
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -189,41 +189,6 @@ struct ExecuteBinaryBackward {
return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, output_memory_config, are_required_outputs, input_a_grad, input_b_grad);
}

// Type 2 : Q_ID, type1 args, optional output tensor for inputs based on are_required_outputs value

static std::vector<std::optional<ttnn::Tensor>> execute_on_main_thread(
uint8_t queue_id,
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
float alpha,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
const std::vector<bool>& are_required_outputs = std::vector<bool>{true, true},
std::optional<Tensor> input_a_grad = std::nullopt,
std::optional<Tensor> input_b_grad = std::nullopt) {

auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
auto op_type = BinaryBackwardFunction::get_function_type2(binary_backward_op_type);
return op_type(queue_id, grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, alpha, output_memory_config, are_required_outputs, input_a_grad, input_b_grad);
}

// Type 2 : type1 args, optional output tensor for inputs based on are_required_outputs value

static std::vector<std::optional<ttnn::Tensor>> execute_on_main_thread(
const Tensor &grad_tensor_arg,
const Tensor &input_tensor_a_arg,
const Tensor &input_tensor_b_arg,
float alpha,
const std::optional<MemoryConfig> &memory_config = std::nullopt,
const std::vector<bool>& are_required_outputs = std::vector<bool>{true, true},
std::optional<Tensor> input_a_grad = std::nullopt,
std::optional<Tensor> input_b_grad = std::nullopt) {

auto output_memory_config = memory_config.value_or(input_tensor_a_arg.memory_config());
auto op_type = BinaryBackwardFunction::get_function_type2_wo_qid(binary_backward_op_type);
return op_type(grad_tensor_arg, input_tensor_a_arg, input_tensor_b_arg, alpha, output_memory_config, are_required_outputs, input_a_grad, input_b_grad);
}

};

} // operations::binary
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,52 +322,6 @@ Keyword args:
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_a_grad") = std::nullopt,
py::arg("input_b_grad") = std::nullopt},

ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const float alpha,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::vector<bool>& are_required_outputs,
const std::optional<ttnn::Tensor>& input_a_grad,
const std::optional<ttnn::Tensor>& input_b_grad,
const uint8_t& queue_id) -> std::vector<optional<ttnn::Tensor>> {
return self(queue_id, grad_tensor, input_tensor_a, input_tensor_b, alpha, memory_config, are_required_outputs, input_a_grad, input_b_grad);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::arg("alpha") = 1.0f,
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_a_grad") = std::nullopt,
py::arg("input_b_grad") = std::nullopt,
py::arg("queue_id") = 0},

ttnn::pybind_overload_t{
[](const binary_backward_operation_t& self,
const ttnn::Tensor& grad_tensor,
const ttnn::Tensor& input_tensor_a,
const ttnn::Tensor& input_tensor_b,
const float alpha,
const std::optional<ttnn::MemoryConfig>& memory_config,
const std::vector<bool>& are_required_outputs,
const std::optional<ttnn::Tensor>& input_a_grad,
const std::optional<ttnn::Tensor>& input_b_grad) -> std::vector<optional<ttnn::Tensor>> {
return self(grad_tensor, input_tensor_a, input_tensor_b, alpha, memory_config, are_required_outputs, input_a_grad, input_b_grad);
},
py::arg("grad_tensor"),
py::arg("input_tensor_a"),
py::arg("input_tensor_b"),
py::arg("alpha") = 1.0f,
py::kw_only(),
py::arg("memory_config") = std::nullopt,
py::arg("are_required_outputs") = std::vector<bool>{true, true},
py::arg("input_a_grad") = std::nullopt,
py::arg("input_b_grad") = std::nullopt});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,22 +660,6 @@ std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tens
}
}

std::function<std::vector<std::optional<ttnn::Tensor>>(uint8_t , const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> BinaryBackwardFunction::get_function_type2(BinaryBackwardOpType OpType){
switch (OpType) {
default:
TT_ASSERT(false && "Undefined op type");
return 0;
}
}

std::function<std::vector<std::optional<ttnn::Tensor>>(const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> BinaryBackwardFunction::get_function_type2_wo_qid(BinaryBackwardOpType OpType){
switch (OpType) {
default:
TT_ASSERT(false && "Undefined op type");
return 0;
}
}

std::function<std::vector<std::optional<ttnn::Tensor>>(uint8_t , const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> BinaryBackwardFunction::get_function_type3(BinaryBackwardOpType OpType){
switch (OpType) {
case BinaryBackwardOpType::ADD_BW:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ struct BinaryBackwardFunction{
static std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&)> get_function_type1(BinaryBackwardOpType OpType); //get_function_binary_bw_type1
static std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&)> get_function_type1_w_float(BinaryBackwardOpType OpType);
static std::function<std::vector<ttnn::Tensor>(const Tensor&, const Tensor&, const Tensor&, std::string, const MemoryConfig&)> get_function_type1_w_string(BinaryBackwardOpType OpType);
static std::function<std::vector<std::optional<ttnn::Tensor>>(uint8_t , const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> get_function_type2(BinaryBackwardOpType OpType);
static std::function<std::vector<std::optional<ttnn::Tensor>>(const Tensor&, const Tensor&, const Tensor&, float, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> get_function_type2_wo_qid(BinaryBackwardOpType OpType);
static std::function<std::vector<std::optional<ttnn::Tensor>>(uint8_t , const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> get_function_type3(BinaryBackwardOpType OpType);
static std::function<std::vector<std::optional<ttnn::Tensor>>(const Tensor&, const Tensor&, const Tensor&, const MemoryConfig&, const std::vector<bool>&, std::optional<Tensor>, std::optional<Tensor>)> get_function_type3_wo_qid(BinaryBackwardOpType OpType);
};
Expand Down

0 comments on commit 9e122ea

Please sign in to comment.