Skip to content

Commit

Permalink
#7159: Fix softmax sharded cache hit to update sharded mask address
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Jun 6, 2024
1 parent cc5660c commit dbfa65a
Showing 1 changed file with 17 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(

tt::DataFormat out0_cb_data_format = tt_metal::datatype_to_dataformat_converter(output_tensor.get_dtype());
tt::DataFormat im_cb_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : tt::DataFormat::Float16_b;
tt::DataFormat mask_cb_data_format = mask.has_value() ? tt_metal::datatype_to_dataformat_converter(mask.value().get_dtype()) : tt::DataFormat::Float16_b;
tt::DataFormat mask_cb_data_format = mask.has_value() ? tt_metal::datatype_to_dataformat_converter(mask->get_dtype()) : tt::DataFormat::Float16_b;
tt::DataFormat scale_cb_data_format = tt::DataFormat::Float16_b;
tt::DataFormat scalar_cb_data_format = tt::DataFormat::Float16_b;

Expand All @@ -480,7 +480,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(

uint32_t mask_H = shape[2];
if (mask.has_value()) {
mask_H = mask.value().get_legacy_shape()[2];
mask_H = mask->get_legacy_shape()[2];
}
uint32_t mask_Ht = mask_H/TILE_HEIGHT;
// block
Expand Down Expand Up @@ -552,17 +552,17 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(
// reader compile arg
bool is_dram_mask = 0;
if (mask.has_value()) {
is_dram_mask = mask.value().buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
is_dram_mask = mask->buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0;
}
std::vector<uint32_t> reader_compile_time_args = {
(std::uint32_t) block_wt,
(std::uint32_t) is_dram_mask
};
std::map<string, string> softmax_defines;
// hw_dims_only_causal_mask does not support RM Layout atm
bool use_row_major_kernel = (mask.has_value() and mask.value().get_layout() == Layout::ROW_MAJOR);
bool use_row_major_kernel = (mask.has_value() and mask->get_layout() == Layout::ROW_MAJOR);
if (use_row_major_kernel) {
auto mask_stick_size = mask.value().get_legacy_shape()[3] * mask.value().element_size();
auto mask_stick_size = mask->get_legacy_shape()[3] * mask->element_size();
bool mask_stick_size_is_power_of_two = is_power_of_two_at_least_32(mask_stick_size);
reader_compile_time_args.push_back((std::uint32_t) mask_stick_size_is_power_of_two);
if (mask_stick_size_is_power_of_two) {
Expand Down Expand Up @@ -648,8 +648,8 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(
.set_page_size(CB::c_in2, scale_tile_size);
cb_in2_id = CreateCircularBuffer(program, all_device_cores, c_in2_config);
// in3 attn mask
if (mask.value().is_sharded()) {
auto mask_buffer = mask.value().buffer();
if (mask->is_sharded()) {
auto mask_buffer = mask->buffer();
auto c_in3_config = CircularBufferConfig(in3_CB_size, {{CB::c_in3, mask_cb_data_format}})
.set_page_size(CB::c_in3, mask_tile_size).set_globally_allocated_address(*mask_buffer);
cb_in3_id = CreateCircularBuffer( program, all_device_cores, c_in3_config);
Expand All @@ -673,7 +673,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(
auto cb_intermed1_id = CreateCircularBuffer( program, all_device_cores, c_intermed1_config );

// Runtime Args
uint32_t mask_addr = mask.has_value() ? mask.value().buffer()->address() : 0;
uint32_t mask_addr = mask.has_value() ? mask->buffer()->address() : 0;
union { float f; uint32_t u; } s; s.f = scale.value_or(1.0f); // scale for fused scale-mask-softmax
uint32_t mask_start_tile_id = 0;

Expand Down Expand Up @@ -712,9 +712,9 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(
num_cores_per_batch_index = 0;
if (mask.has_value()) {
if (causal_mask) {
mask_start_tile_id += mask.value().get_legacy_shape()[-1] * mask.value().get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT;
mask_start_tile_id += mask->get_legacy_shape()[-1] * mask->get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT;
} else {
mask_start_tile_id += use_row_major_kernel ? mask.value().get_legacy_shape()[-2] : mask.value().get_legacy_shape()[-1] / TILE_WIDTH;
mask_start_tile_id += use_row_major_kernel ? mask->get_legacy_shape()[-2] : mask->get_legacy_shape()[-1] / TILE_WIDTH;
}
}
}
Expand Down Expand Up @@ -748,9 +748,9 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(
num_cores_per_batch_index = 0;
if (mask.has_value()) {
if (causal_mask) {
mask_start_tile_id += mask.value().get_legacy_shape()[-1] * mask.value().get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT;
mask_start_tile_id += mask->get_legacy_shape()[-1] * mask->get_legacy_shape()[-2] / TILE_WIDTH / TILE_HEIGHT;
} else {
mask_start_tile_id += use_row_major_kernel ? mask.value().get_legacy_shape()[-2] : mask.value().get_legacy_shape()[-1] / TILE_WIDTH;
mask_start_tile_id += use_row_major_kernel ? mask->get_legacy_shape()[-2] : mask->get_legacy_shape()[-1] / TILE_WIDTH;
}
}
}
Expand All @@ -763,6 +763,7 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(
reader_kernels_id,
cb_in0_id,
cb_out0_id,
cb_in3_id,
num_cores,
grid_size
]
Expand All @@ -779,12 +780,15 @@ operation::ProgramWithCallbacks scale_mask_softmax_sharded_multi_core(

UpdateDynamicCircularBufferAddress(program, cb_in0_id, *in0_buffer);
UpdateDynamicCircularBufferAddress(program, cb_out0_id, *out_buffer);
if (mask_tensor.has_value() && mask_tensor->is_sharded()) {
UpdateDynamicCircularBufferAddress(program, cb_in3_id.value(), *mask_tensor->buffer());
}

if (mask_tensor.has_value()) {
for (uint32_t i = 0; i < num_cores; ++i) {
CoreCoord core = {i % grid_size.x, i / grid_size.x};
auto &runtime_args = GetRuntimeArgs(program, reader_kernels_id, core);
runtime_args[2] = mask_tensor.value().buffer()->address();
runtime_args[2] = mask_tensor->buffer()->address();
}
}
};
Expand Down

0 comments on commit dbfa65a

Please sign in to comment.