diff --git a/ttnn/cpp/ttnn/operations/copy.hpp b/ttnn/cpp/ttnn/operations/copy.hpp index d324a1d78342..60cb33419ab7 100644 --- a/ttnn/cpp/ttnn/operations/copy.hpp +++ b/ttnn/cpp/ttnn/operations/copy.hpp @@ -46,9 +46,10 @@ struct Typecast { TT_FATAL(output_dtype == optional_output_tensor.value().get_dtype(), "If both output dtype and output tensor provided dtype should match"); } + DataType input_dtype = input.get_dtype(); auto memory_config = memory_config_arg.value_or(input.memory_config()); - bool fp32_dest_acc_en = output_dtype == DataType::UINT32; - auto unary_op = UnaryWithParam{UnaryOpType::TYPECAST, static_cast(output_dtype)}; + bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_dtype == DataType::INT32; + auto unary_op = UnaryWithParam{UnaryOpType::TYPECAST, {static_cast(input_dtype), static_cast(output_dtype)}}; auto eltwise_op = EltwiseUnary{{unary_op}, memory_config, fp32_dest_acc_en, output_dtype}; return operation::run(eltwise_op, {input}, {}, {optional_output_tensor}, queue_id).at(0); }