Skip to content

Commit

Permalink
Bfp4_b typecast support (#15574)
Browse files Browse the repository at this point in the history
### Ticket
#12068

### What's changed
Adding bfp4_b to typecast supported formats. For GS, WH and BH.

### Checklist
- [x] Post commit CI passes:
https://github.com/tenstorrent/tt-metal/actions/runs/12201731185
- [x] Blackhole Post commit (if applicable):
https://github.com/tenstorrent/tt-metal/actions/runs/12201773504
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
rdjogoTT authored Dec 6, 2024
1 parent d422ccf commit 71a4cff
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@
(torch.float16, ttnn.float32),
(torch.float32, ttnn.bfloat8_b),
(torch.bfloat16, ttnn.bfloat16),
(torch.bfloat16, ttnn.bfloat4_b),
(torch.int, ttnn.uint32),
),
)
@pytest.mark.parametrize(
"pt_output_dtype, tt_output_dtype",
(
(torch.bfloat16, ttnn.bfloat16),
(torch.bfloat16, ttnn.bfloat4_b),
(torch.float32, ttnn.bfloat8_b),
),
)
Expand Down Expand Up @@ -85,6 +87,8 @@ def test_run_typecast_op(
test_args["input_mem_config"] = [input_mem_config]
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_pcc
if tt_input_dtype == ttnn.bfloat4_b or tt_output_dtype == ttnn.bfloat4_b:
comparison_func = partial(comparison_funcs.comp_pcc, pcc=0.97)

run_single_pytorch_test(
"typecast",
Expand Down
32 changes: 32 additions & 0 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,38 @@ def eltwise_typecast(x, *args, tt_input_dtype, tt_output_dtype, **kwargs):
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.uint16 and tt_output_dtype[0] == ttnn.uint32:
return torch.clamp(x.to(torch.int32), min=0, max=65535)
elif tt_input_dtype[0] == ttnn.bfloat8_b and tt_output_dtype[0] == ttnn.bfloat16:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat16 and tt_output_dtype[0] == ttnn.bfloat8_b:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat8_b and tt_output_dtype[0] == ttnn.float32:
return x.to(torch.bfloat16).to(torch.float32)
elif tt_input_dtype[0] == ttnn.float32 and tt_output_dtype[0] == ttnn.bfloat8_b:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat4_b and tt_output_dtype[0] == ttnn.uint16:
return torch.clamp(x.to(torch.bfloat16).to(torch.int32), min=0, max=65535) # due to no uint16 support
elif tt_input_dtype[0] == ttnn.uint16 and tt_output_dtype[0] == ttnn.bfloat4_b:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat4_b and tt_output_dtype[0] == ttnn.int32:
return x.to(torch.bfloat16).to(torch.int32)
elif tt_input_dtype[0] == ttnn.int32 and tt_output_dtype[0] == ttnn.bfloat4_b:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat4_b and tt_output_dtype[0] == ttnn.uint32:
return torch.relu(x.to(torch.int32)) # due to no uint32 support
elif tt_input_dtype[0] == ttnn.uint32 and tt_output_dtype[0] == ttnn.bfloat4_b:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat4_b and tt_output_dtype[0] == ttnn.bfloat16:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat16 and tt_output_dtype[0] == ttnn.bfloat4_b:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat4_b and tt_output_dtype[0] == ttnn.float32:
return x.to(torch.bfloat16).to(torch.float32)
elif tt_input_dtype[0] == ttnn.float32 and tt_output_dtype[0] == ttnn.bfloat4_b:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat4_b and tt_output_dtype[0] == ttnn.bfloat8_b:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttnn.bfloat8_b and tt_output_dtype[0] == ttnn.bfloat4_b:
return x.to(torch.bfloat16)
else:
return x

Expand Down
18 changes: 18 additions & 0 deletions tests/ttnn/unit_tests/operations/eltwise/test_eltwise_typecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,22 @@
(torch.bfloat16, ttnn.bfloat8_b, ttnn.uint32),
(torch.int, ttnn.uint32, ttnn.bfloat8_b),
(torch.int, ttnn.uint16, ttnn.uint32),
(torch.bfloat16, ttnn.bfloat8_b, ttnn.bfloat16),
(torch.bfloat16, ttnn.bfloat16, ttnn.bfloat8_b),
(torch.bfloat16, ttnn.bfloat8_b, ttnn.float32),
(torch.float32, ttnn.float32, ttnn.bfloat8_b),
(torch.bfloat16, ttnn.bfloat4_b, ttnn.int32),
(torch.int, ttnn.int32, ttnn.bfloat4_b),
(torch.bfloat16, ttnn.bfloat4_b, ttnn.uint16),
(torch.int, ttnn.uint16, ttnn.bfloat4_b),
(torch.bfloat16, ttnn.bfloat4_b, ttnn.uint32),
(torch.int, ttnn.uint32, ttnn.bfloat4_b),
(torch.bfloat16, ttnn.bfloat4_b, ttnn.bfloat16),
(torch.bfloat16, ttnn.bfloat16, ttnn.bfloat4_b),
(torch.bfloat16, ttnn.bfloat4_b, ttnn.float32),
(torch.float32, ttnn.float32, ttnn.bfloat4_b),
(torch.bfloat16, ttnn.bfloat4_b, ttnn.bfloat8_b),
(torch.bfloat16, ttnn.bfloat8_b, ttnn.bfloat4_b),
),
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -95,6 +111,8 @@ def test_run_eltwise_typecast_op(
test_args["input_mem_config"] = [input_mem_config]
test_args.update({"output_mem_config": dst_mem_config})
comparison_func = comparison_funcs.comp_pcc
if tt_input_dtype == ttnn.bfloat4_b or tt_output_dtype == ttnn.bfloat4_b:
comparison_func = partial(comparison_funcs.comp_pcc, pcc=0.98)

run_single_pytorch_test(
"eltwise-typecast",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,44 @@ inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::UInt32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint16_to_uint32<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::Float16_b) {
// no SFPU kernel needed, handled by unpacker
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float16_b && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) {
// no SFPU kernel needed, handled by packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::Float32) {
// no SFPU kernel needed, handled by unpacker/packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) {
// no SFPU kernel needed, handled by packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::UInt16) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_uint16<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint16_to_fp16b<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::Int32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_int32<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Int32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_int32_to_fp16b<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::UInt32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_uint32<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint32_to_fp16b<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::Float16_b) {
// no SFPU kernel needed, handled by unpacker
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float16_b && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
// no SFPU kernel needed, handled by packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) {
// no SFPU kernel needed, handled by unpacker
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
// no SFPU kernel needed, handled by packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::Float32) {
// no SFPU kernel needed, handled by unpacker/packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
// no SFPU kernel needed, handled by packer
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,44 @@ inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::UInt32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint16_to_uint32<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::Float16_b) {
// no SFPU kernel needed, handled by unpacker
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float16_b && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) {
// no SFPU kernel needed, handled by packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::Float32) {
// no SFPU kernel needed, handled by unpacker/packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) {
// no SFPU kernel needed, handled by packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::UInt16) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_uint16<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint16_to_fp16b<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::Int32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_int32<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Int32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_int32_to_fp16b<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::UInt32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_uint32<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint32_to_fp16b<APPROXIMATE, 8>, dst_index, vector_mode);
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::Float16_b) {
// no SFPU kernel needed, handled by unpacker
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float16_b && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
// no SFPU kernel needed, handled by packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::Bfp8_b) {
// no SFPU kernel needed, handled by unpacker
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp8_b && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
// no SFPU kernel needed, handled by packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Bfp4_b && OUT_DTYPE == (uint32_t)DataFormat::Float32) {
// no SFPU kernel needed, handled by unpacker/packer
} else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float32 && OUT_DTYPE == (uint32_t)DataFormat::Bfp4_b) {
// no SFPU kernel needed, handled by packer
}
}

Expand Down
8 changes: 8 additions & 0 deletions tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,14 @@ namespace ckernel {
* Bfp8_b <-> Int32
* Bfp8_b <-> UInt16
* Bfp8_b <-> UInt32
* Bfp8_b <-> Float16_b
* Bfp8_b <-> Float32
* Bfp4_b <-> Int32
* Bfp4_b <-> UInt16
* Bfp4_b <-> UInt32
* Bfp4_b <-> Bfp8_b
* Bfp4_b <-> Float16_b
* Bfp4_b <-> Float32
* UInt16 -> UInt32
*
* For input/output to be UInt32, Int32, or Float32, Dest must be in 32 bit mode.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ void CopyDeviceOperation::validate_with_output_tensors(
const auto& input_tensor_a = input_tensors.at(0);
TT_FATAL(
input_tensor_a.get_dtype() == DataType::BFLOAT16 or input_tensor_a.get_dtype() == DataType::BFLOAT8_B or
input_tensor_a.get_dtype() == DataType::FLOAT32,
input_tensor_a.get_dtype() == DataType::FLOAT32 or input_tensor_a.get_dtype() == DataType::BFLOAT4_B,
"Typecast operation is only supported on Grayskull for float/bfloat inputs");
TT_FATAL(
this->output_dtype == DataType::BFLOAT16 or this->output_dtype == DataType::BFLOAT8_B or
this->output_dtype == DataType::FLOAT32,
this->output_dtype == DataType::FLOAT32 or this->output_dtype == DataType::BFLOAT4_B,
"Typecast operation is only supported on Grayskull for float/bfloat outputs");
TT_FATAL(input_tensor_a.storage_type() == StorageType::DEVICE, "Operands to copy need to be on device!");
TT_FATAL(input_tensor_a.buffer() != nullptr, "Operands to copy need to be allocated in buffers on device!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ std::map<std::string, std::string> get_defines(
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT16) ||
(input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT16) ||
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::BFLOAT16) ||
(input_dtype.value() == DataType::BFLOAT16 && output_dtype.value() == DataType::FLOAT32) ||
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::UINT16) ||
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::FLOAT32) ||
(input_dtype.value() == DataType::FLOAT32 && output_dtype.value() == DataType::INT32) ||
Expand All @@ -124,7 +123,13 @@ std::map<std::string, std::string> get_defines(
(input_dtype.value() == DataType::UINT32 && output_dtype.value() == DataType::FLOAT32) ||
(input_dtype.value() == DataType::BFLOAT8_B && output_dtype.value() == DataType::UINT32) ||
(input_dtype.value() == DataType::UINT32 && output_dtype.value() == DataType::BFLOAT8_B) ||
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::UINT32))) {
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::UINT32) ||
(input_dtype.value() == DataType::BFLOAT4_B && output_dtype.value() == DataType::UINT32) ||
(input_dtype.value() == DataType::UINT32 && output_dtype.value() == DataType::BFLOAT4_B) ||
(input_dtype.value() == DataType::BFLOAT4_B && output_dtype.value() == DataType::UINT16) ||
(input_dtype.value() == DataType::UINT16 && output_dtype.value() == DataType::BFLOAT4_B) ||
(input_dtype.value() == DataType::BFLOAT4_B && output_dtype.value() == DataType::INT32) ||
(input_dtype.value() == DataType::INT32 && output_dtype.value() == DataType::BFLOAT4_B))) {
TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined");

auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(input_dtype.value()));
Expand Down

0 comments on commit 71a4cff

Please sign in to comment.