Skip to content

Commit

Permalink
#4858: update ttnn typecast to provide 2 params
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Jun 11, 2024
1 parent b8ed2e8 commit 4fbf03d
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions ttnn/cpp/ttnn/operations/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ struct Typecast {
TT_FATAL(output_dtype == optional_output_tensor.value().get_dtype(), "If both output dtype and output tensor provided dtype should match");
}

DataType input_dtype = input.get_dtype();
auto memory_config = memory_config_arg.value_or(input.memory_config());
bool fp32_dest_acc_en = output_dtype == DataType::UINT32;
auto unary_op = UnaryWithParam{UnaryOpType::TYPECAST, static_cast<float>(output_dtype)};
bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_dtype == DataType::INT32;
auto unary_op = UnaryWithParam{UnaryOpType::TYPECAST, {static_cast<float>(input_dtype), static_cast<float>(output_dtype)}};
auto eltwise_op = EltwiseUnary{{unary_op}, memory_config, fp32_dest_acc_en, output_dtype};
return operation::run(eltwise_op, {input}, {}, {optional_output_tensor}, queue_id).at(0);
}
Expand Down

0 comments on commit 4fbf03d

Please sign in to comment.