Skip to content

Commit

Permalink
#10082: Update memory_config
Browse files Browse the repository at this point in the history
  • Loading branch information
VirdhatchaniKN committed Jul 13, 2024
1 parent 42397b2 commit e9b5260
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ std::vector<Tensor> _clamp_max_bw(
std::vector<Tensor> _clamp_bw(
const Tensor& grad, const Tensor& input, float min, float max, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config());
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Tensor minT = ttnn::ge(input, min, std::nullopt, output_memory_config);
Tensor maxT = ttnn::le(input, max, std::nullopt, output_memory_config);
Tensor result = ttnn::logical_and(minT, maxT, std::nullopt, output_memory_config);
Expand All @@ -61,7 +61,7 @@ std::vector<Tensor> _clamp_bw(
std::vector<Tensor> _hardtanh_bw(
const Tensor& grad, const Tensor& input, float min, float max, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config());
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Tensor grad_result = where(
ttnn::le(input, ttnn::operations::creation::full_like(input, min), std::nullopt, output_memory_config),
0.0,
Expand All @@ -76,7 +76,7 @@ std::vector<Tensor> _hardtanh_bw(
std::vector<Tensor> _threshold_bw(
const Tensor& grad, const Tensor& input, float threshold, float value, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config());
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Tensor result = where(
ttnn::gtz(ttnn::add(input, -threshold, std::nullopt, output_memory_config), output_memory_config),
grad,
Expand All @@ -90,7 +90,7 @@ std::vector<Tensor> _threshold_bw(
std::vector<Tensor> _softplus_bw(
const Tensor& grad, const Tensor& input, float beta, float threshold, const std::optional<MemoryConfig>& output_mem_config) {
std::vector<Tensor> grad_tensor;
auto output_memory_config = output_mem_config.value_or(input.memory_config());
auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed
Tensor mul_input_beta = ttnn::multiply(input, beta, std::nullopt, output_memory_config);
Tensor exp_beta_self = ttnn::exp(mul_input_beta, false, output_memory_config);
Tensor sub_result = ttnn::add(mul_input_beta, -threshold, std::nullopt, output_memory_config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ struct ExecuteUnaryBackwardTwoFloat {
float max,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {
auto op_type = get_function_type1_w_two_float<unary_backward_op_type>();
return op_type(grad_tensor_arg, input_tensor_arg, min, max, memory_config);
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config());
return op_type(grad_tensor_arg, input_tensor_arg, min, max, output_memory_config);
}

};
Expand All @@ -54,7 +55,8 @@ struct ExecuteUnaryBackwardTwoFloatWithDefault {
float parameter_b,
const std::optional<MemoryConfig> &memory_config = std::nullopt) {
auto op_type = get_function_type1_w_two_float_with_default<unary_backward_op_type>();
return op_type(grad_tensor_arg, input_tensor_arg, parameter_a, parameter_b, memory_config);
auto output_memory_config = memory_config.value_or(input_tensor_arg.memory_config());
return op_type(grad_tensor_arg, input_tensor_arg, parameter_a, parameter_b, output_memory_config);
}

};
Expand Down

0 comments on commit e9b5260

Please sign in to comment.