diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp index 221e5c3011a..8e10559a870 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_receive_writer.cpp @@ -53,18 +53,20 @@ void kernel_main() { constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(4); constexpr uint32_t ring_size = get_compile_time_arg_val(5); constexpr bool fuse_op = get_compile_time_arg_val(6); + constexpr uint32_t output_tile_size = get_compile_time_arg_val(7); + ASSERT(half_cb_n_pages > rem_num_pages); #ifdef SHARDED_MEM_LAYOUT - constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast(get_compile_time_arg_val(7)); - constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(8); - constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(9); - constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(10); - constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(11); - constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(12); - constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(13); - constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(14) != 0; + constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast(get_compile_time_arg_val(8)); + constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(9); + constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(10); + constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(11); + constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(12); + constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(13); + constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(14); + constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(15) != 0; #endif constexpr uint32_t cb_id_in0 = tt::CB::c_in0; @@ -94,7 +96,7 @@ void kernel_main() { #ifdef INTERLEAVED_MEM_LAYOUT const DataFormat in0_df = get_dataformat(cb_id_in0); - InterleavedAddrGenFast d = { + InterleavedAddrGenFast d = { .bank_base_address = dst_addr, .page_size = output_page_size, .data_format = in0_df diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp index 6b6a43fc7ee..9d937029a5d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_reader.cpp @@ -55,25 +55,27 @@ void kernel_main() { uint32_t sem_addr = get_semaphore(get_compile_time_arg_val(5)); constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(6); constexpr uint32_t ring_size = get_compile_time_arg_val(7); + constexpr uint32_t input_tile_size = get_compile_time_arg_val(8); + constexpr uint32_t output_tile_size = get_compile_time_arg_val(9); #ifdef SHARDED_MEM_LAYOUT - constexpr tt::tt_metal::TensorMemoryLayout input_tensor_memory_layout = static_cast(get_compile_time_arg_val(8)); - constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(9); - constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(10); - constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(11); - constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(12); - constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(13); - constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(14); - constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(15) != 0; - - constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast(get_compile_time_arg_val(16)); - constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(17); - constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(18); - constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(19); - constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(20); - constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(21); - constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(22); - constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(23) != 0; + constexpr tt::tt_metal::TensorMemoryLayout input_tensor_memory_layout = static_cast(get_compile_time_arg_val(10)); + constexpr uint32_t input_tensor_shard_grid_height = get_compile_time_arg_val(11); + constexpr uint32_t input_tensor_shard_grid_width = get_compile_time_arg_val(12); + constexpr uint32_t input_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(13); + constexpr uint32_t input_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(14); + constexpr uint32_t input_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(15); + constexpr uint32_t input_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(16); + constexpr bool input_tensor_shard_grid_transposed = get_compile_time_arg_val(17) != 0; + + constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast(get_compile_time_arg_val(18)); + constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(19); + constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(20); + constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(21); + constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(22); + constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(23); + constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(24); + constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(25) != 0; #endif ASSERT(half_cb_n_pages > rem_num_pages); @@ -121,13 +123,13 @@ void kernel_main() { #ifdef INTERLEAVED_MEM_LAYOUT const DataFormat in0_df = get_dataformat(cb_id_in0); - const InterleavedAddrGenFast s = { + const InterleavedAddrGenFast s = { .bank_base_address = src_addr, .page_size = page_size, .data_format = in0_df }; - InterleavedAddrGenFast d = { + InterleavedAddrGenFast d = { .bank_base_address = dst_addr, .page_size = output_page_size, .data_format = in0_df diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp index 89796a74d6e..96ea1aaecae 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/kernels/dataflow/worker_interleaved_ring_gather_send_writer.cpp @@ -55,18 +55,19 @@ void kernel_main() { volatile uint32_t *const writer_send_sem_ptr = reinterpret_cast(get_semaphore(get_compile_time_arg_val(4))); constexpr uint32_t half_cb_n_pages = get_compile_time_arg_val(5); constexpr uint32_t num_buffers_per_channel = get_compile_time_arg_val(6); + constexpr uint32_t output_tile_size = get_compile_time_arg_val(7); ASSERT(half_cb_n_pages > rem_num_pages); #ifdef SHARDED_MEM_LAYOUT - constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast(get_compile_time_arg_val(7)); - constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(8); - constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(9); - constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(10); - constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(11); - constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(12); - constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(13); - constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(14) != 0; + constexpr tt::tt_metal::TensorMemoryLayout output_tensor_memory_layout = static_cast(get_compile_time_arg_val(8)); + constexpr uint32_t output_tensor_shard_grid_height = get_compile_time_arg_val(9); + constexpr uint32_t output_tensor_shard_grid_width = get_compile_time_arg_val(10); + constexpr uint32_t output_tensor_shard_grid_start_y_logical = get_compile_time_arg_val(11); + constexpr uint32_t output_tensor_shard_grid_start_x_logical = get_compile_time_arg_val(12); + constexpr uint32_t output_tensor_shard_pages_per_shard_y = get_compile_time_arg_val(13); + constexpr uint32_t output_tensor_shard_pages_per_shard_x = get_compile_time_arg_val(14); + constexpr bool output_tensor_shard_grid_transposed = get_compile_time_arg_val(15) != 0; #endif constexpr uint32_t cb_id_in0 = tt::CB::c_in0; @@ -95,7 +96,7 @@ void kernel_main() { #ifdef INTERLEAVED_MEM_LAYOUT const DataFormat in0_df = get_dataformat(cb_id_in0); - const InterleavedAddrGenFast d = { + const InterleavedAddrGenFast d = { .bank_base_address = dst_addr, .page_size = output_page_size, .data_format = in0_df diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp index 01fca69cf67..6d3fdbfa69d 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/multi_core/all_gather_op_multi_core.cpp @@ -290,8 +290,6 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( tt::DataFormat df = datatype_to_dataformat_converter(input_tensor.get_dtype()); std::map worker_defines; - worker_defines["INPUT_TILE_SIZE"] = std::to_string(input_tensor_config->get_tile_size()); - worker_defines["OUTPUT_TILE_SIZE"] = std::to_string(output_tensor_config->get_tile_size()); if (rm) { worker_defines["ROW_MAJOR_LAYOUT"] = "1"; } else { @@ -371,6 +369,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( log_trace(tt::LogOp, "input_page_size: {}", input_page_size); uint32_t src0_cb_index = tt::CB::c_in0; const uint32_t cb_n_packets = 2; + const uint32_t cb_size_in_pages = cb_n_packets * max_pages_per_chunk; const uint32_t CB_buffer_size = cb_n_packets * max_buffer_per_chunk; log_trace(tt::LogOp, "max_pages_per_chunk: {}", max_pages_per_chunk); CircularBufferConfig cb_src0_config = CircularBufferConfig(CB_buffer_size, {{src0_cb_index, df}}) @@ -398,7 +397,9 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( static_cast(ring_index), static_cast(sender_worker_reader_semaphore_id), static_cast(max_pages_per_chunk), - static_cast(ring_size) + static_cast(ring_size), + static_cast(input_tensor_config->get_tile_size()), + static_cast(output_tensor_config->get_tile_size()) }; if (is_sharded) { emit_sharded_tensor_kernel_ct_args(device, input_tensor, worker_reader_sender_ct_args, input_pages_per_shard_y, input_pages_per_shard_x); @@ -442,6 +443,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( static_cast(sender_worker_writer_semaphore_id), static_cast(max_pages_per_chunk), static_cast(num_edm_buffers_per_channel), + static_cast(output_tensor_config->get_tile_size()) }; if (is_sharded) { @@ -503,7 +505,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( static_cast(sender_worker_reader_semaphore_id), static_cast(max_pages_per_chunk), static_cast(ring_size), - static_cast(fuse_op) + static_cast(fuse_op), + static_cast(output_tensor_config->get_tile_size()) }; if (is_sharded) { @@ -690,7 +693,7 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers_helper( } if (rem_pages != 0) { rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) = rem_pages; - TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) <= max_pages_per_chunk * 2); + TT_ASSERT(rem_pages_per_worker.at(worker_idx % all_gather_config.get_num_workers_per_link()) <= cb_size_in_pages); } { // Logging log_trace(tt::LogOp, "num_full_chunks, remaining pages per worker (clockwise):");