Skip to content

Commit

Permalink
#11756: rename PreserveFP32Target to UnpackToDestMode for clarity
Browse files Browse the repository at this point in the history
  • Loading branch information
rdjogoTT committed Sep 23, 2024
1 parent cbca484 commit 8a07246
Show file tree
Hide file tree
Showing 22 changed files with 85 additions and 82 deletions.
14 changes: 10 additions & 4 deletions tt_metal/common/base_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ struct std::hash<MathFidelity>
}
};

enum class PreserveFP32Target : uint8_t
/**
* Specifies mode of operation for unpacking directly to Dest regsiter.
* Default mode enables all dataformats (except Float32) to be unpacked into Dest. Buffers
* with Default mode can be used to unpack to SRCA/B or Dest.
* UnpackToDestFp32 enables unpacking Float32 data to Dest with full precision, but makes
* the buffer incompatible with unpacking to SRCA/B.
*/
enum class UnpackToDestMode : uint8_t
{
SRCA_B = 0,
DEST = 1,
Disabled = 0xff,
UnpackToDestFp32,
Default
};
2 changes: 1 addition & 1 deletion tt_metal/impl/kernels/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ void ComputeKernel::set_build_options(JitBuildOptions &build_options) const {
build_options.set_hlk_math_fidelity_all_cores(this->config_.math_fidelity);
build_options.set_hlk_math_approx_mode_all_cores(this->config_.math_approx_mode);
build_options.fp32_dest_acc_en = this->config_.fp32_dest_acc_en;
build_options.preserve_fp32_precision = this->config_.preserve_fp32_precision;
build_options.unpack_to_dest_mode = this->config_.unpack_to_dest_mode;
build_options.hlk_defines = this->defines_;
}

Expand Down
2 changes: 1 addition & 1 deletion tt_metal/impl/kernels/kernel_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct WriterDataMovementConfig : public DataMovementConfig {
struct ComputeConfig {
MathFidelity math_fidelity = MathFidelity::HiFi4;
bool fp32_dest_acc_en = false;
std::vector<PreserveFP32Target> preserve_fp32_precision;
std::vector<UnpackToDestMode> unpack_to_dest_mode;
bool math_approx_mode = false;
std::vector<uint32_t> compile_args;
// Will cause CompileProgram to emit a file hlk_defines_generated.h
Expand Down
15 changes: 6 additions & 9 deletions tt_metal/jit_build/data_format.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,11 @@ std::vector<DataFormat> get_unpack_dst_formats(
DataFormat output_formats[NUM_OPERANDS],
DataFormat unpack_conditional_dst_format,
bool fp32_dest_acc_en,
std::vector<PreserveFP32Target> preserve_fp32_precision,
std::vector<UnpackToDestMode> unpack_to_dest_mode,
bool int_fpu_en)
{
if (!preserve_fp32_precision.empty()) {
TT_FATAL(preserve_fp32_precision.size() == NUM_CIRCULAR_BUFFERS, "preserve_fp32_precision vector must have 32 elements");
if (!unpack_to_dest_mode.empty()) {
TT_FATAL(unpack_to_dest_mode.size() == NUM_CIRCULAR_BUFFERS, "unpack_to_dest_mode vector must have 32 elements");
}

DataFormat pack_format = get_pack_data_format(output_formats, intermed_formats);
Expand All @@ -264,25 +264,22 @@ std::vector<DataFormat> get_unpack_dst_formats(
} else if (int_fpu_en) {
unpack_dst_format.push_back(src_format);
} else {
if (input_formats[i] == DataFormat::Float32 && !preserve_fp32_precision.empty() && preserve_fp32_precision[i] != PreserveFP32Target::Disabled) {
TT_FATAL(preserve_fp32_precision[i] == PreserveFP32Target::DEST, "preserve_fp32_precision is only available when unpack target is DEST register");
if (input_formats[i] == DataFormat::Float32 && !unpack_to_dest_mode.empty() && unpack_to_dest_mode[i] != UnpackToDestMode::Default) {
unpack_dst_format.push_back(get_single_unpack_dst_format(input_formats[i], pack_format, DataFormat::Float32));
} else {
unpack_dst_format.push_back(get_single_unpack_dst_format(input_formats[i], pack_format, unpack_cond_dst_format));
}
}
}
for (int i=0 ; i<NUM_OPERANDS ; i++) {
if (param_formats[i] == DataFormat::Float32 && !preserve_fp32_precision.empty() && preserve_fp32_precision[NUM_OPERANDS+i] != PreserveFP32Target::Disabled) {
TT_FATAL(preserve_fp32_precision[NUM_OPERANDS+i] == PreserveFP32Target::DEST, "preserve_fp32_precision is only available when unpack target is DEST register");
if (param_formats[i] == DataFormat::Float32 && !unpack_to_dest_mode.empty() && unpack_to_dest_mode[NUM_OPERANDS+i] != UnpackToDestMode::Default) {
unpack_dst_format.push_back(get_single_unpack_dst_format(param_formats[i], pack_format, DataFormat::Float32));
} else {
unpack_dst_format.push_back(get_single_unpack_dst_format(param_formats[i], pack_format, unpack_cond_dst_format));
}
}
for (int i=0 ; i<NUM_OPERANDS ; i++) {
if (intermed_formats[i] == DataFormat::Float32 && !preserve_fp32_precision.empty() && preserve_fp32_precision[3*NUM_OPERANDS+i] != PreserveFP32Target::Disabled) {
TT_FATAL(preserve_fp32_precision[3*NUM_OPERANDS+i] == PreserveFP32Target::DEST, "preserve_fp32_precision is only available when unpack target is DEST register");
if (intermed_formats[i] == DataFormat::Float32 && !unpack_to_dest_mode.empty() && unpack_to_dest_mode[3*NUM_OPERANDS+i] != UnpackToDestMode::Default) {
unpack_dst_format.push_back(get_single_unpack_dst_format(intermed_formats[i], pack_format, DataFormat::Float32));
} else {
unpack_dst_format.push_back(get_single_unpack_dst_format(intermed_formats[i], pack_format, unpack_cond_dst_format));
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/jit_build/data_format.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void check_valid_in_out_data_formats(DataFormat input_formats[NUM_OPERANDS], Dat
const DataFormat get_single_pack_src_format(DataFormat input_format, DataFormat output_format, DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, tt::ARCH arch);

std::vector<DataFormat> get_unpack_src_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS]);
std::vector<DataFormat> get_unpack_dst_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, std::vector<PreserveFP32Target> preserve_fp32_precision, bool int_fpu_en = false);
std::vector<DataFormat> get_unpack_dst_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, std::vector<UnpackToDestMode> unpack_to_dest_mode, bool int_fpu_en = false);
std::vector<DataFormat> get_pack_src_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS], DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, bool int_fpu_en = false, tt::ARCH arch = tt::ARCH::GRAYSKULL);
std::vector<DataFormat> get_pack_dst_formats(DataFormat input_formats[NUM_OPERANDS], DataFormat param_formats[NUM_OPERANDS], DataFormat intermed_formats[NUM_OPERANDS], DataFormat output_formats[NUM_OPERANDS]);

Expand Down
6 changes: 3 additions & 3 deletions tt_metal/jit_build/genfiles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,14 @@ static std::string create_formats_array_string(std::string array_type, std::stri
}

static std::pair<std::vector<DataFormat>, std::vector<DataFormat>>
generate_unpack_data_formats(tt_hlk_desc& desc, DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, std::vector<PreserveFP32Target> preserve_fp32_precision) {
generate_unpack_data_formats(tt_hlk_desc& desc, DataFormat unpack_conditional_dst_format, bool fp32_dest_acc_en, std::vector<UnpackToDestMode> unpack_to_dest_mode) {

vector<DataFormat> src_formats = tt::get_unpack_src_formats(
desc.input_buf_dataformat_arr, desc.param_buf_dataformat_arr, desc.intermediate_buf_dataformat_arr);

vector<DataFormat> dst_formats = tt::get_unpack_dst_formats(
desc.input_buf_dataformat_arr, desc.param_buf_dataformat_arr, desc.intermediate_buf_dataformat_arr,
desc.output_buf_dataformat_arr, unpack_conditional_dst_format, fp32_dest_acc_en, preserve_fp32_precision);
desc.output_buf_dataformat_arr, unpack_conditional_dst_format, fp32_dest_acc_en, unpack_to_dest_mode);

TT_ASSERT(src_formats.size() == 24 && dst_formats.size() == 24,
"There must be 8 unpack src/dst formats for each input, param, and intermediate operands.");
Expand Down Expand Up @@ -310,7 +310,7 @@ static void generate_data_format_descriptors(JitBuildOptions& options, const tt:
desc.intermediate_buf_dataformat_arr);

vector<DataFormat> unpack_src_formats_all_cbs, unpack_dst_formats_all_cbs;
tie(unpack_src_formats_all_cbs, unpack_dst_formats_all_cbs) = generate_unpack_data_formats(desc, unpack_conditional_dst_format, options.fp32_dest_acc_en, options.preserve_fp32_precision);
tie(unpack_src_formats_all_cbs, unpack_dst_formats_all_cbs) = generate_unpack_data_formats(desc, unpack_conditional_dst_format, options.fp32_dest_acc_en, options.unpack_to_dest_mode);

vector<DataFormat> pack_src_formats_all_cbs, pack_dst_formats_all_cbs;
tie(pack_src_formats_all_cbs, pack_dst_formats_all_cbs) = generate_pack_data_formats(desc, unpack_conditional_dst_format, options.fp32_dest_acc_en, arch);
Expand Down
2 changes: 1 addition & 1 deletion tt_metal/jit_build/settings.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class JitBuildOptions {

// We can keep for future WH support, otherwise not used in GS
bool fp32_dest_acc_en;
std::vector<PreserveFP32Target> preserve_fp32_precision;
std::vector<UnpackToDestMode> unpack_to_dest_mode;

// BRISC config
std::string brisc_kernel_file_name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_
MathFidelity math_fidelity,
bool fp32_dest_acc_en,
bool math_approx_mode,
vector<PreserveFP32Target> preserve_fp32_precision) {
vector<UnpackToDestMode> unpack_to_dest_mode) {
std::vector<KernelHandle> compute_kernel_ids{};
KernelHandle compute_kernel_id{};
for (auto arg : args) {
Expand All @@ -141,7 +141,7 @@ std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_
math_fidelity,
fp32_dest_acc_en,
math_approx_mode,
preserve_fp32_precision);
unpack_to_dest_mode);
compute_kernel_ids.push_back(compute_kernel_id);
}
return compute_kernel_ids;
Expand All @@ -155,7 +155,7 @@ std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_
MathFidelity math_fidelity,
bool fp32_dest_acc_en,
bool math_approx_mode,
vector<PreserveFP32Target> preserve_fp32_precision) {
vector<UnpackToDestMode> unpack_to_dest_mode) {
KernelHandle compute_kernel_id{0};
if (arg.num_tile_per_core_group > 0) {
compute_kernel_id = CreateKernel(
Expand All @@ -165,7 +165,7 @@ std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_
tt_metal::ComputeConfig{
.math_fidelity = math_fidelity,
.fp32_dest_acc_en = fp32_dest_acc_en,
.preserve_fp32_precision = preserve_fp32_precision,
.unpack_to_dest_mode = unpack_to_dest_mode,
.math_approx_mode = math_approx_mode,
.compile_args = arg.compile_args,
.defines = defines});
Expand Down Expand Up @@ -195,7 +195,7 @@ std::tuple<uint32_t, CoreRangeSet, CoreRangeSet, CoreRangeSet, uint32_t, uint32_
tt_metal::ComputeConfig{
.math_fidelity = config.math_fidelity,
.fp32_dest_acc_en = config.fp32_dest_acc_en,
.preserve_fp32_precision = config.preserve_fp32_precision,
.unpack_to_dest_mode = config.unpack_to_dest_mode,
.math_approx_mode = config.math_approx_mode,
.compile_args = arg.compile_args,
.defines = config.defines});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ struct ComputeKernelArg {
struct ComputeKernelConfig {
MathFidelity math_fidelity = MathFidelity::HiFi4;
bool fp32_dest_acc_en = false;
vector<PreserveFP32Target> preserve_fp32_precision;
vector<UnpackToDestMode> unpack_to_dest_mode;
bool math_approx_mode = false;
std::map<std::string, std::string> defines;
};
Expand All @@ -99,7 +99,7 @@ struct ComputeKernelConfig {
MathFidelity math_fidelity = MathFidelity::HiFi4,
bool fp32_dest_acc_en = false,
bool math_approx_mode = false,
vector<PreserveFP32Target> preserve_fp32_precision = {});
vector<UnpackToDestMode> unpack_to_dest_mode = {});

[[maybe_unused]] KernelHandle CreateComputeKernel(
Program &program,
Expand All @@ -109,7 +109,7 @@ struct ComputeKernelConfig {
MathFidelity math_fidelity = MathFidelity::HiFi4,
bool fp32_dest_acc_en = false,
bool math_approx_mode = false,
vector<PreserveFP32Target> preserve_fp32_precision = {});
vector<UnpackToDestMode> unpack_to_dest_mode = {});

[[maybe_unused]] std::vector<KernelHandle> CreateComputeKernel(
Program &program, const std::string &file_name, std::vector<ComputeKernelArg> args, ComputeKernelConfig config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,14 @@ operation::ProgramWithCallbacks moreh_sum_h_impl(const Tensor &a, const Tensor &
origin_H
};

// set preserve_fp32_precision to the same value as fp32_dest_acc_en
// bool preserve_fp32_precision = fp32_dest_acc_en;
vector<PreserveFP32Target> preserve_fp32_precision(NUM_CIRCULAR_BUFFERS, PreserveFP32Target::Disabled);
// set unpack_to_dest_mode to the same value as fp32_dest_acc_en
// bool unpack_to_dest_mode = fp32_dest_acc_en;
vector<UnpackToDestMode> unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default);
auto reduce_compute_kernel_group_1_id = tt_metal::CreateKernel(
program,
compute_kernel_name,
core_group_1,
tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .preserve_fp32_precision = preserve_fp32_precision, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_1, .defines = reduce_defines});
tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .unpack_to_dest_mode = unpack_to_dest_mode, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_1, .defines = reduce_defines});

if (!core_group_2.ranges().empty()) {
vector<uint32_t> compute_kernel_args_group_2 = {
Expand All @@ -175,7 +175,7 @@ operation::ProgramWithCallbacks moreh_sum_h_impl(const Tensor &a, const Tensor &
program,
compute_kernel_name,
core_group_2,
tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .preserve_fp32_precision = preserve_fp32_precision, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_2, .defines = reduce_defines});
tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .unpack_to_dest_mode = unpack_to_dest_mode, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_2, .defines = reduce_defines});
}

for (uint32_t i = 0, num_cols_read = 0; i < num_cores; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ operation::ProgramWithCallbacks moreh_sum_nc_impl(const Tensor &input, const Ten
if (fp32_dest_acc_en) {
compute_defines["FP32_DEST_ACC_EN"] = "1";
}
// set preserve_fp32_precision to the same value as fp32_dest_acc_en
// bool preserve_fp32_precision = fp32_dest_acc_en;
vector<PreserveFP32Target> preserve_fp32_precision(NUM_CIRCULAR_BUFFERS, PreserveFP32Target::Disabled);
// set unpack_to_dest_mode to the same value as fp32_dest_acc_en
// bool unpack_to_dest_mode = fp32_dest_acc_en;
vector<UnpackToDestMode> unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default);
auto compute_kernel_file = "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/kernels/moreh_sum_nc.cpp";
if (device->arch() == tt::ARCH::GRAYSKULL) {
compute_kernel_file = "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/kernels/moreh_sum_nc_gs.cpp";
Expand All @@ -111,7 +111,7 @@ operation::ProgramWithCallbacks moreh_sum_nc_impl(const Tensor &input, const Ten
math_fidelity,
fp32_dest_acc_en,
math_approx_mode,
preserve_fp32_precision);
unpack_to_dest_mode);

std::optional<KernelHandle> compute_kernel_2_id = std::nullopt;
if (!core_group_2.ranges().empty()) {
Expand All @@ -124,7 +124,7 @@ operation::ProgramWithCallbacks moreh_sum_nc_impl(const Tensor &input, const Ten
math_fidelity,
fp32_dest_acc_en,
math_approx_mode,
preserve_fp32_precision);
unpack_to_dest_mode);
}

////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,14 @@ operation::ProgramWithCallbacks moreh_sum_w_impl(const Tensor &a, const Tensor &
origin_W,
};

// set preserve_fp32_precision to the same value as fp32_dest_acc_en
// bool preserve_fp32_precision = fp32_dest_acc_en;
vector<PreserveFP32Target> preserve_fp32_precision(NUM_CIRCULAR_BUFFERS, PreserveFP32Target::Disabled);
// set unpack_to_dest_mode to the same value as fp32_dest_acc_en
// bool unpack_to_dest_mode = fp32_dest_acc_en;
vector<UnpackToDestMode> unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default);
auto reduce_compute_kernel_group_1_id = tt_metal::CreateKernel(
program,
compute_kernel_name,
core_group_1,
tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .preserve_fp32_precision = preserve_fp32_precision, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_1, .defines = reduce_defines});
tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .unpack_to_dest_mode = unpack_to_dest_mode, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_1, .defines = reduce_defines});

if (!core_group_2.ranges().empty()) {
vector<uint32_t> compute_kernel_args_group_2 = {
Expand All @@ -168,7 +168,7 @@ operation::ProgramWithCallbacks moreh_sum_w_impl(const Tensor &a, const Tensor &
program,
compute_kernel_name,
core_group_2,
tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .preserve_fp32_precision = preserve_fp32_precision, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_2, .defines = reduce_defines});
tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, .unpack_to_dest_mode = unpack_to_dest_mode, .math_approx_mode = math_approx_mode, .compile_args = compute_kernel_args_group_2, .defines = reduce_defines});
}

uint32_t out_dim_divider = Wt;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create(
1 // per_core_block_size
};

vector<PreserveFP32Target> preserve_fp32_precision(NUM_CIRCULAR_BUFFERS, PreserveFP32Target::Disabled);
vector<UnpackToDestMode> unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default);
if (args.preserve_fp32_precision) {
preserve_fp32_precision[src0_cb_index] = PreserveFP32Target::DEST;
unpack_to_dest_mode[src0_cb_index] = UnpackToDestMode::UnpackToDestFp32;
}

bool math_approx_mode = std::all_of(
Expand All @@ -100,7 +100,7 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create(
tt::tt_metal::ComputeConfig{
.math_fidelity = MathFidelity::HiFi4,
.fp32_dest_acc_en = args.fp32_dest_acc_en,
.preserve_fp32_precision = preserve_fp32_precision,
.unpack_to_dest_mode = unpack_to_dest_mode,
.math_approx_mode = math_approx_mode,
.compile_args = compute_kernel_args_group_1,
.defines = unary_defines});
Expand All @@ -118,7 +118,7 @@ UnaryProgramFactory::cached_program_t UnaryProgramFactory::create(
tt::tt_metal::ComputeConfig{
.math_fidelity = MathFidelity::HiFi4,
.fp32_dest_acc_en = args.fp32_dest_acc_en,
.preserve_fp32_precision = preserve_fp32_precision,
.unpack_to_dest_mode = unpack_to_dest_mode,
.math_approx_mode = math_approx_mode,
.compile_args = compute_kernel_args_group_2,
.defines = unary_defines});
Expand Down
Loading

0 comments on commit 8a07246

Please sign in to comment.