Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Typecast int32->fp16b #9317

Merged
merged 5 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,12 +1374,14 @@ def eltwise_identity(x, *args, **kwargs):
return x


def eltwise_typecast(x, *args, tt_output_dtype, **kwargs):
if tt_output_dtype[0] == ttl.tensor.DataType.UINT16:
def eltwise_typecast(x, *args, tt_input_dtype, tt_output_dtype, **kwargs):
if tt_input_dtype[0] == ttl.tensor.DataType.BFLOAT16 and tt_output_dtype[0] == ttl.tensor.DataType.UINT16:
return torch.clamp(x.to(torch.int32), min=0, max=65535) # due to no uint16 support
elif tt_output_dtype[0] == ttl.tensor.DataType.UINT32:
elif tt_input_dtype[0] == ttl.tensor.DataType.BFLOAT16 and tt_output_dtype[0] == ttl.tensor.DataType.UINT32:
return torch.relu(x.to(torch.int32)) # due to no uint32 support
elif tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT16:
elif tt_input_dtype[0] == ttl.tensor.DataType.UINT16 and tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT16:
return x.to(torch.bfloat16)
elif tt_input_dtype[0] == ttl.tensor.DataType.INT32 and tt_output_dtype[0] == ttl.tensor.DataType.BFLOAT16:
return x.to(torch.bfloat16)
else:
return x
Expand Down
6 changes: 3 additions & 3 deletions tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,15 +2331,15 @@ def eltwise_typecast(
x,
*args,
device,
dtype,
tt_input_dtype,
tt_output_dtype,
layout,
input_mem_config,
output_mem_config,
**kwargs,
):
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0])
t1 = ttl.tensor.eltwise_typecast(t0, tt_output_dtype[0], output_mem_config=output_mem_config)
t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], tt_input_dtype[0])
t1 = ttl.tensor.eltwise_typecast(t0, tt_input_dtype[0], tt_output_dtype[0], output_mem_config=output_mem_config)

return tt2torch_tensor(t1)

Expand Down
13 changes: 9 additions & 4 deletions tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@ namespace eltwise_binary_op_utils {
using namespace tt::tt_metal;

std::map<string, string> get_defines(
BinaryOpType op_type, const std::optional<DataType> output_dtype, const std::optional<std::vector<UnaryWithParam>> fused_activations) {
BinaryOpType op_type,
const std::optional<DataType> in_dtype,
const std::optional<DataType> output_dtype,
const std::optional<std::vector<UnaryWithParam>> fused_activations) {

std::map<string, string> defines;
string op_name = "sub_tiles";
string op_binary_type = "EltwiseBinaryType::ELWSUB";
Expand Down Expand Up @@ -106,12 +110,13 @@ std::map<string, string> get_defines(
default: TT_ASSERT(false && "Undefined op type");
}

if(output_dtype.has_value() && (output_dtype.value() == DataType::UINT32 || output_dtype.value() == DataType::UINT16)){
if(in_dtype.has_value() && output_dtype.has_value() && (output_dtype.value() == DataType::UINT32 || output_dtype.value() == DataType::UINT16)){
TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined");

auto dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(output_dtype.value()));
auto in_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(in_dtype.value()));
auto out_dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(output_dtype.value()));
defines.insert({"SFPU_OP_CHAIN_0",
fmt::format("typecast_tile_init(); typecast_tile<{0}u>(i);", dataformat)});
fmt::format("typecast_tile_init(); typecast_tile<{0}u, {1}u>(i);", in_dataformat, out_dataformat)});
defines.insert({"SFPU_OP_TYPECAST_INCLUDE", "1"});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ enum class BinaryOpType {

namespace eltwise_binary_op_utils {

std::map<string, string> get_defines(BinaryOpType op_type, const std::optional<DataType> out_dtype = std::nullopt,
std::map<string, string> get_defines(BinaryOpType op_type,
const std::optional<DataType> in_dtype = std::nullopt,
const std::optional<DataType> out_dtype = std::nullopt,
const std::optional<std::vector<UnaryWithParam>> fused_activations = std::nullopt);

} // namespace eltwise_binary_op_utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const Tensor &a, const
}
auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config);

std::map<string, string> eltwise_defines = eltwise_binary_op_utils::get_defines(op_type, output.get_dtype(), fused_activations);
std::map<string, string> eltwise_defines = eltwise_binary_op_utils::get_defines(op_type, a.get_dtype(), output.get_dtype(), fused_activations);

if (eltwise_defines.find("SFPU_OP_INIT_PRE_IN0_0") != eltwise_defines.end()) {
tt_metal::CircularBufferConfig cb_interm_config = tt_metal::CircularBufferConfig(1 * src0_single_tile_size, {{CB::c_intermed0, src0_cb_data_format}})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,14 @@ std::pair<string, string> get_op_init_and_func_parameterized(
break;
}
case UnaryOpType::TYPECAST:
TT_ASSERT(params.size() == 2, "Expected eltwise_typecast to take 2 parameters");
op_init_and_name = {
"typecast_tile_init();",
fmt::format("typecast_tile<{1}u>({0});", idst, std::to_string((uint32_t)datatype_to_dataformat_converter((DataType)param0)))};
fmt::format(
"typecast_tile<{1}u, {2}u>({0});",
idst,
std::to_string((uint32_t)datatype_to_dataformat_converter((DataType)params[0])),
std::to_string((uint32_t)datatype_to_dataformat_converter((DataType)params[1])))};
break;
default: TT_ASSERT(false && "unexpected parameterized type");
};
Expand Down
17 changes: 14 additions & 3 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ inline Tensor run_eltwise_unary(
const std::vector<UnaryWithParam>& ops_chain,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified");
DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast<DataType>(ops_chain[0].params[0]) : input_tensor.get_dtype();
DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast<DataType>(ops_chain[0].params[1]) : input_tensor.get_dtype();
bool fp32_dest_acc_en =
output_dtype == DataType::UINT32 or
input_tensor.get_dtype() == DataType::UINT32 or
Expand Down Expand Up @@ -247,7 +247,7 @@ inline Tensor run_eltwise_unary(
const std::vector<UnaryWithParam>& ops_chain,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified");
DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast<DataType>(ops_chain[0].params[0]) : input_tensor.get_dtype();
DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast<DataType>(ops_chain[0].params[1]) : input_tensor.get_dtype();
bool fp32_dest_acc_en =
output_dtype == DataType::UINT32 or
input_tensor.get_dtype() == DataType::UINT32 or
Expand Down Expand Up @@ -378,7 +378,6 @@ constexpr auto rsub = make_eltwise_unary_with_param<UnaryOpType::RSUB>{};
constexpr auto silu = make_eltwise_unary<UnaryOpType::SILU>{};
constexpr auto identity = make_eltwise_unary<UnaryOpType::IDENTITY>{};
constexpr auto identity_uint32 = make_eltwise_unary<UnaryOpType::IDENTITY_UINT32>{};
constexpr auto eltwise_typecast = make_eltwise_unary_with_param<UnaryOpType::TYPECAST, uint32_t>{};
constexpr auto add_unary_sfpu = make_eltwise_symmetric_binop_unary_with_param<UnaryOpType::ADD_UNARY_SFPU>{};
constexpr auto mul_unary_sfpu = make_eltwise_symmetric_binop_unary_with_param<UnaryOpType::MUL_UNARY_SFPU>{};
constexpr auto unary_gt = make_eltwise_unary_with_param<UnaryOpType::UNARY_GT>{};
Expand Down Expand Up @@ -452,6 +451,18 @@ inline Tensor softplus(
input_tensor, {UnaryWithParam(UnaryOpType::SOFTPLUS, {beta, threshold})}, output_mem_config);
}

inline Tensor eltwise_typecast(
const Tensor& input_tensor,
uint32_t tt_input_dtype,
uint32_t tt_output_dtype,
const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) {
TT_ASSERT(input_tensor.device()->arch() != tt::ARCH::GRAYSKULL, "eltwise_typecast is not currently supported on Grayskull");
return run_eltwise_unary(
input_tensor,
{UnaryWithParam(UnaryOpType::TYPECAST, {static_cast<float>(tt_input_dtype), static_cast<float>(tt_output_dtype)})},
output_mem_config);
}

inline Tensor unary_chain(
const Tensor& input_tensor,
std::vector<UnaryWithParam> ops_chain,
Expand Down
27 changes: 21 additions & 6 deletions tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,27 @@ namespace tt::tt_metal::detail {
detail::bind_unary_op(m_tensor, "silu", silu, R"doc(Returns tensor with the silu all of elements of the input tensor ``{0}``.)doc");
detail::bind_unary_op(m_tensor, "neg", neg, R"doc(Returns tensor with the negate all of elements of the input tensor ``{0}``.)doc");

detail::bind_unary_op_with_param(
m_tensor, "eltwise_typecast", eltwise_typecast,
py::arg("tt_output_dtype"),
R"doc(Returns tensor with all of the elements of the input tensor ``{0}`` typecasted from bfloat16 to uint32, bfloat16 to uint16, or uint16 to bfloat16.)doc",
R"doc("Indicates output dtype of typecast", "ttl.tensor.DataType", "")doc"
);
m_tensor.def("eltwise_typecast", &eltwise_typecast,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor note: Can it be documented either here or in the compute kernel api that GS is not supported for this op?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

py::arg("input").noconvert(), py::arg("tt_input_dtype"), py::arg("tt_output_dtype"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc(
Returns tensor with all elements of the input tensor ``{0}`` typecasted.
Supported typecasts:
BFLOAT16 -> UINT32
BFLOAT16 -> UINT16
UINT16 -> BFLOAT16
INT32 -> BFLOAT16

Input tensor must have tt_input_dtype data type.

Output tensor will have tt_output_dtype data type.

.. csv-table::
:header: "Argument", "Description", "Data type", "Valid range", "Required"

"input", "Tensor softplus is applied to", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes"
"tt_input_dtype", "Input tensor DataType", "DataType", "One of supported input DataTypes", "Yes"
"tt_output_dtype", "Desired output tensor DataType", "DataType", "One of supported output DataTypes", "Yes"
"output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No"
)doc");

detail::bind_unary_op_with_param(
m_tensor, "exp", py::overload_cast<const Tensor&, bool, const MemoryConfig&>(&exp),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,21 @@ inline void calculate_typecast_uint16_to_fp16b()
for (int d = 0; d < ITERATIONS; d++) {
TTI_SFPLOAD(0,6,3,0);
TTI_SFPCAST(0,1,0);
TTI_SFPSTORE(1,2,3,0);
TTI_SFP_STOCH_RND(0,0,3,1,2,1);
TTI_SFPSTORE(2,2,3,0);
dst_reg++;
}
}

template <bool APPROXIMATION_MODE, int ITERATIONS>
inline void calculate_typecast_int32_to_fp16b()
{
#pragma GCC unroll 0
for (int d = 0; d < ITERATIONS; d++) {
TTI_SFPLOAD(0,12,3,0);
TTI_SFPCAST(0,1,0);
TTI_SFP_STOCH_RND(0,0,3,1,2,1);
TTI_SFPSTORE(2,2,3,0);
dst_reg++;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,32 @@ namespace ckernel {

// New LLK SFPU APIs

template <bool APPROXIMATE, uint32_t OUT_DTYPE>
template <bool APPROXIMATE, uint32_t IN_DTYPE, uint32_t OUT_DTYPE>
inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode = (int)VectorMode::RC) {
if constexpr (OUT_DTYPE == (uint32_t)DataFormat::UInt16) {
if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float16_b && OUT_DTYPE == (uint32_t)DataFormat::UInt16) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_uint16<APPROXIMATE,8>,
dst_index,
vector_mode);
}
else if constexpr (OUT_DTYPE == (uint32_t)DataFormat::UInt32) {
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Float16_b && OUT_DTYPE == (uint32_t)DataFormat::UInt32) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_fp16b_to_uint32<APPROXIMATE,8>,
dst_index,
vector_mode);
}
else if constexpr (OUT_DTYPE == (uint32_t)DataFormat::Float16_b) {
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::UInt16 && OUT_DTYPE == (uint32_t)DataFormat::Float16_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_uint16_to_fp16b<APPROXIMATE,8>,
dst_index,
vector_mode);
}
else if constexpr (IN_DTYPE == (uint32_t)DataFormat::Int32 && OUT_DTYPE == (uint32_t)DataFormat::Float16_b) {
llk_math_eltwise_unary_sfpu_params<APPROXIMATE>(
ckernel::sfpu::calculate_typecast_int32_to_fp16b<APPROXIMATE,8>,
dst_index,
vector_mode);
}
}

template <bool APPROXIMATE>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,21 @@ namespace ckernel {
* Float16_b -> UInt32
* Float16_b -> UInt16
* UInt16 -> Float16_b
* Int32 -> Float16_b
*
* For output to be UInt32, Dest must be in 32 bit mode.
*
* Return value: None
*
* | Argument | Description | Type | Valid Range | Required |
* |----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------|
* | tile_index | The index of the tile in DST register buffer to perform typecast operation | uint32_t | Must be less than the size of the DST register buffer | True |
* | IN_DTYPE | Input data format | uint32_t | Must be valid tt::DataFormat | True |
* | OUT_DTYPE | Desired output data format | uint32_t | Must be valid tt::DataFormat | True |
*/
template <uint32_t OUT_DTYPE>
template <uint32_t IN_DTYPE, uint32_t OUT_DTYPE>
ALWI void typecast_tile(uint32_t idst) {
MATH(( llk_math_eltwise_unary_sfpu_typecast<APPROX, OUT_DTYPE>(idst) ));
MATH(( llk_math_eltwise_unary_sfpu_typecast<APPROX, IN_DTYPE, OUT_DTYPE>(idst) ));
}

/**
Expand Down
5 changes: 3 additions & 2 deletions ttnn/cpp/ttnn/operations/copy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ struct Typecast {
TT_FATAL(output_dtype == optional_output_tensor.value().get_dtype(), "If both output dtype and output tensor provided dtype should match");
}

DataType input_dtype = input.get_dtype();
auto memory_config = memory_config_arg.value_or(input.memory_config());
bool fp32_dest_acc_en = output_dtype == DataType::UINT32;
auto unary_op = UnaryWithParam{UnaryOpType::TYPECAST, static_cast<float>(output_dtype)};
bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_dtype == DataType::INT32;
auto unary_op = UnaryWithParam{UnaryOpType::TYPECAST, {static_cast<float>(input_dtype), static_cast<float>(output_dtype)}};
auto eltwise_op = EltwiseUnary{{unary_op}, memory_config, fp32_dest_acc_en, output_dtype};
return operation::run(eltwise_op, {input}, {}, {optional_output_tensor}, queue_id).at(0);
}
Expand Down
2 changes: 1 addition & 1 deletion ttnn/cpp/ttnn/operations/unary.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ inline Tensor execute_on_worker_thread(
const std::vector<tt::tt_metal::UnaryWithParam>& op_chain,
const std::optional<MemoryConfig>& memory_config = std::nullopt,
const std::optional<Tensor>& optional_output_tensor = std::nullopt) {
DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast<DataType>(op_chain[0].params[0]) : input_tensor.get_dtype();
DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast<DataType>(op_chain[0].params[1]) : input_tensor.get_dtype();
bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or
input_tensor.get_dtype() == DataType::UINT32 or
input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to
Expand Down
Loading