Skip to content

Commit

Permalink
#4858: add typecast fp32 <-> int32
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Jun 27, 2024
1 parent 0fa6c82 commit 0f9a509
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 2 deletions.
4 changes: 4 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,10 @@ def eltwise_typecast(x, *args, tt_input_dtype, tt_output_dtype, **kwargs):
return torch.clamp(x.to(torch.int32), min=0, max=65535) # due to no uint16 support
elif tt_input_dtype[0] == ttl.tensor.DataType.UINT16 and tt_output_dtype[0] == ttl.tensor.DataType.FLOAT32:
return x.to(torch.float32)
elif tt_input_dtype[0] == ttl.tensor.DataType.FLOAT32 and tt_output_dtype[0] == ttl.tensor.DataType.INT32:
return x.to(torch.int32)
elif tt_input_dtype[0] == ttl.tensor.DataType.INT32 and tt_output_dtype[0] == ttl.tensor.DataType.FLOAT32:
return x.to(torch.float32)
else:
return x

Expand Down
2 changes: 2 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ namespace tt::tt_metal::detail {
FLOAT32 -> BFLOAT16
UINT16 -> FLOAT32
FLOAT32 -> UINT16
INT32 -> FLOAT32
FLOAT32 -> INT32
Input tensor must have tt_input_dtype data type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ inline void calculate_typecast_fp16b_to_int32()
dst_reg[0] = 0;
} v_elseif (exp > 30) {
// set to int32 max value in case of overflow
vInt tmp = std::numeric_limits<int32_t>::max();;
vInt tmp = std::numeric_limits<int32_t>::max();
// check sign
v_if (in < 0) {
tmp = reinterpret<vInt>(setsgn(reinterpret<vFloat>(tmp), 1));
Expand Down Expand Up @@ -155,5 +155,17 @@ inline void calculate_typecast_uint16_to_fp32()
}
}

template <bool APPROXIMATION_MODE, int ITERATIONS>
inline void calculate_typecast_int32_to_fp32()
{
#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++) {
TTI_SFPLOAD(0,12,3,0);
TTI_SFPCAST(0,1,0);
TTI_SFPSTORE(1,3,3,0);
dst_reg++;
}
}

} // namespace sfpu
} // namespace ckernel
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,18 @@ inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float32 && OUT_DTYPE == (uint32_t)DataFormat::Int32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_int32<APPROXIMATE,8>,
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Int32 && OUT_DTYPE == (uint32_t)DataFormat::Float32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_int32_to_fp32<APPROXIMATE,8>,
dst_index,
vector_mode);
}
}

template <bool APPROXIMATE>
Expand Down
2 changes: 2 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace ckernel {
* Float32 -> Float16_b
* Float32 -> UInt16
* UInt16 -> Float32
* Float32 -> Int32
* Int32 -> Float32
*
* For input/output to be UInt32, Int32, or Float32, Dest must be in 32 bit mode.
*
Expand Down
4 changes: 3 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ std::map<string, string> get_defines(
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::BFLOAT16) ||
(input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::FLOAT32) ||
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::UINT16) ||
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32))){
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32) ||
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::INT32) ||
(input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::FLOAT32))){
TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined");

auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value()));
Expand Down

0 comments on commit 0f9a509

Please sign in to comment.