Skip to content

Commit

Permalink
#4858: add typecast bfp8_b
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Jun 27, 2024
1 parent b3b22b7 commit 07337e6
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 1 deletion.
8 changes: 8 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1437,6 +1437,14 @@ def eltwise_typecast(x, *args, tt_input_dtype, tt_output_dtype, **kwargs):
return x.to(torch.int32)
elif tt_input_dtype[0] == ttl.tensor.DataType.INT32 and tt_output_dtype[0] == ttl.tensor.DataType.FLOAT32:
return x.to(torch.float32)
elif tt_input_dtype[0] == ttl.tensor.DataType.BFLOAT8_B and tt_output_dtype[0] == ttl.tensor.DataType.UINT16:
return torch.clamp(x.to(torch.bfloat16).to(torch.int32), min=0, max=65535) # due to no uint16 support
elif tt_input_dtype[0] == ttl.tensor.DataType.UINT16 and tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT8_B:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttl.tensor.DataType.BFLOAT8_B and tt_output_dtype[0] == ttl.tensor.DataType.INT32:
return x.to(torch.bfloat16).to(torch.int32)
elif tt_input_dtype[0] == ttl.tensor.DataType.INT32 and tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT8_B:
return x.to(torch.bfloat16)
else:
return x

Expand Down
4 changes: 4 additions & 0 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ namespace tt::tt_metal::detail {
FLOAT32 -> UINT16
INT32 -> FLOAT32
FLOAT32 -> INT32
UINT16 -> BFLOAT8_B
BFLOAT8_B -> UINT16
INT32 -> BFLOAT8_B
BFLOAT8_B -> INT32
Input tensor must have tt_input_dtype data type.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,30 @@ inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::UInt16) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_uint16<APPROXIMATE,8>,
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint16_to_fp16b<APPROXIMATE,8>,
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::Int32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_int32<APPROXIMATE,8>,
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Int32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_int32_to_fp16b<APPROXIMATE,8>,
dst_index,
vector_mode);
}
}

template <bool APPROXIMATE>
Expand Down
4 changes: 4 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ namespace ckernel {
* UInt16 -> Float32
* Float32 -> Int32
* Int32 -> Float32
* Bfp8_b -> UInt16
* UInt16 -> Bfp8_b
* Bfp8_b -> Int32
* Int32 -> Bfp8_b
*
* For input/output to be UInt32, Int32, or Float32, Dest must be in 32 bit mode.
*
Expand Down
6 changes: 5 additions & 1 deletion ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,11 @@ std::map<string, string> get_defines(
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::UINT16) ||
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32) ||
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::INT32) ||
(input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::FLOAT32))){
(input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::FLOAT32) ||
(input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::UINT16) ||
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT8_B) ||
(input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::INT32) ||
(input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT8_B))){
TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined");

auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value()));
Expand Down

0 comments on commit 07337e6

Please sign in to comment.