Skip to content

Commit

Permalink
#10890: Update files
Browse files Browse the repository at this point in the history
  • Loading branch information
mouliraj-mcw committed Aug 10, 2024
1 parent d0c11b1 commit 2153fca
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -666,16 +666,16 @@ Tensor _polygamma(const Tensor& input_a, int32_t k, const std::optional<MemoryCo
}

//rdiv
Tensor ExecuteRdiv::operator()(uint8_t queue_id, const Tensor& input_tensor, float value, string round_mode, const std::optional<MemoryConfig>& memory_config, std::optional<Tensor> optional_output_tensor) {
Tensor ExecuteRdiv::operator()(uint8_t queue_id, const Tensor& input_tensor, float value, const std::string& round_mode, const std::optional<MemoryConfig>& memory_config, std::optional<Tensor> optional_output_tensor) {
float t_inf = std::numeric_limits<float>::infinity();
Tensor recip_result = ttnn::reciprocal(queue_id, input_tensor, memory_config, optional_output_tensor);
Tensor result = ttnn::multiply(queue_id, recip_result, value, std::nullopt, memory_config, optional_output_tensor);

if(round_mode == "trunc"){
result = trunc(result);
result = ttnn::trunc(result);
}
else if(round_mode == "floor"){
result = floor(result);
result = ttnn::floor(result);
}
return ttnn::where(ttnn::eqz(queue_id, input_tensor, memory_config), t_inf, result, memory_config, optional_output_tensor);
}
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_composite.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ struct ExecuteRdiv {
uint8_t queue_id,
const Tensor& input_tensor,
float value,
string round_mode = "None",
const std::string& round_mode = "None",
const std::optional<MemoryConfig>& memory_config = std::nullopt,
std::optional<Tensor> optional_output_tensor = std::nullopt);
};
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ void bind_unary_rdiv(py::module& module, const unary_operation_t& operation, con
[](const unary_operation_t& self,
const ttnn::Tensor& input_tensor,
float parameter_a,
string parameter_b,
const std::string& parameter_b,
const std::optional<MemoryConfig>& memory_config,
const std::optional<ttnn::Tensor>& output_tensor,
const uint8_t& queue_id) {
Expand Down

0 comments on commit 2153fca

Please sign in to comment.