Skip to content

Commit

Permalink
#8837: Optimize cache hit updating RTAs for some ops
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed May 25, 2024
1 parent 2f4a4e2 commit ce0fb7a
Show file tree
Hide file tree
Showing 10 changed files with 1,754 additions and 1,492 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -408,11 +408,13 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
.compile_args = in1_sender_writer_compile_time_args,
.defines = mm_kernel_in1_sender_writer_defines});

auto in1_receiver =
(CoreRangeSet)(std::set<CoreRange>){in0_sender_in1_receiver, in0_receiver_in1_receiver_left_half};
auto mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp",
/* in0_sender_in1_receiver, // If not using half-half noc setup */
(CoreRangeSet)(std::set<CoreRange>){in0_sender_in1_receiver, in0_receiver_in1_receiver_left_half},
in1_receiver,
tt_metal::DataMovementConfig{
.processor = tt_metal::DataMovementProcessor::RISCV_0,
.noc = in1_noc,
Expand Down Expand Up @@ -639,9 +641,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
uint32_t last_block_padded_block_tiles_h_skip =
(per_core_M / out_subblock_h - last_block_num_nonzero_subblocks_h) * (per_core_N * out_subblock_h);

std::vector<KernelHandle> reader_kernel_ids;
std::vector<KernelHandle> writer_kernel_ids;

uint32_t diff_start_coord;
uint32_t diff_end_coord;
std::vector<uint32_t> in0_mcast_noc_x;
Expand All @@ -668,6 +667,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
}

const auto& cores = grid_to_cores(all_cores.start, all_cores.end, true);
const auto& in0_sender_cores = grid_to_cores(in0_sender.start, in0_sender.end, true);
const auto& in1_sender_cores = grid_to_cores(in1_sender.start, in1_sender.end, true);
const auto& in1_receiver_cores = corerange_to_cores(in1_receiver, std::nullopt, true);
const auto& in1_receiver_other_cores =
grid_to_cores(in0_receiver_in1_receiver_right_half.start, in0_receiver_in1_receiver_right_half.end, true);
for (const auto& core : cores) {
CoreCoord left_core = {(std::size_t)start_core_x, (std::size_t)core.y};
CoreCoord left_core_plus_one = {(std::size_t)start_core_x + 1, (std::size_t)core.y};
Expand Down Expand Up @@ -733,7 +737,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
mm_in0_sender_args.push_back(worker_shard_same_coord);
}
tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_sender_args); // RISCV_0_default
reader_kernel_ids.push_back(mm_kernel_in0_sender_id);
} else if (in1_idx == 0) {
std::vector<uint32_t> mm_in0_sender_args = {
// in0 tensor args
Expand All @@ -753,7 +756,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
}

tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_sender_args); // RISCV_0_default
reader_kernel_ids.push_back(mm_kernel_in0_sender_id);

// in0 receiver
} else {
Expand All @@ -765,13 +767,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
// left half
if (core.x <= half_core || (!transpose_mcast and core.y == start_core_y)) {
tt_metal::SetRuntimeArgs(program, mm_kernel_in0_receiver_id, core, mm_in0_receiver_args);
reader_kernel_ids.push_back(mm_kernel_in0_receiver_id);
}
// right half
else {
tt_metal::SetRuntimeArgs(
program, mm_kernel_in0_receiver_other_noc_setup_id, core, mm_in0_receiver_args);
reader_kernel_ids.push_back(mm_kernel_in0_receiver_other_noc_setup_id);
}
}

Expand Down Expand Up @@ -826,7 +826,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
}
tt_metal::SetRuntimeArgs(
program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); // RISCV_1_default
writer_kernel_ids.push_back(mm_kernel_in1_sender_writer_id);

// in1 receiver
} else {
Expand Down Expand Up @@ -883,20 +882,24 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
// left half
if (core.x <= half_core || (transpose_mcast and core.y == start_core_y)) {
tt_metal::SetRuntimeArgs(program, mm_kernel_in1_receiver_writer_id, core, mm_in1_receiver_writer_args);
writer_kernel_ids.push_back(mm_kernel_in1_receiver_writer_id);
}
// right half
else {
tt_metal::SetRuntimeArgs(
program, mm_kernel_in1_receiver_writer_other_noc_setup_id, core, mm_in1_receiver_writer_args);
writer_kernel_ids.push_back(mm_kernel_in1_receiver_writer_other_noc_setup_id);
}
}
}

auto override_runtime_arguments_callback =
[reader_kernel_ids,
writer_kernel_ids,
[mm_kernel_in0_sender_id,
in0_sender_cores,
mm_kernel_in1_sender_writer_id,
in1_sender_cores,
mm_kernel_in1_receiver_writer_id,
in1_receiver_cores,
mm_kernel_in1_receiver_writer_other_noc_setup_id,
in1_receiver_other_cores,
cb_src2,
cb_output,
num_cores_r,
Expand All @@ -910,56 +913,58 @@ operation::ProgramWithCallbacks create_program_mcast_in0_in1(
const std::vector<Tensor>& input_tensors,
const std::vector<std::optional<const Tensor>>& optional_input_tensors,
const std::vector<Tensor>& output_tensors) {
TT_FATAL(input_tensors.size() + optional_input_tensors.size() == 3);
TT_FATAL(output_tensors.size() == 1);
TT_ASSERT(input_tensors.size() + optional_input_tensors.size() == 3);
TT_ASSERT(output_tensors.size() == 1);

auto src_buffer_a = input_tensors.at(0).buffer();
auto src_buffer_b = input_tensors.at(1).buffer();
auto bias_tensor = optional_input_tensors.at(0);

auto dst_buffer = output_tensors.at(0).buffer();

bool src0_sharded = input_tensors.at(0).memory_config().is_sharded();
bool out_sharded = output_tensors.at(0).memory_config().is_sharded();

for (uint32_t i = 0; i < cores.size(); ++i) {
const CoreCoord& core = cores[i];

auto reader_kernel_id = reader_kernel_ids.at(i);
auto& reader_runtime_args = GetRuntimeArgs(program, reader_kernel_id, core);

auto writer_kernel_id = writer_kernel_ids.at(i);
auto& writer_runtime_args = GetRuntimeArgs(program, writer_kernel_id, core);

uint32_t in0_idx = core.y - start_core_y;
uint32_t in1_idx = core.x - start_core_x;
bool src0_sharded = input_tensors[0].memory_config().is_sharded();
bool out_sharded = output_tensors[0].memory_config().is_sharded();

if (transpose_mcast) {
std::swap(in0_idx, in1_idx);
}
std::optional<Buffer*> bias_buffer;
if (bias_tensor.has_value()) {
bias_buffer = bias_tensor.value().buffer();
}

// in0 sender
if (!src0_sharded && in1_idx == 0) {
// in0 sender
if (src0_sharded) {
UpdateDynamicCircularBufferAddress(program, cb_src2, *src_buffer_a);
} else {
auto& reader_sender_runtime_args_by_core = GetRuntimeArgs(program, mm_kernel_in0_sender_id);
for (const auto& core : in0_sender_cores) {
auto& reader_runtime_args = reader_sender_runtime_args_by_core[core.x][core.y];
reader_runtime_args[0] = src_buffer_a->address();
// in0 receiver
} else {
}
}

// in1 sender
if (in0_idx == 0) {
writer_runtime_args[0] = src_buffer_b->address();
writer_runtime_args[6] = dst_buffer->address();
if (bias_tensor.has_value()) {
writer_runtime_args[16] = bias_tensor.value().buffer()->address();
}
// in1 receiver
} else {
writer_runtime_args[2] = dst_buffer->address();
// in1 sender
auto& sender_writer_runtime_args_by_core = GetRuntimeArgs(program, mm_kernel_in1_sender_writer_id);
for (const auto& core : in1_sender_cores) {
auto& writer_runtime_args = sender_writer_runtime_args_by_core[core.x][core.y];
writer_runtime_args[0] = src_buffer_b->address();
writer_runtime_args[6] = dst_buffer->address();
if (bias_tensor.has_value()) {
writer_runtime_args[16] = (*bias_buffer)->address();
}
}

if (src0_sharded) {
UpdateDynamicCircularBufferAddress(program, cb_src2, *src_buffer_a);
// in1 receiver
auto& receiver_writer_runtime_args_by_core = GetRuntimeArgs(program, mm_kernel_in1_receiver_writer_id);
for (const auto& core : in1_receiver_cores) {
auto& writer_runtime_args = receiver_writer_runtime_args_by_core[core.x][core.y];
writer_runtime_args[2] = dst_buffer->address();
}
if (mm_kernel_in1_receiver_writer_id != mm_kernel_in1_receiver_writer_other_noc_setup_id) {
auto& receiver_writer_runtime_args_by_core =
GetRuntimeArgs(program, mm_kernel_in1_receiver_writer_other_noc_setup_id);
for (const auto& core : in1_receiver_other_cores) {
auto& writer_runtime_args = receiver_writer_runtime_args_by_core[core.x][core.y];
writer_runtime_args[2] = dst_buffer->address();
}
}

if (out_sharded) {
Expand Down
Loading

0 comments on commit ce0fb7a

Please sign in to comment.