From 3282d07bc13970ff51ae15e8d660d3db47d2cd4a Mon Sep 17 00:00:00 2001 From: Radomir Djogo Date: Tue, 25 Jun 2024 20:38:35 +0000 Subject: [PATCH] #4858: add support for typecast uint16<->fp32 --- .../python_api_testing/sweep_tests/pytorch_ops.py | 4 ++++ .../op_library/eltwise_unary/eltwise_unary_op.hpp | 2 ++ .../tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp | 2 ++ .../metal/llk_api/llk_math_unary_datacopy_api.h | 3 ++- .../metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h | 12 ++++++++++++ .../llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h | 12 ++++++++++++ .../compute_kernel_api/eltwise_unary/typecast.h | 4 +++- tt_metal/third_party/tt_llk_wormhole_b0 | 2 +- ttnn/cpp/ttnn/operations/copy.hpp | 1 + .../operations/eltwise/binary/device/binary_op.cpp | 4 +++- ttnn/cpp/ttnn/operations/unary.hpp | 1 + 11 files changed, 43 insertions(+), 4 deletions(-) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 9160c347397..907989f8cb4 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -1429,6 +1429,10 @@ def eltwise_typecast(x, *args, tt_input_dtype, tt_output_dtype, **kwargs): return x.to(torch.bfloat16).to(torch.float32) elif tt_input_dtype[0] == ttl.tensor.DataType.FLOAT32 and tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT16: return x.to(torch.bfloat16) + elif tt_input_dtype[0] == ttl.tensor.DataType.FLOAT32 and tt_output_dtype[0] == ttl.tensor.DataType.UINT16: + return torch.clamp(x.to(torch.int32), min=0, max=65535) # due to no uint16 support + elif tt_input_dtype[0] == ttl.tensor.DataType.UINT16 and tt_output_dtype[0] == ttl.tensor.DataType.FLOAT32: + return x.to(torch.float32) else: return x diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index ffe7874c4e2..87fde1e6592 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -218,6 +218,7 @@ inline Tensor run_eltwise_unary_with_output_tensor( preserve_fp32_precision or output_dtype == DataType::UINT32 or output_dtype == DataType::INT32 or + output_dtype == DataType::FLOAT32 or input_tensor.get_dtype() == DataType::UINT32 or input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to DST directly, fp32 is converted to fp16b @@ -276,6 +277,7 @@ inline Tensor run_eltwise_unary( preserve_fp32_precision or output_dtype == DataType::UINT32 or output_dtype == DataType::INT32 or + output_dtype == DataType::FLOAT32 or input_tensor.get_dtype() == DataType::UINT32 or input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to DST directly, fp32 is converted to fp16b diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp index 82f01945dd2..8f5defb0df3 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp @@ -80,6 +80,8 @@ namespace tt::tt_metal::detail { BFLOAT16 -> INT32 BFLOAT16 -> FLOAT32 FLOAT32 -> BFLOAT16 + UINT16 -> FLOAT32 + FLOAT32 -> UINT16 Input tensor must have tt_input_dtype data type. diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_datacopy_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_datacopy_api.h index 5195d7bd006..3abc91c3178 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_datacopy_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_datacopy_api.h @@ -44,6 +44,7 @@ inline void llk_math_eltwise_unary_datacopy_init( const std::uint32_t operand = 0) { const std::uint32_t operand_id = get_operand_id(operand); const std::uint32_t num_faces = get_operand_num_faces(operand_id); + const std::uint32_t dst_format = get_operand_dst_format(operand_id); _llk_math_eltwise_unary_datacopy_init_( - transpose_of_faces, within_face_16x16_transpose, num_faces); + transpose_of_faces, within_face_16x16_transpose, num_faces, dst_format); } diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h index fc106e1e858..6bd02f95fba 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h @@ -143,5 +143,17 @@ inline void calculate_typecast_fp32_to_fp16b() } } +template +inline void calculate_typecast_uint16_to_fp32() +{ + #pragma GCC unroll 0 + for (int d = 0; d < ITERATIONS; d++) { + TTI_SFPLOAD(0,6,3,0); + TTI_SFPCAST(0,1,0); + TTI_SFPSTORE(1,3,3,0); + dst_reg++; + } +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h index 6779cb3bc3c..1483bb2b41a 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h @@ -53,6 +53,18 @@ inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode dst_index, vector_mode); } + else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float32 && OUT_DTYPE == (uint32_t)DataFormat::UInt16) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_typecast_fp16b_to_uint16, + dst_index, + vector_mode); + } + else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::Float32) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_typecast_uint16_to_fp32, + dst_index, + vector_mode); + } } template diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h b/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h index e6edb3b2a07..2344ae4da8c 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h @@ -28,8 +28,10 @@ namespace ckernel { * Float16_b -> Int32 * Float16_b -> Float32 * Float32 -> Float16_b + * Float32 -> UInt16 + * UInt16 -> Float32 * - * For output to be UInt32, Dest must be in 32 bit mode. + * For input/output to be UInt32, Int32, or Float32, Dest must be in 32 bit mode. * * Return value: None * diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index ff50541c140..bd5e985a345 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit ff50541c1403c99f035995dad81e5944f057aa6b +Subproject commit bd5e985a3451b6edda05aef29f87221b17b85542 diff --git a/ttnn/cpp/ttnn/operations/copy.hpp b/ttnn/cpp/ttnn/operations/copy.hpp index f3e82703a4f..85cbb42cf74 100644 --- a/ttnn/cpp/ttnn/operations/copy.hpp +++ b/ttnn/cpp/ttnn/operations/copy.hpp @@ -52,6 +52,7 @@ struct Typecast { bool fp32_dest_acc_en = preserve_fp32_precision or output_dtype == DataType::UINT32 or output_dtype == DataType::INT32 or + output_dtype == DataType::FLOAT32 or input_dtype == DataType::UINT32 or input_dtype == DataType::INT32; auto unary_op = UnaryWithParam{UnaryOpType::TYPECAST, {static_cast(input_dtype), static_cast(output_dtype)}}; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp index 2c63e29e63f..876c29d2243 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp @@ -109,7 +109,9 @@ std::map get_defines( (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT16) || (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT16) || (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::BFLOAT16) || - (input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::FLOAT32))){ + (input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::FLOAT32) || + (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::UINT16) || + (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32))){ TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined"); auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value())); diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index 2ff9814e06f..794bfe512e7 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -49,6 +49,7 @@ inline Tensor execute_on_worker_thread( bool fp32_dest_acc_en = preserve_fp32_precision or output_dtype == DataType::UINT32 or output_dtype == DataType::INT32 or + output_dtype == DataType::FLOAT32 or input_tensor.get_dtype() == DataType::UINT32 or input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to // DST directly, fp32 is converted to fp16b