diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 0013318547b..48cb87c71ff 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -1437,6 +1437,14 @@ def eltwise_typecast(x, *args, tt_input_dtype, tt_output_dtype, **kwargs): return x.to(torch.int32) elif tt_input_dtype[0] == ttl.tensor.DataType.INT32 and tt_output_dtype[0] == ttl.tensor.DataType.FLOAT32: return x.to(torch.float32) + elif tt_input_dtype[0] == ttl.tensor.DataType.BFLOAT8_B and tt_output_dtype[0] == ttl.tensor.DataType.UINT16: + return torch.clamp(x.to(torch.bfloat16).to(torch.int32), min=0, max=65535) # due to no uint16 support + elif tt_input_dtype[0] == ttl.tensor.DataType.UINT16 and tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT8_B: + return x.to(torch.bfloat16) + elif tt_input_dtype[0] == ttl.tensor.DataType.BFLOAT8_B and tt_output_dtype[0] == ttl.tensor.DataType.INT32: + return x.to(torch.bfloat16).to(torch.int32) + elif tt_input_dtype[0] == ttl.tensor.DataType.INT32 and tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT8_B: + return x.to(torch.bfloat16) else: return x diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp index fab80a92de8..0ea805e083d 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp @@ -84,6 +84,10 @@ namespace tt::tt_metal::detail { FLOAT32 -> UINT16 INT32 -> FLOAT32 FLOAT32 -> INT32 + UINT16 -> BFLOAT8_B + BFLOAT8_B -> UINT16 + INT32 -> BFLOAT8_B + BFLOAT8_B -> INT32 Input tensor must have tt_input_dtype data type. diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h index ed4959684e6..82c9733e989 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h @@ -77,6 +77,30 @@ inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode dst_index, vector_mode); } + else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::UInt16) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_typecast_fp16b_to_uint16, + dst_index, + vector_mode); + } + else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_typecast_uint16_to_fp16b, + dst_index, + vector_mode); + } + else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::Int32) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_typecast_fp16b_to_int32, + dst_index, + vector_mode); + } + else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Int32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_typecast_int32_to_fp16b, + dst_index, + vector_mode); + } } template diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h b/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h index baf111e58ed..c6b1f95bd57 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h @@ -32,6 +32,10 @@ namespace ckernel { * UInt16 -> Float32 * Float32 -> Int32 * Int32 -> Float32 + * Bfp8_b -> UInt16 + * UInt16 -> Bfp8_b + * Bfp8_b -> Int32 + * Int32 -> Bfp8_b * * For input/output to be UInt32, Int32, or Float32, Dest must be in 32 bit mode. * diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp index 5e40aeb1b9b..501f5d75ac6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp @@ -113,7 +113,11 @@ std::map get_defines( (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::UINT16) || (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32) || (input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::INT32) || - (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::FLOAT32))){ + (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::FLOAT32) || + (input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::UINT16) || + (input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT8_B) || + (input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::INT32) || + (input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT8_B))){ TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined"); auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value()));