Skip to content

Commit

Permalink
#4858: add support for typecast uint16<->fp32
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Jun 25, 2024
1 parent 70fcce4 commit 3282d07
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 4 deletions.
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,10 @@ def eltwise_typecast(x, *args, tt_input_dtype, tt_output_dtype, **kwargs):
return x.to(torch.bfloat16).to(torch.float32)
elif tt_input_dtype[0] == ttl.tensor.DataType.FLOAT32 and tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT16:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttl.tensor.DataType.FLOAT32 and tt_output_dtype[0] == ttl.tensor.DataType.UINT16:
return torch.clamp(x.to(torch.int32), min=0, max=65535) # due to no uint16 support
elif tt_input_dtype[0] == ttl.tensor.DataType.UINT16 and tt_output_dtype[0] == ttl.tensor.DataType.FLOAT32:
return x.to(torch.float32)
else:
return x

Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ inline Tensor run_eltwise_unary_with_output_tensor(
preserve_fp32_precision or
output_dtype == DataType::UINT32 or
output_dtype == DataType::INT32 or
output_dtype == DataType::FLOAT32 or
input_tensor.get_dtype() == DataType::UINT32 or
input_tensor.get_dtype() ==
DataType::INT32; // MT: Currently only uint32/int32 is moved to DST directly, fp32 is converted to fp16b
Expand Down Expand Up @@ -276,6 +277,7 @@ inline Tensor run_eltwise_unary(
preserve_fp32_precision or
output_dtype == DataType::UINT32 or
output_dtype == DataType::INT32 or
output_dtype == DataType::FLOAT32 or
input_tensor.get_dtype() == DataType::UINT32 or
input_tensor.get_dtype() ==
DataType::INT32; // MT: Currently only uint32/int32 is moved to DST directly, fp32 is converted to fp16b
Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ namespace tt::tt_metal::detail {
BFLOAT16 -> INT32
BFLOAT16 -> FLOAT32
FLOAT32 -> BFLOAT16
UINT16 -> FLOAT32
FLOAT32 -> UINT16
Input tensor must have tt_input_dtype data type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ inline void llk_math_eltwise_unary_datacopy_init(
const std::uint32_t operand = 0) {
const std::uint32_t operand_id = get_operand_id(operand);
const std::uint32_t num_faces = get_operand_num_faces(operand_id);
const std::uint32_t dst_format = get_operand_dst_format(operand_id);
_llk_math_eltwise_unary_datacopy_init_<type, src_b_bcast_type, is_fp32_dest_acc_en>(
transpose_of_faces, within_face_16x16_transpose, num_faces);
transpose_of_faces, within_face_16x16_transpose, num_faces, dst_format);
}
Original file line number Diff line number Diff line change
Expand Up @@ -143,5 +143,17 @@ inline void calculate_typecast_fp32_to_fp16b()
}
}

template <bool APPROXIMATION_MODE, int ITERATIONS>
inline void calculate_typecast_uint16_to_fp32()
{
#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++) {
TTI_SFPLOAD(0,6,3,0);
TTI_SFPCAST(0,1,0);
TTI_SFPSTORE(1,3,3,0);
dst_reg++;
}
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,18 @@ inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float32 && OUT_DTYPE == (uint32_t)DataFormat::UInt16) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_uint16<APPROXIMATE,8>,
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::Float32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint16_to_fp32<APPROXIMATE,8>,
dst_index,
vector_mode);
}
}

template <bool APPROXIMATE>
Expand Down
4 changes: 3 additions & 1 deletion tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ namespace ckernel {
* Float16_b -> Int32
* Float16_b -> Float32
* Float32 -> Float16_b
* Float32 -> UInt16
* UInt16 -> Float32
*
* For output to be UInt32, Dest must be in 32 bit mode.
* For input/output to be UInt32, Int32, or Float32, Dest must be in 32 bit mode.
*
* Return value: None
*
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ struct Typecast {
bool fp32_dest_acc_en = preserve_fp32_precision or
output_dtype == DataType::UINT32 or
output_dtype == DataType::INT32 or
output_dtype == DataType::FLOAT32 or
input_dtype == DataType::UINT32 or
input_dtype == DataType::INT32;
auto unary_op = UnaryWithParam{UnaryOpType::TYPECAST, {static_cast<float>(input_dtype), static_cast<float>(output_dtype)}};
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ std::map<string, string> get_defines(
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT16) ||
(input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT16) ||
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::BFLOAT16) ||
(input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::FLOAT32))){
(input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::FLOAT32) ||
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::UINT16) ||
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32))){
TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined");

auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value()));
Expand Down
1 change: 1 addition & 0 deletions ttnn/cpp/ttnn/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ inline Tensor execute_on_worker_thread(
bool fp32_dest_acc_en = preserve_fp32_precision or
output_dtype == DataType::UINT32 or
output_dtype == DataType::INT32 or
output_dtype == DataType::FLOAT32 or
input_tensor.get_dtype() == DataType::UINT32 or
input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to
// DST directly, fp32 is converted to fp16b
Expand Down

0 comments on commit 3282d07

Please sign in to comment.