Skip to content

Commit

Permalink
#0: Properly support trivial single core case for 1D matmuls
Browse files Browse the repository at this point in the history
- For mcast in1, only mcast to cores with work (similar to mcast in0)
- For single core, skip receiver kernel setup
- TODO: For sharded in0, K must be divisible by in0_block_w due to separate bugs:
  * For mcast in0, the sharded reader doesn't support turning off mcast if it's single core
    ** Bug here is that, we do regular mcast which is mcasting to 0 cores...
  * For mcast in1, the sharded reader doesn't support slicing along shard width
  • Loading branch information
TT-BrianLiu committed May 28, 2024
1 parent 7ab17bd commit bc19f9c
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,64 @@ class TensorMemoryConfigs(enum.Enum):
ttnn.L1_MEMORY_CONFIG,
ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
),
# Matmul 1D mcast in0 (single core)
(
(1,),
(64, 64, 128),
False,
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=(1, 1),
in0_block_w=2,
out_subblock_h=1,
out_subblock_w=1,
per_core_M=2,
per_core_N=4,
fuse_batch=True,
fused_activation=None,
mcast_in0=True,
),
ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED,
buffer_type=ttnn.BufferType.L1,
shard_spec=ttnn.ShardSpec(
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0))}),
(64, 64),
ttnn.ShardOrientation.ROW_MAJOR,
False,
),
),
ttnn.L1_MEMORY_CONFIG,
ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG,
),
# Matmul 1D mcast in1 (single core)
(
(1,),
(64, 64, 128),
False,
ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCast1DProgramConfig(
compute_with_storage_grid_size=(1, 1),
in0_block_w=2,
out_subblock_h=1,
out_subblock_w=1,
per_core_M=2,
per_core_N=4,
fuse_batch=True,
fused_activation=None,
mcast_in0=False,
),
ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
buffer_type=ttnn.BufferType.L1,
shard_spec=ttnn.ShardSpec(
ttnn.CoreRangeSet({ttnn.CoreRange(ttnn.CoreCoord(0, 0), ttnn.CoreCoord(0, 0))}),
(64, 64),
ttnn.ShardOrientation.ROW_MAJOR,
False,
),
),
ttnn.L1_MEMORY_CONFIG,
ttnn.L1_HEIGHT_SHARDED_MEMORY_CONFIG,
),
# Matmul 2D mcast
(
(1,),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ operation::ProgramWithCallbacks create_program_mcast_in0(
CoreRangeSet in0_mcast_cores_with_work_and_in_receiver_grid({});
CoreRangeSet in0_mcast_cores_without_work_and_in_receiver_grid({});
CoreRangeSet in0_mcast_cores_without_work_and_not_in_receiver_grid({});
CoreRangeSet mcast_receivers({});
CoreRangeSet in0_mcast_receivers({});
if (in0_is_sharded) {
in0_mcast_cores_with_work_and_in_receiver_grid = all_cores_with_work;

Expand Down Expand Up @@ -166,11 +166,13 @@ operation::ProgramWithCallbacks create_program_mcast_in0(
}
} else {
in0_mcast_cores_with_work_and_in_receiver_grid = CoreRangeSet({CoreRange(start_core, start_core)});
auto receiver_start_core = start_core.x != (compute_with_storage_grid_size.x - 1)
? CoreCoord{start_core.x + 1, start_core.y}
: CoreCoord{start_core.x, start_core.y + 1};
mcast_receivers =
num_cores_to_corerange_set(receiver_start_core, num_cores - 1, compute_with_storage_grid_size, row_major);
if (in0_mcast_receiver_num_cores > 1) {
auto receiver_start_core = start_core.x != (compute_with_storage_grid_size.x - 1)
? CoreCoord{start_core.x + 1, start_core.y}
: CoreCoord{start_core.x, start_core.y + 1};
in0_mcast_receivers = num_cores_to_corerange_set(
receiver_start_core, num_cores - 1, compute_with_storage_grid_size, row_major);
}
}

// Mcast args
Expand Down Expand Up @@ -379,11 +381,11 @@ operation::ProgramWithCallbacks create_program_mcast_in0(
}

KernelHandle mm_kernel_in0_receiver_id = 0;
if (!in0_is_sharded and mcast_receivers.num_cores() > 0) {
if (!in0_is_sharded and in0_mcast_receivers.num_cores() > 0) {
mm_kernel_in0_receiver_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver.cpp",
mcast_receivers,
in0_mcast_receivers,
tt_metal::DataMovementConfig{
.processor = tt_metal::DataMovementProcessor::RISCV_1,
.noc = in0_noc,
Expand Down Expand Up @@ -855,36 +857,36 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
CoreCoord start_core = {0, 0};
uint32_t start_core_x = start_core.x;
uint32_t start_core_y = start_core.y;
uint32_t num_cores_c = compute_with_storage_grid_size.x;
uint32_t num_cores_r = compute_with_storage_grid_size.y;

uint32_t num_blocks_y = (M - 1) / per_core_M + 1;
uint32_t num_blocks_x = (N - 1) / per_core_N + 1;
uint32_t num_blocks_total = num_blocks_y * num_blocks_x;
uint32_t num_cores = num_blocks_total;
uint32_t mcast_num_cores = num_cores_c * num_cores_r; // Exclude Sender

constexpr bool row_major = true;
CoreRangeSet all_cores =
num_cores_to_corerange_set(start_core, num_cores, compute_with_storage_grid_size, row_major);
CoreRange in1_mcast_receiver_cores_bounding_box = all_cores.bounding_box();
uint32_t in1_mcast_receiver_num_cores = in1_mcast_receiver_cores_bounding_box.size(); // always mcast to full grid

CoreRange mcast_sender(start_core, start_core);

auto receiver_start_core = start_core.x != (compute_with_storage_grid_size.x - 1)
? CoreCoord{start_core.x + 1, start_core.y}
: CoreCoord{start_core.x, start_core.y + 1};
CoreRangeSet mcast_receivers =
num_cores_to_corerange_set(receiver_start_core, num_cores - 1, compute_with_storage_grid_size, row_major);
CoreRange in1_mcast_sender(start_core, start_core);
CoreRangeSet in1_mcast_receivers({});
if (in1_mcast_receiver_num_cores > 1) {
auto receiver_start_core = start_core.x != (compute_with_storage_grid_size.x - 1)
? CoreCoord{start_core.x + 1, start_core.y}
: CoreCoord{start_core.x, start_core.y + 1};
in1_mcast_receivers =
num_cores_to_corerange_set(receiver_start_core, num_cores - 1, compute_with_storage_grid_size, row_major);
}

// Mcast args
auto in1_mcast_sender_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID);
auto in1_mcast_receiver_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID);
uint32_t in3_mcast_sender_semaphore = 0;
uint32_t in3_mcast_receiver_semaphore = 0;

CoreCoord top_left_core = {(std::size_t)start_core_x, (std::size_t)start_core_y};
CoreCoord bottom_right_core = {
(std::size_t)start_core_x + num_cores_c - 1, (std::size_t)start_core_y + num_cores_r - 1};
CoreCoord top_left_core = in1_mcast_receiver_cores_bounding_box.start;
CoreCoord bottom_right_core = in1_mcast_receiver_cores_bounding_box.end;
auto top_left_core_physical = device->worker_core_from_logical_core(top_left_core);
auto bottom_right_core_physical = device->worker_core_from_logical_core(bottom_right_core);

Expand Down Expand Up @@ -937,8 +939,8 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
// in1 mcast args
(std::uint32_t)in1_mcast_sender_semaphore,
(std::uint32_t)in1_mcast_receiver_semaphore,
(std::uint32_t)num_cores - 1, // in1_mcast_num_dests
(std::uint32_t)mcast_num_cores - 1, // in1_mcast_num_cores
(std::uint32_t)num_cores - 1, // in1_mcast_num_dests
(std::uint32_t)in1_mcast_receiver_num_cores - 1, // in1_mcast_num_cores
// batch args
(std::uint32_t)K * N, // KtNt
(std::uint32_t)B, // batch
Expand Down Expand Up @@ -1027,6 +1029,10 @@ operation::ProgramWithCallbacks create_program_mcast_in1(

mm_kernel_in0_sender_defines["SKIP_MCAST"] = "1";

if (in1_mcast_receiver_num_cores == 1) {
mm_kernel_in1_sender_writer_defines["SKIP_MCAST"] = "1";
}

// in1 is the reader of weights/output writer, and we choose to make it use the optimized reader noc
tt_metal::NOC in0_noc = detail::GetPreferredNOCForDRAMWrite(device->arch());
tt_metal::NOC in1_noc = detail::GetPreferredNOCForDRAMRead(device->arch());
Expand All @@ -1044,23 +1050,26 @@ operation::ProgramWithCallbacks create_program_mcast_in1(
auto mm_kernel_in1_sender_writer_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_writer_padding.cpp",
mcast_sender,
in1_mcast_sender,
tt_metal::DataMovementConfig{
.processor = tt_metal::DataMovementProcessor::RISCV_0,
.noc = in1_noc,
.compile_args = in1_sender_writer_compile_time_args,
.defines = mm_kernel_in1_sender_writer_defines});

auto mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/"
"reader_bmm_tile_layout_in1_receiver_writer_padding.cpp",
mcast_receivers,
tt_metal::DataMovementConfig{
.processor = tt_metal::DataMovementProcessor::RISCV_0,
.noc = in1_noc,
.compile_args = in1_receiver_writer_compile_time_args,
.defines = mm_kernel_in1_receiver_writer_defines});
KernelHandle mm_kernel_in1_receiver_writer_id = 0;
if (in1_mcast_receivers.num_cores() > 0) {
mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel(
program,
"tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/"
"reader_bmm_tile_layout_in1_receiver_writer_padding.cpp",
in1_mcast_receivers,
tt_metal::DataMovementConfig{
.processor = tt_metal::DataMovementProcessor::RISCV_0,
.noc = in1_noc,
.compile_args = in1_receiver_writer_compile_time_args,
.defines = mm_kernel_in1_receiver_writer_defines});
}

// Compute kernel compile time args

Expand Down Expand Up @@ -1541,71 +1550,67 @@ operation::ProgramWithCallbacks matmul_multi_core_reuse_mcast_1d_optimized_(
////////////////////////////////////////////////////////////////////////////
// Application Setup
////////////////////////////////////////////////////////////////////////////
if (compute_with_storage_grid_size.x > 1 or compute_with_storage_grid_size.y > 1) {
if (mcast_in0) {
return reuse_mcast_1d_optimized_helpers::create_program_mcast_in0(
a,
device,
math_fidelity,
fp32_dest_acc_en,
math_approx_mode,
packer_l1_acc,
compute_with_storage_grid_size,
B,
Mt,
Nt,
Kt,
bcast_batch,
in0_block_w,
out_subblock_h,
out_subblock_w,
per_core_M,
per_core_N,
fused_activation,
in0_buffer,
in1_buffer,
bias_buffer,
out_buffer,
in0_data_format,
in1_data_format,
bias_data_format,
output_data_format,
a.memory_config().is_sharded(),
output.memory_config().is_sharded(),
untilize_out);
} else {
return reuse_mcast_1d_optimized_helpers::create_program_mcast_in1(
device,
math_fidelity,
fp32_dest_acc_en,
math_approx_mode,
packer_l1_acc,
compute_with_storage_grid_size,
B,
Mt,
Nt,
Kt,
bcast_batch,
in0_block_w,
out_subblock_h,
out_subblock_w,
per_core_M,
per_core_N,
fused_activation,
in0_buffer,
in1_buffer,
bias_buffer,
out_buffer,
in0_data_format,
in1_data_format,
bias_data_format,
output_data_format,
a.memory_config().is_sharded(),
output.memory_config().is_sharded(),
untilize_out);
}
if (mcast_in0) {
return reuse_mcast_1d_optimized_helpers::create_program_mcast_in0(
a,
device,
math_fidelity,
fp32_dest_acc_en,
math_approx_mode,
packer_l1_acc,
compute_with_storage_grid_size,
B,
Mt,
Nt,
Kt,
bcast_batch,
in0_block_w,
out_subblock_h,
out_subblock_w,
per_core_M,
per_core_N,
fused_activation,
in0_buffer,
in1_buffer,
bias_buffer,
out_buffer,
in0_data_format,
in1_data_format,
bias_data_format,
output_data_format,
a.memory_config().is_sharded(),
output.memory_config().is_sharded(),
untilize_out);
} else {
TT_FATAL(false, "Grid is invalid for mcast matmul!");
return reuse_mcast_1d_optimized_helpers::create_program_mcast_in1(
device,
math_fidelity,
fp32_dest_acc_en,
math_approx_mode,
packer_l1_acc,
compute_with_storage_grid_size,
B,
Mt,
Nt,
Kt,
bcast_batch,
in0_block_w,
out_subblock_h,
out_subblock_w,
per_core_M,
per_core_N,
fused_activation,
in0_buffer,
in1_buffer,
bias_buffer,
out_buffer,
in0_data_format,
in1_data_format,
bias_data_format,
output_data_format,
a.memory_config().is_sharded(),
output.memory_config().is_sharded(),
untilize_out);
}
}

Expand Down

0 comments on commit bc19f9c

Please sign in to comment.