From 5bbf96fb29f354861a062c13203b408658b47a1e Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Tue, 4 Jun 2024 02:18:22 +0000 Subject: [PATCH 01/53] #0: optimize allgather for small tensor sizes For smaller tensor sizes, there is only a single packet sent through the erisc data mover channel and that packet may be smaller in size than the channel buffer. For those cases, the datamover channel buffer is shrunk to be the same size as the packet. This can save a large amount of time. For LLama, there are smaller allgathers of size 32x1024 which are allgathered on dim=3. In those cases we can get up to a 2x improvement for bfp8 and slightly less for fp16. --- .../multi_core/all_gather_op_multi_core.cpp | 24 +++++++++++++++++++ .../ccl/ccl_host_datastructures.hpp | 22 ++++++++++++++--- 2 files changed, 43 insertions(+), 3 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp index 2c7f486cd81d..9ffcba874adc 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp @@ -318,6 +318,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& TT_ASSERT(rem_pages < pages_per_chunk || num_full_chunks == 0); TT_ASSERT(rem_pages <= max_pages_per_chunk); std::vector num_full_chunks_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), num_full_chunks / all_gather_config.get_num_eth_buffers_per_edm()); + std::vector is_channel_shrinkable(all_gather_config.get_num_eth_buffers_per_edm(), false); + std::vector largest_packets_per_channel(all_gather_config.get_num_eth_buffers_per_edm(), 0); std::vector rem_pages_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), 0); { uint32_t worker_idx = 0; @@ -355,10 +357,22 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& ); uint32_t max_shards_per_eth_buffer = std::min(all_gather_config.get_eth_buffer_size() / input_tensor_shard_arg_generator.args_struct.shard_size_in_bytes, input_tensor_shard_arg_generator.args_struct.num_dest_cores); TT_ASSERT(max_shards_per_eth_buffer > 0, "Codepath needs further generalization to support computing multiple sends per shard. Shard size: {}", input_tensor_shard_arg_generator.args_struct.shard_size_in_bytes); + log_info(tt::LogOp, "max_shards_per_eth_buffer: {}", max_shards_per_eth_buffer); num_full_chunks_per_worker.at(b) = input_tensor_shard_arg_generator.args_struct.num_dest_cores < max_shards_per_eth_buffer ? 1 : input_tensor_shard_arg_generator.args_struct.num_dest_cores / max_shards_per_eth_buffer; rem_pages_per_worker.at(b) = max_shards_per_eth_buffer > input_tensor_shard_arg_generator.args_struct.num_dest_cores ? 0 : input_tensor_shard_arg_generator.args_struct.num_dest_cores - (num_full_chunks_per_worker.at(b) * max_shards_per_eth_buffer); TT_ASSERT(rem_pages_per_worker.at(b) == 0 || input_tensor_shard_arg_generator.args_struct.num_dest_cores >= num_full_chunks_per_worker.at(b) * max_shards_per_eth_buffer); TT_ASSERT(input_tensor_shard_arg_generator.args_struct.num_dest_cores == rem_pages_per_worker.at(b) + num_full_chunks_per_worker.at(b) * max_shards_per_eth_buffer); + + uint32_t full_chunk_size_bytes = max_shards_per_eth_buffer * input_tensor_shard_arg_generator.args_struct.shard_size_in_bytes; + bool shrinkable = num_full_chunks_per_worker.at(b) == 1 && all_gather_config.get_eth_buffer_size() > full_chunk_size_bytes; + is_channel_shrinkable.at(b) = shrinkable; + largest_packets_per_channel.at(b) = shrinkable ? full_chunk_size_bytes : all_gather_config.get_eth_buffer_size(); + } + } else { + for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + bool shrinkable = num_full_chunks_per_worker.at(b) == 0; + is_channel_shrinkable.at(b) = shrinkable; + largest_packets_per_channel.at(b) = shrinkable ? rem_pages_per_worker.at(b) * input_page_size : all_gather_config.get_eth_buffer_size(); } } for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { @@ -412,6 +426,11 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "Adding sender EDM channel"); EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = sender_edm_builder.add_sender_channel(sender_worker_writer_semaphore_addr, clockwise_link_buffer_num_messages_to_send.at(b), sender_worker_coords); + if (is_channel_shrinkable.at(b)) { + TT_ASSERT(largest_packets_per_channel.at(b) > 0); + log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b); + sender_edm_builder.set_max_message_size_bytes(sender_channel_buffer_info.channel, largest_packets_per_channel.at(b)); + } sender_eth_sem_addrs.push_back(sender_channel_buffer_info.eth_semaphore_l1_address); sender_eth_buffer_addrs.push_back(sender_channel_buffer_info.eth_buffer_l1_address); } @@ -422,6 +441,11 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "Adding receiver EDM channel"); EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = receiver_edm_builder.add_receiver_channel(receiver_worker_semaphore_addr, counter_clockwise_link_buffer_num_messages_to_send.at(b), receiver_worker_coords); + if (is_channel_shrinkable.at(b)) { + TT_ASSERT(largest_packets_per_channel.at(b) > 0); + log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b); + receiver_edm_builder.set_max_message_size_bytes(receiver_channel_buffer_info.channel, largest_packets_per_channel.at(b)); + } receiver_eth_sem_addrs.push_back(receiver_channel_buffer_info.eth_semaphore_l1_address); receiver_eth_buffer_addrs.push_back(receiver_channel_buffer_info.eth_buffer_l1_address); } diff --git a/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp b/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp index 193046b8c549..89c237a6cb60 100644 --- a/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp +++ b/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp @@ -7,6 +7,7 @@ #include "eth_l1_address_map.h" #include "tensor/tensor_impl.hpp" #include "tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include namespace tt { namespace tt_metal { @@ -130,19 +131,25 @@ class EriscDatamoverBuilder { worker_semaphore_address(worker_semaphore_address), num_eth_messages_to_forward(num_eth_messages_to_forward), channel(channel), + largest_message_size_bytes(0), is_sender(is_sender) {} std::vector const worker_coords; uint32_t worker_semaphore_address; uint32_t num_eth_messages_to_forward; uint32_t channel; + uint32_t largest_message_size_bytes; bool is_sender; }; void push_back_channel_args(std::vector& args, ChannelBufferSpec const& channel) const { args.push_back(this->local_buffer_addresses.at(channel.channel)); args.push_back(channel.num_eth_messages_to_forward); - args.push_back(this->eth_buffer_size_bytes); + if (channel.largest_message_size_bytes > 0) { + args.push_back(std::min(channel.largest_message_size_bytes, this->eth_buffer_size_bytes)); + } else { + args.push_back(this->eth_buffer_size_bytes); + } args.push_back(this->local_semaphore_addresses.at(channel.channel)); args.push_back(channel.worker_semaphore_address); args.push_back(channel.worker_coords.size()); @@ -167,6 +174,7 @@ class EriscDatamoverBuilder { public: struct ChannelBufferInterface { + std::size_t channel; uint32_t eth_buffer_l1_address; uint32_t eth_semaphore_l1_address; }; @@ -224,8 +232,16 @@ class EriscDatamoverBuilder { log_trace(tt::LogOp, "\tbuffer_address: {}", local_buffer_addresses.at(channel)); log_trace(tt::LogOp, "\tsemaphore_address: {}", local_semaphore_addresses.at(channel)); - return ChannelBufferInterface{local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; + return ChannelBufferInterface{channel, local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; + } + + // This function is used to set the maximum message size for a given channel. If the maximum + // message size is < EDM channel buffer size, then the buffer size passed to the EDM for this channel + // will be trimmed be no larger than the largest message to save on unnecessary eth bandwidth. + void set_max_message_size_bytes(std::size_t channel, std::size_t max_message_size_bytes) { + active_channels.at(channel).largest_message_size_bytes = std::max(active_channels.at(channel).largest_message_size_bytes, max_message_size_bytes); } + [[nodiscard]] ChannelBufferInterface add_receiver_channel( uint32_t worker_semaphore_address, @@ -241,7 +257,7 @@ class EriscDatamoverBuilder { log_trace(tt::LogOp, "\tnum_eth_messages_to_forward: {}", active_channels.back().num_eth_messages_to_forward); log_trace(tt::LogOp, "\tchannel: {}", active_channels.back().channel); log_trace(tt::LogOp, "\tis_sender: {}", active_channels.back().is_sender ? 1 : 0); - return ChannelBufferInterface{local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; + return ChannelBufferInterface{channel, local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; } [[nodiscard]] From 5770143179b9250cb590f402fb22a7ad87899851 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Fri, 31 May 2024 12:15:55 +0000 Subject: [PATCH 02/53] #0: Enable weight caching for long running Mamba tests --- models/demos/mamba/demo/demo.py | 26 ++++++++----- models/demos/mamba/tests/test_full_model.py | 38 ++++++++++++------- .../demos/mamba/tests/test_full_model_loop.py | 5 ++- models/demos/mamba/tests/test_mamba_block.py | 11 +----- models/demos/mamba/tests/test_mamba_demo.py | 15 ++++++-- models/demos/mamba/tests/test_mamba_perf.py | 11 +++++- models/demos/mamba/tests/test_mamba_ssm.py | 11 +----- .../demos/mamba/tests/test_residual_block.py | 11 +----- 8 files changed, 70 insertions(+), 58 deletions(-) diff --git a/models/demos/mamba/demo/demo.py b/models/demos/mamba/demo/demo.py index fb95f1ececda..e798f2973348 100644 --- a/models/demos/mamba/demo/demo.py +++ b/models/demos/mamba/demo/demo.py @@ -28,13 +28,8 @@ def get_tt_metal_model( from models.demos.mamba.tt import model_config reference_model = get_cpu_reference_model(version, batch_size=batch_size) - if cache_dir: - cache_path = model_config.get_weights_cache_path(version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch_size, reference_model.args.d_model) - model = MambaTT(reference_model, device, config, tt_cache_path=cache_path) + model = MambaTT(reference_model, device, config, tt_cache_path=cache_dir) return model @@ -89,6 +84,7 @@ def run_mamba_demo( assert batch_size == len(prompts), "32 prompts are required" logger.info(f"Running Mamba demo (weights='{model_version}') with batch={batch_size}") + logger.info(f"Using tensor cache at '{cache_dir}'") model = get_tt_metal_model(model_version, device, cache_dir, batch_size) @@ -129,8 +125,18 @@ def run_mamba_demo( @pytest.mark.parametrize( - "max_gen_len", - ([100]), + "model_version, max_gen_len", + ( + ( + "state-spaces/mamba-2.8b-slimpj", + 100, + ), + ), ) -def test_demo(user_input, device, use_program_cache, max_gen_len): - return run_mamba_demo(prompts=user_input, device=device, generated_sequence_length=max_gen_len) +def test_demo(user_input, device, use_program_cache, get_tt_cache_path, model_version, max_gen_len): + return run_mamba_demo( + prompts=user_input, + device=device, + cache_dir=get_tt_cache_path(model_version), + generated_sequence_length=max_gen_len, + ) diff --git a/models/demos/mamba/tests/test_full_model.py b/models/demos/mamba/tests/test_full_model.py index afbdca353e80..585c18fcb4db 100644 --- a/models/demos/mamba/tests/test_full_model.py +++ b/models/demos/mamba/tests/test_full_model.py @@ -46,9 +46,9 @@ def run_inference( model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], num_layers: int, iterations: int, + cache_dir: Optional[str], ): torch.manual_seed(10) @@ -64,13 +64,8 @@ def run_inference( with torch.no_grad(): reference_output = mamba_model_pytorch(input_ids) - if cache_dir: - cache_path = model_config.get_weights_cache_path(model_version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch, reference_model.args.d_model) - mamba_model_tt = MambaTT(reference_model, device, config, tt_cache_path=cache_path, num_layers=num_layers) + mamba_model_tt = MambaTT(reference_model, device, config, tt_cache_path=cache_dir, num_layers=num_layers) for _ in range(iterations): tt_output = mamba_model_tt(input_ids) @@ -87,13 +82,12 @@ def run_inference( @skip_for_grayskull("Not supported on Grayskull") @pytest.mark.parametrize( - "model_version, batch, pcc, cache_dir, num_layers, iterations", + "model_version, batch, pcc, num_layers, iterations", ( ( "state-spaces/mamba-2.8b", 32, 0.985, - None, 64, 1, ), @@ -102,14 +96,23 @@ def run_inference( def test_inference( device: ttnn.Device, use_program_cache, + get_tt_cache_path, model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], num_layers: int, iterations: int, ): - run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations) + run_inference( + device, + use_program_cache, + model_version, + batch, + pcc, + num_layers, + iterations, + cache_dir=get_tt_cache_path(model_version), + ) @skip_for_grayskull("Not supported on Grayskull") @@ -120,11 +123,20 @@ def test_inference( def test_device_perf( device: ttnn.Device, use_program_cache, + get_tt_cache_path, iterations, model_version="state-spaces/mamba-2.8b", batch=32, pcc=0.97, - cache_dir=None, num_layers=1, ): - run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations) + run_inference( + device, + use_program_cache, + model_version, + batch, + pcc, + num_layers, + iterations, + cache_dir=get_tt_cache_path(model_version), + ) diff --git a/models/demos/mamba/tests/test_full_model_loop.py b/models/demos/mamba/tests/test_full_model_loop.py index 532e9f509cf5..1fc0f91c6d41 100644 --- a/models/demos/mamba/tests/test_full_model_loop.py +++ b/models/demos/mamba/tests/test_full_model_loop.py @@ -12,11 +12,12 @@ def test_inference_loop( device: ttnn.Device, use_program_cache, + get_tt_cache_path, model_version="state-spaces/mamba-2.8b", batch=32, pcc=0.88, - cache_dir=None, num_layers=64, iterations=10, ): - run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations) + cache_dir = get_tt_cache_path(model_version) + run_inference(device, use_program_cache, model_version, batch, pcc, num_layers, iterations, cache_dir) diff --git a/models/demos/mamba/tests/test_mamba_block.py b/models/demos/mamba/tests/test_mamba_block.py index 0589e551d2ea..a141f57450a9 100644 --- a/models/demos/mamba/tests/test_mamba_block.py +++ b/models/demos/mamba/tests/test_mamba_block.py @@ -30,13 +30,12 @@ def forward(self, x): @pytest.mark.parametrize( - "model_version, batch, pcc, cache_dir", + "model_version, batch, pcc", ( ( "state-spaces/mamba-2.8b", 32, 0.99, - None, ), ), ) @@ -46,7 +45,6 @@ def test_mamba_block_inference( model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], ): torch.manual_seed(0) @@ -63,14 +61,9 @@ def test_mamba_block_inference( residual_block = reference_model.layers[LAYER_NUM] assert not isinstance(residual_block, torch.Tensor), "Expected torch.Module" - if cache_dir: - cache_path = model_config.get_weights_cache_path(model_version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch, d_model) - loader = TtTensorLoader(reference_model.state_dict(), device, tt_cache_path=cache_path) + loader = TtTensorLoader(reference_model.state_dict(), device) transformer = MambaSsmBlockTransformer( device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 ) diff --git a/models/demos/mamba/tests/test_mamba_demo.py b/models/demos/mamba/tests/test_mamba_demo.py index d14b07571eb3..21a8ed6734b9 100644 --- a/models/demos/mamba/tests/test_mamba_demo.py +++ b/models/demos/mamba/tests/test_mamba_demo.py @@ -7,8 +7,15 @@ @pytest.mark.parametrize( - "user_input, max_gen_len", - ((["Hello World"], 2),), + "user_input, model_version, max_gen_len", + ((["Hello World"], "state-spaces/mamba-2.8b-slimpj", 2),), ) -def test_demo(user_input, device, use_program_cache, max_gen_len): - return run_mamba_demo(prompts=user_input, device=device, generated_sequence_length=max_gen_len, display=False) +def test_demo(user_input, model_version, device, use_program_cache, get_tt_cache_path, max_gen_len): + return run_mamba_demo( + prompts=user_input, + model_version=model_version, + device=device, + generated_sequence_length=max_gen_len, + display=False, + cache_dir=get_tt_cache_path(model_version), + ) diff --git a/models/demos/mamba/tests/test_mamba_perf.py b/models/demos/mamba/tests/test_mamba_perf.py index 1563a29d00bc..e83e3ac4976c 100644 --- a/models/demos/mamba/tests/test_mamba_perf.py +++ b/models/demos/mamba/tests/test_mamba_perf.py @@ -27,7 +27,14 @@ ((32, 10, 12.5, 0.40),), # Issue 7816 Compile time ) def test_mamba_e2e_perf( - device, batch, iterations, expected_compile_time, expected_inference_time, use_program_cache, reset_seeds + device, + batch, + iterations, + expected_compile_time, + expected_inference_time, + use_program_cache, + reset_seeds, + get_tt_cache_path, ): model_version = "state-spaces/mamba-2.8b-slimpj" display_decoded_seq = False @@ -46,7 +53,7 @@ def test_mamba_e2e_perf( profiler.end("pytorch_ref_model_setup") profiler.start("tt_model_setup") - tt_model = get_tt_metal_model(model_version, device, cache_dir=None, batch_size=batch) + tt_model = get_tt_metal_model(model_version, device, cache_dir=get_tt_cache_path(model_version), batch_size=batch) profiler.end("tt_model_setup") sequences: torch.Tensor = tokenizer(prompts, return_tensors="pt", padding=True).input_ids diff --git a/models/demos/mamba/tests/test_mamba_ssm.py b/models/demos/mamba/tests/test_mamba_ssm.py index 43d5b66ac3ef..22898760ec9d 100644 --- a/models/demos/mamba/tests/test_mamba_ssm.py +++ b/models/demos/mamba/tests/test_mamba_ssm.py @@ -30,13 +30,12 @@ def forward(self, x): @pytest.mark.parametrize( - "model_version, batch, pcc, cache_dir", + "model_version, batch, pcc", ( ( "state-spaces/mamba-2.8b", 32, 0.99, - None, ), ), ) @@ -46,7 +45,6 @@ def test_mamba_ssm_inference( model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], ): torch.manual_seed(0) @@ -63,14 +61,9 @@ def test_mamba_ssm_inference( residual_block = reference_model.layers[LAYER_NUM] assert not isinstance(residual_block, torch.Tensor), "Expected torch.Module" - if cache_dir: - cache_path = model_config.get_weights_cache_path(model_version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch, reference_model.args.d_model) - loader = TtTensorLoader(reference_model.state_dict(), device, tt_cache_path=cache_path) + loader = TtTensorLoader(reference_model.state_dict(), device) transformer = MambaSsmBlockTransformer( device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 ) diff --git a/models/demos/mamba/tests/test_residual_block.py b/models/demos/mamba/tests/test_residual_block.py index 16e521c70717..47267f426351 100644 --- a/models/demos/mamba/tests/test_residual_block.py +++ b/models/demos/mamba/tests/test_residual_block.py @@ -29,13 +29,12 @@ def forward(self, x): @pytest.mark.parametrize( - "model_version, batch, pcc, cache_dir", + "model_version, batch, pcc", ( ( "state-spaces/mamba-2.8b", 32, 0.99, - None, ), ), ) @@ -45,7 +44,6 @@ def test_mamba_residual_block_inference( model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], ): torch.manual_seed(0) @@ -62,14 +60,9 @@ def test_mamba_residual_block_inference( residual_block = reference_model.layers[LAYER_NUM] assert not isinstance(residual_block, torch.Tensor), "Expected torch.Module" - if cache_dir: - cache_path = model_config.get_weights_cache_path(model_version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch, d_model) - loader = TtTensorLoader(reference_model.state_dict(), device, tt_cache_path=cache_path) + loader = TtTensorLoader(reference_model.state_dict(), device) transformer = MambaSsmBlockTransformer( device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 ) From d20ad35cf715d1f086c17ffce22bdb2415bbbe97 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Tue, 4 Jun 2024 14:01:39 +0000 Subject: [PATCH 03/53] #0: Remove MambaSsmBlockTransformer because it no longer used --- models/demos/mamba/tests/test_mamba_block.py | 6 +- models/demos/mamba/tests/test_mamba_ssm.py | 6 +- .../demos/mamba/tests/test_residual_block.py | 7 +- models/demos/mamba/tests/test_transforms.py | 93 ------------------- models/demos/mamba/tt/full_model.py | 7 +- models/demos/mamba/tt/mamba_block.py | 5 +- models/demos/mamba/tt/mamba_one_step_ssm.py | 5 +- models/demos/mamba/tt/residual_block.py | 5 +- models/demos/mamba/tt/transforms.py | 62 ------------- 9 files changed, 10 insertions(+), 186 deletions(-) delete mode 100644 models/demos/mamba/tests/test_transforms.py delete mode 100644 models/demos/mamba/tt/transforms.py diff --git a/models/demos/mamba/tests/test_mamba_block.py b/models/demos/mamba/tests/test_mamba_block.py index a141f57450a9..8d118a26b26b 100644 --- a/models/demos/mamba/tests/test_mamba_block.py +++ b/models/demos/mamba/tests/test_mamba_block.py @@ -10,7 +10,6 @@ from models.demos.mamba.tt.full_model import TtTensorLoader from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName from models.demos.mamba.tt.mamba_block import TtMambaBlock -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer from models.demos.mamba.tt import model_config from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_allclose, @@ -64,11 +63,8 @@ def test_mamba_block_inference( config = model_config.create_model_config(batch, d_model) loader = TtTensorLoader(reference_model.state_dict(), device) - transformer = MambaSsmBlockTransformer( - device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 - ) - model = TtMambaBlock(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM), transformer) + model = TtMambaBlock(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM)) tt_input = input.view(1, 1, batch, d_model) tt_input = ttnn.to_device( ttnn.from_torch(tt_input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), diff --git a/models/demos/mamba/tests/test_mamba_ssm.py b/models/demos/mamba/tests/test_mamba_ssm.py index 22898760ec9d..bc489d5b7be9 100644 --- a/models/demos/mamba/tests/test_mamba_ssm.py +++ b/models/demos/mamba/tests/test_mamba_ssm.py @@ -10,7 +10,6 @@ from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName from models.demos.mamba.tt.full_model import TtTensorLoader from models.demos.mamba.tt.mamba_one_step_ssm import TtMambaSSM -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer from models.demos.mamba.tt import model_config from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_allclose, @@ -64,11 +63,8 @@ def test_mamba_ssm_inference( config = model_config.create_model_config(batch, reference_model.args.d_model) loader = TtTensorLoader(reference_model.state_dict(), device) - transformer = MambaSsmBlockTransformer( - device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 - ) - model = TtMambaSSM(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM), transformer) + model = TtMambaSSM(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM)) tt_input = input.view(1, 1, batch, d_in) tt_input = ttnn.to_device( ttnn.from_torch(tt_input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), diff --git a/models/demos/mamba/tests/test_residual_block.py b/models/demos/mamba/tests/test_residual_block.py index 47267f426351..005eba21ed13 100644 --- a/models/demos/mamba/tests/test_residual_block.py +++ b/models/demos/mamba/tests/test_residual_block.py @@ -7,7 +7,7 @@ from loguru import logger from typing import Optional import ttnn -from models.demos.mamba.tt.full_model import TtTensorLoader, MambaSsmBlockTransformer +from models.demos.mamba.tt.full_model import TtTensorLoader from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName from models.demos.mamba.tt.residual_block import TtResidualBlock from models.demos.mamba.tt import model_config @@ -63,11 +63,8 @@ def test_mamba_residual_block_inference( config = model_config.create_model_config(batch, d_model) loader = TtTensorLoader(reference_model.state_dict(), device) - transformer = MambaSsmBlockTransformer( - device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 - ) - model = TtResidualBlock(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM), transformer) + model = TtResidualBlock(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM)) tt_input = input.view(1, 1, batch, d_model) tt_input = ttnn.to_device( ttnn.from_torch(tt_input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), diff --git a/models/demos/mamba/tests/test_transforms.py b/models/demos/mamba/tests/test_transforms.py deleted file mode 100644 index 0e94ec769081..000000000000 --- a/models/demos/mamba/tests/test_transforms.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import pytest - -import ttnn -import tt_lib as ttl - -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( - comp_allclose, - comp_pcc, -) - -N = 32 -HIDDEN_SIZE = 2560 - - -@pytest.mark.parametrize( - "batch, pcc", - ( - ( - 32, - 0.99, - ), - ), -) -def test_mamba_ssm_block_repeat_interleave( - device: ttnn.Device, - use_program_cache, - batch: int, - pcc: float, -): - input = torch.rand(1, 1, batch, HIDDEN_SIZE * 2) - - expected = torch.repeat_interleave(input, N, dim=3) - - transformer = MambaSsmBlockTransformer(device, batch, HIDDEN_SIZE * 2, N) - input = ttnn.to_device( - ttnn.from_torch(input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), - device=device, - memory_config=ttnn.L1_MEMORY_CONFIG, - ) - actual = transformer.repeat_interleave( - input, - memory_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) - - assert list(actual.get_legacy_shape()) == [1, 1, batch, 2 * HIDDEN_SIZE * N] - - actual = ttnn.to_torch(actual) - passing_pcc, output_pcc = comp_pcc(actual, expected, 0.9999) - assert passing_pcc - - -@pytest.mark.parametrize( - "batch, pcc", - ( - ( - 32, - 0.99, - ), - ), -) -def test_mamba_ssm_block_repeat( - device: ttnn.Device, - batch: int, - pcc: float, - use_program_cache, -): - input = torch.rand(1, 1, batch, N) - - # (1, 1, B, n) -> (1, 1, B, hidden * 2 * n) - expected = input.repeat((1, 1, 1, HIDDEN_SIZE * 2)) - - transformer = MambaSsmBlockTransformer(device, batch, HIDDEN_SIZE * 2, N) - input = ttnn.to_device( - ttnn.from_torch(input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), - device=device, - memory_config=ttnn.L1_MEMORY_CONFIG, - ) - actual = transformer.repeat( - input, - memory_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) - - assert list(actual.get_legacy_shape()) == [1, 1, batch, 2 * HIDDEN_SIZE * N] - - actual = ttnn.to_torch(actual) - passing_pcc, output_pcc = comp_pcc(actual, expected, 0.9999) - assert passing_pcc diff --git a/models/demos/mamba/tt/full_model.py b/models/demos/mamba/tt/full_model.py index 0c3c3438ac97..a06ad6b9f800 100644 --- a/models/demos/mamba/tt/full_model.py +++ b/models/demos/mamba/tt/full_model.py @@ -11,7 +11,6 @@ from typing import Callable, Optional from models.demos.mamba.tt.residual_block import TtResidualBlock -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer class TtTensorLoader: @@ -81,13 +80,9 @@ def __init__( self.embedding = reference_model.embedding loader = TtTensorLoader(reference_model.state_dict(), self.device, tt_cache_path=tt_cache_path) - transformer = MambaSsmBlockTransformer( - self.device, self.args.batch_size, self.args.d_inner, configs["latent_size"] - ) self.layers = [ - TtResidualBlock(self.args, device, configs, loader.get_tensor_loader(i), transformer) - for i in range(self.num_layers) + TtResidualBlock(self.args, device, configs, loader.get_tensor_loader(i)) for i in range(self.num_layers) ] load_fn = loader.get_tensor_loader() diff --git a/models/demos/mamba/tt/mamba_block.py b/models/demos/mamba/tt/mamba_block.py index 5dd3ab55ec30..d5fe4adffde9 100644 --- a/models/demos/mamba/tt/mamba_block.py +++ b/models/demos/mamba/tt/mamba_block.py @@ -10,11 +10,10 @@ from models.demos.mamba.reference.args import ModelArgs from models.demos.mamba.tt.mamba_one_step_ssm import TtMambaSSM -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer class TtMambaBlock(torch.nn.Module): - def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transformer: MambaSsmBlockTransformer): + def __init__(self, args: ModelArgs, device, configs, load_fn: Callable): super().__init__() self.device = device @@ -76,7 +75,7 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transfor ) ) - self.tt_ssm = TtMambaSSM(self.args, self.device, configs, load_fn, transformer) + self.tt_ssm = TtMambaSSM(self.args, self.device, configs, load_fn) self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( math_fidelity=ttl.tensor.MathFidelity.HiFi3, diff --git a/models/demos/mamba/tt/mamba_one_step_ssm.py b/models/demos/mamba/tt/mamba_one_step_ssm.py index 5cf769e75aed..f5d07996c783 100644 --- a/models/demos/mamba/tt/mamba_one_step_ssm.py +++ b/models/demos/mamba/tt/mamba_one_step_ssm.py @@ -9,15 +9,12 @@ from typing import Callable from models.demos.mamba.reference.args import ModelArgs -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer class TtMambaSSM(torch.nn.Module): - def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transformer: MambaSsmBlockTransformer): + def __init__(self, args: ModelArgs, device, configs, load_fn: Callable): super().__init__() - self.transformer = transformer - self.device = device self.args = args diff --git a/models/demos/mamba/tt/residual_block.py b/models/demos/mamba/tt/residual_block.py index a1cf33f2d70a..dbe3ff1236a4 100644 --- a/models/demos/mamba/tt/residual_block.py +++ b/models/demos/mamba/tt/residual_block.py @@ -10,11 +10,10 @@ from models.demos.mamba.reference.args import ModelArgs from models.demos.mamba.tt.mamba_block import TtMambaBlock -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer class TtResidualBlock(torch.nn.Module): - def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transformer: MambaSsmBlockTransformer): + def __init__(self, args: ModelArgs, device, configs, load_fn: Callable): super().__init__() self.device = device @@ -24,7 +23,7 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transfor rms_norm_weight_name = "norm.weight" self.rms_norm_weights = load_fn(rms_norm_weight_name) - self.tt_mamba_block = TtMambaBlock(self.args, self.device, configs, load_fn, transformer) + self.tt_mamba_block = TtMambaBlock(self.args, self.device, configs, load_fn) def forward(self, x): assert len(x.shape) == 4, "Mamba residual block expects inputs to be rank 4" diff --git a/models/demos/mamba/tt/transforms.py b/models/demos/mamba/tt/transforms.py deleted file mode 100644 index 8978da096c24..000000000000 --- a/models/demos/mamba/tt/transforms.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -import tt_lib as ttl -import torch - - -class MambaSsmBlockTransformer: - def __init__(self, device, batch_size, hidden_size, latent_size): - self.device = device - self.batch_size = batch_size - self.hidden_size = hidden_size - self.latent_size = latent_size - repeat_interleave_mask = torch.ones(1, 1, batch_size, latent_size) - self.repeat_interleave_mask = ttnn.from_torch( - repeat_interleave_mask, - layout=ttnn.TILE_LAYOUT, - device=device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=ttnn.bfloat16, - ) - - repeat_mask = torch.ones(1, 1, batch_size, hidden_size) - self.repeat_mask = ttnn.from_torch( - repeat_mask, - layout=ttnn.TILE_LAYOUT, - device=device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=ttnn.bfloat16, - ) - - def repeat_interleave(self, x, memory_config): - """ - This function implements an SSM-specific repeat_interleave operation needed to transform - the SSM block input (X) from (B, 2E) to (B, 2EN) so that it can be multiplied with delta*B. - - """ - assert x.shape == ( - 1, - 1, - self.batch_size, - self.hidden_size, - ), f"Expected repeat_interleave input to be (1, 1, B, 2E) (was {x.shape})" - return ttl.operations.primary.transformers.ssm_eltwise_mul( - self.repeat_interleave_mask, x, output_mem_config=memory_config - ) - - def repeat(self, x, memory_config): - """ - This function implements an SSM-specific repeat operation needed to transform the C - value from (B, N) to (B, 2EN) where N is the latent size (32) and E is the - up project size (2560). - """ - assert x.shape == ( - 1, - 1, - self.batch_size, - self.latent_size, - ), f"Expected repeat input to be (1, 1, B, N) (was {x.shape})" - return ttl.operations.primary.transformers.ssm_eltwise_mul(x, self.repeat_mask, output_mem_config=memory_config) From 8833ad5d41c094848187f67479bb56ff63d15f19 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Tue, 4 Jun 2024 14:13:46 +0000 Subject: [PATCH 04/53] #0: Remove redundant Mamba model loop test --- .../demos/mamba/tests/test_full_model_loop.py | 23 ------------------- .../single_card/nightly/run_wh_b0_only.sh | 7 +++--- 2 files changed, 3 insertions(+), 27 deletions(-) delete mode 100644 models/demos/mamba/tests/test_full_model_loop.py diff --git a/models/demos/mamba/tests/test_full_model_loop.py b/models/demos/mamba/tests/test_full_model_loop.py deleted file mode 100644 index 1fc0f91c6d41..000000000000 --- a/models/demos/mamba/tests/test_full_model_loop.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn - -from models.demos.mamba.tests.test_full_model import run_inference -from models.utility_functions import skip_for_grayskull - - -@skip_for_grayskull("Not supported on Grayskull") -def test_inference_loop( - device: ttnn.Device, - use_program_cache, - get_tt_cache_path, - model_version="state-spaces/mamba-2.8b", - batch=32, - pcc=0.88, - num_layers=64, - iterations=10, -): - cache_dir = get_tt_cache_path(model_version) - run_inference(device, use_program_cache, model_version, batch, pcc, num_layers, iterations, cache_dir) diff --git a/tests/scripts/single_card/nightly/run_wh_b0_only.sh b/tests/scripts/single_card/nightly/run_wh_b0_only.sh index 163ed499c4a3..5af44887070a 100755 --- a/tests/scripts/single_card/nightly/run_wh_b0_only.sh +++ b/tests/scripts/single_card/nightly/run_wh_b0_only.sh @@ -14,13 +14,12 @@ SLOW_MATMULS=1 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml env pytest tes env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b/tests/ci/test_falcon_end_to_end_prefill.py +env pytest models/demos/mamba/tests/test_benchmarks.py +env pytest models/demos/mamba/tests/test_reference_model.py env pytest models/demos/mamba/tests/test_mamba_ssm.py env pytest models/demos/mamba/tests/test_mamba_block.py env pytest models/demos/mamba/tests/test_residual_block.py -env pytest models/demos/mamba/tests/test_full_model_loop.py -env pytest models/demos/mamba/tests/test_benchmarks.py -env pytest models/demos/mamba/tests/test_reference_model.py -env pytest models/demos/mamba/tests/test_transforms.py +env pytest models/demos/mamba/tests/test_full_model.py env pytest models/demos/mamba/tests/test_mamba_demo.py env pytest models/demos/wormhole/mistral7b/tests/test_mistral_embedding.py From eb194f25ac0700663ba57c31788f4af4d69e75fb Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Tue, 4 Jun 2024 14:43:33 +0000 Subject: [PATCH 05/53] #0: Lower expected PCC in Mamba full model tests by 0.001 This is required since commit 0598421 lowered overall model PCC. --- models/demos/mamba/tests/test_full_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/demos/mamba/tests/test_full_model.py b/models/demos/mamba/tests/test_full_model.py index 585c18fcb4db..6790dd186524 100644 --- a/models/demos/mamba/tests/test_full_model.py +++ b/models/demos/mamba/tests/test_full_model.py @@ -87,7 +87,7 @@ def run_inference( ( "state-spaces/mamba-2.8b", 32, - 0.985, + 0.984, 64, 1, ), From afb1672a070397cb22cc0fbf937bfa90ae5847b2 Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Thu, 30 May 2024 17:10:07 +0000 Subject: [PATCH 06/53] #5389: removed early return from validate when enable_fast_runtime_mode was set to true --- tests/ttnn/integration_tests/mistral/test_mistral_attention.py | 3 +++ tt_eager/tt_dnn/op_library/operation.hpp | 3 --- ttnn/ttnn/__init__.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/ttnn/integration_tests/mistral/test_mistral_attention.py b/tests/ttnn/integration_tests/mistral/test_mistral_attention.py index efc2dc36a8bb..c3a516d12df9 100644 --- a/tests/ttnn/integration_tests/mistral/test_mistral_attention.py +++ b/tests/ttnn/integration_tests/mistral/test_mistral_attention.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest + import torch import ttnn import tt_lib @@ -19,6 +21,7 @@ from tests.ttnn.utils_for_testing import assert_with_pcc +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @skip_for_wormhole_b0() def test_mistral_attention_inference(model_location_generator, device, reset_seeds): model_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral") diff --git a/tt_eager/tt_dnn/op_library/operation.hpp b/tt_eager/tt_dnn/op_library/operation.hpp index 6ef4b8fc33dc..26285d0b5e80 100644 --- a/tt_eager/tt_dnn/op_library/operation.hpp +++ b/tt_eager/tt_dnn/op_library/operation.hpp @@ -528,9 +528,6 @@ struct DeviceOperation final { const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, const OptionalTensors& optional_output_tensors) -> void { - if (ttnn::CONFIG.enable_fast_runtime_mode) { - return; - } const auto& operation = *reinterpret_cast*>(&storage); if constexpr ( (detail::implements_validate() or diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 889a517af461..ea52b8fb3864 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -57,7 +57,7 @@ def validate(self, name): if self.enable_fast_runtime_mode: if self.enable_logging: logger.warning( - "Running in fast runtime mode without logging. Please disable fast runtime mode if you want to enable logging." + "Logging cannot be enabled in fast runtime mode. Please disable fast runtime mode if you want to enable logging." ) if name in { From 8ca86c4ac65351999f1a949da017358e071e4698 Mon Sep 17 00:00:00 2001 From: yugaoT Date: Mon, 3 Jun 2024 22:18:56 +0000 Subject: [PATCH 07/53] #0: fix matmul dram sharded validation --- tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp index 3100d466520d..29cbae919476 100644 --- a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp @@ -1041,7 +1041,7 @@ void Matmul::validate( // subbblock constraint TT_FATAL(program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1); // tensor in1 - TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED); } else if constexpr (std::is_same_v) { if (input_tensor_a.memory_config().is_sharded()) { auto tensor_a_memory_layout = input_tensor_a.memory_config().memory_layout; From ad7c3a22309ff60ebb4d21e56a0a2953b8c07965 Mon Sep 17 00:00:00 2001 From: mtairum Date: Tue, 4 Jun 2024 09:42:43 +0000 Subject: [PATCH 08/53] #5337: Removed unucessary ttnn.to_device() from mixtral code --- models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py | 1 - .../demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py | 2 +- models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py | 6 ------ models/demos/t3000/mixtral8x7b/tt/mixtral_common.py | 7 ++----- models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py | 3 --- models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py | 1 - models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py | 1 - 7 files changed, 3 insertions(+), 18 deletions(-) diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py index c4428ed36366..3db26429c2e6 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py @@ -69,7 +69,6 @@ def test_mixtral_mlp_inference(t3k_device_mesh, use_program_cache, reset_seeds): layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(t3k_device_mesh), ) - tt_input = ttnn.to_device(tt_input, t3k_device_mesh) tt_output = tt_model(tt_input) tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ConcatMeshToTensor(t3k_device_mesh, dim=0))[0] diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py index b50abc7a3e98..6557af40fab2 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py @@ -55,7 +55,7 @@ def test_mistral_rms_norm_inference(t3k_device_mesh, use_program_cache, reset_se layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(t3k_device_mesh), ) - tt_input = ttnn.to_device(tt_input, t3k_device_mesh) + tt_output = tt_model(tt_input) tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ConcatMeshToTensor(t3k_device_mesh, dim=0))[0] passing, pcc_message = comp_pcc(reference_output, tt_output_torch) diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index 332db2bbfb03..d22af394cf0e 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -81,7 +81,6 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): cache_file_name=cache_name(f"wqkv_multidevice_4d"), ) - self.wqkv = ttnn.to_device(self.wqkv, self.device_mesh) self.wo = ttnn.as_tensor( torch.transpose( self.state_dict[wo_str], @@ -98,8 +97,6 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): cache_file_name=cache_name(f"wo_multidevice4d"), ) - self.wo = ttnn.to_device(self.wo, self.device_mesh) - cache_k = torch.zeros( ( self.n_kv_heads, @@ -130,8 +127,6 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): for lp in layer_past ] - self.layer_past = [ttnn.to_device(lp, self.device_mesh) for lp in self.layer_past] - self.scale = self.head_dim**-0.5 reduce_mask_torch = torch.zeros(1, 1, self.max_batch_size, self.max_batch_size * 8) @@ -145,7 +140,6 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): layout=ttnn.TILE_LAYOUT, ) - self.reduce_mask = ttnn.to_device(self.reduce_mask, self.device_mesh) self.compute_kernel = self.model_args.get_compute_kernel_config() self.compute_kernel_attn = self.model_args.get_compute_kernel_attn_config() diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py index 83e35f0a0aaa..d3cb5e9f677c 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py @@ -81,7 +81,6 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, current_pos, sliding_window, device_ memory_config=ttnn.L1_MEMORY_CONFIG, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - xs_1SBH = ttnn.to_device(xs_1SBH, device_mesh) # Attention mask padded_layer_past_len = min(nearest_32(current_pos + 1), sliding_window) @@ -108,7 +107,7 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, current_pos, sliding_window, device_ memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - attn_mask = ttnn.to_device(attn_mask, device_mesh) + ATTN_MASK_MEMCFG = ttnn.create_sharded_memory_config( shape=(32, padded_layer_past_len), core_grid=ttnn.CoreGrid(y=4, x=8), @@ -137,7 +136,6 @@ def prepare_rotation_mat_ttnn(head_dim, max_seq_len, device_mesh): ) for rot_mat_i in rot_mat ] - rot_mats = [ttnn.to_device(rot_mat, device_mesh) for rot_mat in rot_mats] return rot_mats @@ -178,7 +176,6 @@ def cache_attention(device_mesh, state_dict, model_args, rot_emb_matrix_list, se memory_config=ttnn.L1_MEMORY_CONFIG, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - attention_inputs = ttnn.to_device(attention_inputs, device_mesh) tt_attn = TtMixtralAttention( device_mesh, @@ -201,7 +198,7 @@ def cache_attention(device_mesh, state_dict, model_args, rot_emb_matrix_list, se memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - attn_mask = ttnn.to_device(attn_mask, device_mesh) + ATTN_MASK_MEMCFG = ttnn.create_sharded_memory_config( shape=(32, padded_layer_past_len), core_grid=ttnn.CoreGrid(y=4, x=8), diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py index 665ef5d9fd30..f3c2002d4d83 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py @@ -43,11 +43,8 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtypes): ) self.w1 = as_tensor("w1") - self.w1 = ttnn.to_device(self.w1, device_mesh) self.w2 = as_tensor("w2") - self.w2 = ttnn.to_device(self.w2, device_mesh) self.w3 = as_tensor("w3") - self.w3 = ttnn.to_device(self.w3, device_mesh) def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: """ diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py index 598f9663bc08..6664ad227e2e 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py @@ -48,7 +48,6 @@ def __init__(self, device_mesh, state_dict, experts, args, layer_num, dtype): device=self.device_mesh, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - self.reduce_mask = ttnn.to_device(self.reduce_mask, device_mesh) self.expert_mask_11BB = ttnn.from_torch( torch.cat([torch.full((1, 1, 32, 32), fill_value=i + 1) for i in range(8)], dim=3), dtype=ttnn.uint16, diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py index 4c29ee50ae0d..4957d4d6d1ef 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py @@ -88,7 +88,6 @@ def __init__( cache_file_name=cache_name, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - self.weight = ttnn.to_device(self.weight, device_mesh) def forward(self, x: ttnn.Tensor, out_sharded=False) -> ttnn.Tensor: x = ttnn.experimental.tensor.interleaved_to_sharded( From 556567258cdd0af03d84006a1b81285d47d571eb Mon Sep 17 00:00:00 2001 From: mtairum Date: Tue, 4 Jun 2024 14:17:57 +0100 Subject: [PATCH 09/53] #5337: Update Mixtral perf CI times --- models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py | 8 ++++---- tests/scripts/t3000/run_t3000_model_perf_tests.sh | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py index 043666dd8ce9..174ff0c5b235 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py @@ -44,10 +44,10 @@ def forward(self, x): @pytest.mark.parametrize( "generation_start_pos, expected_compile_time, expected_inference_time", ( - (32, 150, 7.5), - (128, 150, 7.5), - (1024, 150, 7.5), - (2048, 150, 7.5), + (32, 150, 0.025), + (128, 150, 0.025), + (1024, 150, 0.025), + (2048, 150, 0.025), ), ) def test_mixtral_model_perf( diff --git a/tests/scripts/t3000/run_t3000_model_perf_tests.sh b/tests/scripts/t3000/run_t3000_model_perf_tests.sh index abff688f6487..c8fc186f9bca 100755 --- a/tests/scripts/t3000/run_t3000_model_perf_tests.sh +++ b/tests/scripts/t3000/run_t3000_model_perf_tests.sh @@ -22,7 +22,7 @@ run_t3000_mixtral_tests() { echo "LOG_METAL: Running run_t3000_mixtral_tests" - env pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py::test_mixtral_model_perf[wormhole_b0-True-2048-150-7.5] -m "model_perf_t3000" + env pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py::test_mixtral_model_perf[wormhole_b0-True-2048-150-0.025] -m "model_perf_t3000" # Record the end time end_time=$(date +%s) From 61920c84a074214bacc26160e92e7659137a6a06 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Mon, 3 Jun 2024 13:15:47 +0000 Subject: [PATCH 10/53] #8837: Resnet multi cq write/program overlap --- models/demos/resnet/tests/test_perf_resnet.py | 51 +++++++++++++++---- models/demos/resnet/tt/metalResnetBlock50.py | 29 +++++------ tt_eager/tt_lib/csrc/tt_lib_bindings.cpp | 24 +++++++++ 3 files changed, 77 insertions(+), 27 deletions(-) diff --git a/models/demos/resnet/tests/test_perf_resnet.py b/models/demos/resnet/tests/test_perf_resnet.py index f7bc7368ed2b..d572f544a229 100644 --- a/models/demos/resnet/tests/test_perf_resnet.py +++ b/models/demos/resnet/tests/test_perf_resnet.py @@ -9,9 +9,7 @@ import pytest import tt_lib -from models.utility_functions import is_e75 -from models.utility_functions import profiler -from models.utility_functions import disable_persistent_kernel_cache, skip_for_wormhole_b0 +from models.utility_functions import is_e75, profiler, divup, disable_persistent_kernel_cache, skip_for_wormhole_b0 from models.perf.perf_utils import prep_perf_report from loguru import logger @@ -76,21 +74,54 @@ def run_perf_resnet( profiler.end(cpu_key) tt_inputs = tt_resnet50.preprocessing(inputs) + input_shape = tt_inputs.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_inputs.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_inputs.shape, tt_inputs.dtype, tt_inputs.layout, device, sharded_mem_config_DRAM + ) + op_event = tt_lib.device.CreateEvent() + write_event = tt_lib.device.CreateEvent() + # Initialize the op event so we can write + tt_lib.device.RecordEvent(device, 0, op_event) warmup_end = 5 for iter in range(0, warmup_end): profiler.start(f"{iter}_key") - _ = tt_resnet50(tt_inputs).cpu(blocking=True) + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) profiler.end(f"{iter}_key") tt_lib.device.DumpDeviceProfiler(device) - num_warm_iterations = 15 + num_warm_iterations = 10 warm_start = warmup_end warm_end = warm_start + num_warm_iterations outputs = [] profiler.start(f"run") for iter in range(warm_start, warm_end): - outputs.append(tt_resnet50(tt_inputs).cpu(blocking=False)) + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + outputs.append(tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=False)) tt_lib.device.Synchronize(device) profiler.end(f"run") tt_lib.device.DumpDeviceProfiler(device) @@ -120,14 +151,14 @@ def run_perf_resnet( @skip_for_wormhole_b0(reason_str="Not tested on single WH") -@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( "batch_size, expected_inference_time, expected_compile_time", ( - (1, 0.001, 1), - (2, 0.001, 1), - (16, 0.007, 7), + # (1, 0.001, 1), + # (2, 0.001, 1), + # (16, 0.007, 7), (20, 0.007, 7), ), ) diff --git a/models/demos/resnet/tt/metalResnetBlock50.py b/models/demos/resnet/tt/metalResnetBlock50.py index 16f8fb01ffb1..32e3f913c314 100644 --- a/models/demos/resnet/tt/metalResnetBlock50.py +++ b/models/demos/resnet/tt/metalResnetBlock50.py @@ -2101,7 +2101,7 @@ def preprocessing_with_fold(self, x: torch.Tensor) -> tt_lib.tensor: return x - def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: + def forward(self, x: tt_lib.tensor, write_event=None, op_event=None) -> tt_lib.tensor: if not self.sharded: original_A_cl_host_shape = x.get_legacy_shape() x = x.reshape( @@ -2116,7 +2116,7 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: original_A_cl_host_shape[2], original_A_cl_host_shape[3], ) - elif x.storage_type() != tt_lib.tensor.StorageType.DEVICE: + else: x_shape = x.get_legacy_shape() shard_spec = tt_lib.tensor.ShardSpec( self.shard_grid, @@ -2130,21 +2130,16 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: mem_config = tt_lib.tensor.MemoryConfig( tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec ) - x = x.to(self.device, mem_config) - else: - shard_spec = tt_lib.tensor.ShardSpec( - self.shard_grid, - [ - x.get_legacy_shape()[2] // self.first_conv_num_cores_nhw, - x.get_legacy_shape()[3], - ], - tt_lib.tensor.ShardOrientation.ROW_MAJOR, - False, - ) - mem_config = tt_lib.tensor.MemoryConfig( - tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec - ) - x = tt_lib.tensor.interleaved_to_sharded(x, mem_config) + if write_event is not None: + tt_lib.device.WaitForEvent(self.device, 0, write_event) + if x.storage_type() != tt_lib.tensor.StorageType.DEVICE: + x = x.to(self.device, mem_config) + elif x.memory_config().is_sharded(): + x = tt_lib.tensor.reshard(x, mem_config) + else: + x = tt_lib.tensor.interleaved_to_sharded(x, mem_config) + if op_event is not None: + tt_lib.device.RecordEvent(self.device, 0, op_event) x = self.conv1(x) # Relu is fused with conv1 diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp index fc3710689212..2d15d354531f 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp @@ -251,6 +251,30 @@ void DeviceModule(py::module &m_device) { Release captured Trace on Device handle )doc"); + auto pyEvent = py::class_>(m_device, "Event", "Event class"); + m_device.def("CreateEvent", + [] () { + return std::make_shared(); + }, R"doc( + Create new event + )doc"); + m_device.def("RecordEvent", + [] (Device* device, const uint8_t cq_id, std::shared_ptr event) { + device->push_work([device, cq_id, event] { + EnqueueRecordEvent(device->command_queue(cq_id), event); + }); + }, R"doc( + Record an event + )doc"); + m_device.def("WaitForEvent", + [] (Device* device, const uint8_t cq_id, std::shared_ptr event) { + device->push_work([device, cq_id, event] { + EnqueueWaitForEvent(device->command_queue(cq_id), event); + }); + }, R"doc( + Wait for an event + )doc"); + m_device.attr("DEFAULT_L1_SMALL_SIZE") = py::int_(DEFAULT_L1_SMALL_SIZE); } From a57868a28667bdea8c6da466dcb57cf86e172f5c Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Mon, 3 Jun 2024 13:16:01 +0000 Subject: [PATCH 11/53] #8837: Use a different noc for each cq for dispatch --- tt_metal/impl/device/device.cpp | 76 ++++++++++++++----- tt_metal/impl/device/device.hpp | 6 +- tt_metal/impl/dispatch/command_queue.cpp | 70 ++++++++++------- tt_metal/impl/dispatch/command_queue.hpp | 26 +++++-- .../impl/dispatch/kernels/cq_dispatch.cpp | 9 +-- .../impl/dispatch/kernels/cq_prefetch.cpp | 3 +- .../impl/dispatch/kernels/cq_prefetch.hpp | 2 +- tt_metal/impl/program/program.cpp | 16 ++-- tt_metal/impl/program/program.hpp | 9 +-- tt_metal/impl/program/program_device_map.hpp | 6 +- 10 files changed, 144 insertions(+), 79 deletions(-) diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index e73c9efbdbba..6e9892c130c4 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -16,6 +16,7 @@ #include "common/utils.hpp" #include "llrt/llrt.hpp" #include "dev_msgs.h" +#include "noc/noc_parameters.h" namespace tt { @@ -344,16 +345,19 @@ void Device::configure_kernel_variant( CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in, + NOC noc_index, bool is_active_eth_core) { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + std::map defines = { {"DISPATCH_KERNEL", "1"}, - {"MY_NOC_X", std::to_string(kernel_physical_core.x)}, - {"MY_NOC_Y", std::to_string(kernel_physical_core.y)}, - {"UPSTREAM_NOC_X", std::to_string(upstream_physical_core.x)}, - {"UPSTREAM_NOC_Y", std::to_string(upstream_physical_core.y)}, - {"DOWNSTREAM_NOC_X", std::to_string(downstream_physical_core.x)}, - {"DOWNSTREAM_NOC_Y", std::to_string(downstream_physical_core.y)}, + {"MY_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, kernel_physical_core.x))}, + {"MY_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, kernel_physical_core.y))}, + {"UPSTREAM_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, upstream_physical_core.x))}, + {"UPSTREAM_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, upstream_physical_core.y))}, + {"DOWNSTREAM_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, downstream_physical_core.x))}, + {"DOWNSTREAM_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, downstream_physical_core.y))}, }; defines.insert(defines_in.begin(), defines_in.end()); @@ -364,7 +368,7 @@ void Device::configure_kernel_variant( kernel_core, tt::tt_metal::DataMovementConfig { .processor = tt::tt_metal::DataMovementProcessor::RISCV_1, - .noc = NOC::NOC_0, + .noc = noc_index, .compile_args = compile_args, .defines = defines } @@ -376,7 +380,7 @@ void Device::configure_kernel_variant( kernel_core, tt::tt_metal::EthernetConfig{ .eth_mode = is_active_eth_core ? Eth::SENDER : Eth::IDLE, - .noc = NOC::NOC_0, + .noc = noc_index, .compile_args = compile_args, .defines = defines } @@ -420,6 +424,8 @@ void Device::compile_command_queue_programs() { CoreCoord prefetch_physical_core = get_physical_core_coordinate(prefetch_core, dispatch_core_type); CoreCoord dispatch_physical_core = get_physical_core_coordinate(dispatch_core, dispatch_core_type); + NOC noc_index = this->hw_command_queues_[cq_id]->noc_index; + log_debug(LogDevice, "Dispatching out of {} cores", magic_enum::enum_name(dispatch_core_type)); log_debug(LogDevice, "Prefetch HD logical location: {} physical core: {}", prefetch_core.str(), prefetch_physical_core.str()); log_debug(LogDevice, "Dispatch HD logical location: {} physical core {}", dispatch_core.str(), dispatch_physical_core.str()); @@ -465,7 +471,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, dispatch_physical_core, - std::map {} + std::map {}, + noc_index ); tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, prefetch_core, 0, dispatch_core_type); // prefetch_sync_sem @@ -501,7 +508,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, prefetch_physical_core, CoreCoord{0, 0}, - std::map {} + std::map {}, + noc_index ); tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_core, 0, dispatch_core_type); // dispatch_sem @@ -517,7 +525,7 @@ void Device::compile_command_queue_programs() { Device *mmio_device = tt::tt_metal::detail::GetDeviceHandle(mmio_device_id); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_id); uint32_t cq_size = mmio_device->sysmem_manager().get_cq_size(); - + NOC noc_index = this->hw_command_queues_[cq_id]->noc_index; CoreType dispatch_core_type = dispatch_core_manager::get(num_hw_cqs).get_dispatch_core_type(mmio_device_id); tt_cxy_pair prefetch_core = dispatch_core_manager::get(num_hw_cqs).prefetcher_core(device_id, channel, cq_id); @@ -610,7 +618,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, mux_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run prefetch_h {}", prefetch_core.str()); @@ -671,7 +680,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); std::vector tunneler_l_compile_args = @@ -715,6 +725,7 @@ void Device::compile_command_queue_programs() { CoreCoord{0, 0}, CoreCoord{0, 0}, std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index, true ); @@ -782,7 +793,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run dispatch demux at {}", demux_core.str()); @@ -816,7 +828,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, demux_physical_core, CoreCoord{0xffffffff, 0xffffffff}, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run dispatch_h at {}", dispatch_core.str()); @@ -895,6 +908,7 @@ void Device::compile_command_queue_programs() { CoreCoord{0, 0}, CoreCoord{0, 0}, std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index, true ); @@ -959,7 +973,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run demux at {}", demux_d_core.str()); @@ -1007,7 +1022,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, demux_d_physical_core, dispatch_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run prefertch_d at {}", prefetch_d_core.str()); @@ -1041,7 +1057,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, prefetch_d_physical_core, mux_d_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run dispatch at {}", dispatch_core.str()); @@ -1100,7 +1117,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run mux at {}", mux_d_core.str()); @@ -1194,7 +1212,7 @@ void Device::initialize_command_queue() { this->sysmem_manager_ = std::make_unique(this->id_, this->num_hw_cqs()); hw_command_queues_.resize(num_hw_cqs()); for (size_t cq_id = 0; cq_id < num_hw_cqs(); cq_id++) { - hw_command_queues_[cq_id] = std::make_unique(this, cq_id); + hw_command_queues_[cq_id] = std::make_unique(this, cq_id, static_cast(cq_id)); // Need to do this since CommandQueue constructor is private sw_command_queues_.push_back(std::unique_ptr(new CommandQueue(this, cq_id))); } @@ -1530,6 +1548,24 @@ std::vector Device::ethernet_cores_from_logical_cores(const std::vect return ethernet_cores; } +uint32_t Device::get_noc_unicast_encoding(uint8_t noc_index, const CoreCoord& physical_core) const { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + return NOC_XY_ENCODING( + NOC_0_X(noc_index, grid_size.x, physical_core.x), + NOC_0_Y(noc_index, grid_size.y, physical_core.y) + ); +} + +uint32_t Device::get_noc_multicast_encoding(uint8_t noc_index, const CoreRange& physical_cores) const { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + return NOC_MULTICAST_ENCODING( + NOC_0_X(noc_index, grid_size.x, physical_cores.start.x), + NOC_0_Y(noc_index, grid_size.y, physical_cores.start.y), + NOC_0_X(noc_index, grid_size.x, physical_cores.end.x), + NOC_0_Y(noc_index, grid_size.y, physical_cores.end.y) + ); +} + void Device::check_allocator_is_initialized() const { if (this->allocator_ == nullptr) { TT_THROW("No memory allocator! Device has not been initialized, did you forget to call InitializeDevice?"); diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 7b054f030688..12df80a6bee1 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -11,6 +11,7 @@ #include "impl/dispatch/work_executor.hpp" #include "tt_metal/impl/allocator/basic_allocator.hpp" #include "tt_metal/impl/allocator/l1_banking_allocator.hpp" +#include "tt_metal/impl/kernels/data_types.hpp" #include "tt_metal/impl/trace/trace_buffer.hpp" #include "tt_metal/jit_build/build.hpp" #include "llrt/tt_cluster.hpp" @@ -192,6 +193,9 @@ class Device { // core.y represents different channels along one const std::set ðernet_cores() const { return this->ethernet_cores_; } + uint32_t get_noc_unicast_encoding(uint8_t noc_index, const CoreCoord& physical_core) const; + uint32_t get_noc_multicast_encoding(uint8_t noc_index, const CoreRange& physical_cores) const; + void deallocate_buffers(); // machine epsilon @@ -229,7 +233,7 @@ class Device { void initialize_command_queue(); void initialize_synchronous_sw_cmd_queue(); void configure_kernel_variant(Program& program, string path, std::vector compile_args, CoreCoord kernel_core, CoreCoord Kernel_physical_core, - CoreType dispatch_core_type, CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in , bool is_active_eth_core = false); + CoreType dispatch_core_type, CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in, NOC noc_index, bool is_active_eth_core = false); void compile_command_queue_programs(); void configure_command_queue_programs(); void clear_l1_state(); diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 5df863d7b3bc..59cf23af4f46 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -43,16 +43,12 @@ namespace tt::tt_metal { thread_local std::unordered_map EnqueueProgramCommand::cached_program_command_sequences = {}; -uint32_t get_noc_unicast_encoding(const CoreCoord& coord) { return NOC_XY_ENCODING(NOC_X(coord.x), NOC_Y(coord.y)); } -uint32_t get_noc_multicast_encoding(const CoreCoord& start, const CoreCoord& end) { - return NOC_MULTICAST_ENCODING(start.x, start.y, end.x, end.y); -} - // EnqueueReadBufferCommandSection EnqueueReadBufferCommand::EnqueueReadBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -60,6 +56,7 @@ EnqueueReadBufferCommand::EnqueueReadBufferCommand( uint32_t src_page_index, std::optional pages_to_read) : command_queue_id(command_queue_id), + noc_index(noc_index), dst(dst), manager(manager), buffer(buffer), @@ -89,7 +86,7 @@ void EnqueueReadShardedBufferCommand::add_prefetch_relay(HugepageDeviceCommand& const CoreCoord physical_core = this->buffer.device()->physical_core_from_logical_core(this->core, this->buffer.core_type()); command.add_prefetch_relay_linear( - get_noc_unicast_encoding(physical_core), padded_page_size * this->pages_to_read, this->bank_base_address); + this->device->get_noc_unicast_encoding(this->noc_index, physical_core), padded_page_size * this->pages_to_read, this->bank_base_address); } void EnqueueReadBufferCommand::process() { @@ -125,6 +122,7 @@ void EnqueueReadBufferCommand::process() { EnqueueWriteBufferCommand::EnqueueWriteBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -135,6 +133,7 @@ EnqueueWriteBufferCommand::EnqueueWriteBufferCommand( uint32_t dst_page_index, std::optional pages_to_write) : command_queue_id(command_queue_id), + noc_index(noc_index), manager(manager), issue_wait(issue_wait), src(src), @@ -211,7 +210,7 @@ void EnqueueWriteShardedBufferCommand::add_dispatch_write(HugepageDeviceCommand& this->buffer.device()->physical_core_from_logical_core(this->core, this->buffer.core_type()); bool flush_prefetch = true; command_sequence.add_dispatch_write_linear( - flush_prefetch, 0, get_noc_unicast_encoding(physical_core), this->bank_base_address, data_size_bytes); + flush_prefetch, 0, this->device->get_noc_unicast_encoding(this->noc_index, physical_core), this->bank_base_address, data_size_bytes); } void EnqueueWriteShardedBufferCommand::add_buffer_data(HugepageDeviceCommand& command_sequence) { @@ -287,10 +286,12 @@ void EnqueueWriteBufferCommand::process() { EnqueueProgramCommand::EnqueueProgramCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Program& program, SystemMemoryManager& manager, uint32_t expected_num_workers_completed) : command_queue_id(command_queue_id), + noc_index(noc_index), manager(manager), expected_num_workers_completed(expected_num_workers_completed), program(program) { @@ -462,13 +463,12 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { // can make a vector of unicast encodings here CoreCoord physical_core = device->physical_core_from_logical_core(core_coord, kernel->get_kernel_core_type()); - uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); const auto& runtime_args_data = kernel->runtime_args(core_coord); unique_rt_args_data[processor_idx].emplace_back(kernel->runtime_args_data(core_coord)); // 2, 17, could be differnet len here unique_sub_cmds[processor_idx].emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_core)}); unique_rt_data_and_sizes[processor_idx].emplace_back( runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t)); unique_max_runtime_args_len[processor_idx] = @@ -496,12 +496,11 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { for (auto& core_coord : kernel->logical_cores()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord); - uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); unicast_sub_cmd.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_core)}); } } else { - vector> dst_noc_multicast_info = + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, kernel->logical_coreranges(), kernel->get_kernel_core_type()); common_sub_cmds[kernel_id].emplace>( @@ -511,7 +510,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { multicast_sub_cmd.reserve(dst_noc_multicast_info.size()); for (const auto& mcast_dests : dst_noc_multicast_info) { multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, std::get(mcast_dests.first)), .num_mcast_dests = mcast_dests.second}); } } } @@ -634,7 +633,6 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const CoreRange& core_range : circular_buffers_unique_coreranges) { const CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start); const CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end); - const uint32_t dst_noc_multicast_encoding = get_noc_multicast_encoding(physical_start, physical_end); const uint32_t num_receivers = core_range.size(); auto& cb_config_payload = cb_config_payloads[i]; @@ -659,7 +657,7 @@ void EnqueueProgramCommand::assemble_device_commands() { } } multicast_cb_config_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_multicast_encoding, .num_mcast_dests = (uint32_t)core_range.size()}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, CoreRange(physical_start, physical_end)), .num_mcast_dests = (uint32_t)core_range.size()}); multicast_cb_config_data.emplace_back( cb_config_payload.data(), (max_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t)); @@ -683,7 +681,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (int buffer_idx = 0; buffer_idx < program.program_transfer_info.kernel_bins.size(); buffer_idx++) { const auto& kg_transfer_info = program.program_transfer_info.kernel_bins[buffer_idx]; for (int kernel_idx = 0; kernel_idx < kg_transfer_info.dst_base_addrs.size(); kernel_idx++) { - for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; } @@ -709,9 +707,8 @@ void EnqueueProgramCommand::assemble_device_commands() { CoreCoord physical_end = device->physical_core_from_logical_core(core_range.end, kernel_group.get_core_type()); - uint32_t dst_noc_multicast_encoding = get_noc_multicast_encoding(physical_start, physical_end); multicast_go_signal_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_multicast_encoding, .num_mcast_dests = (uint32_t)core_range.size()}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, CoreRange(physical_start, physical_end)), .num_mcast_dests = (uint32_t)core_range.size()}); multicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); } } @@ -733,9 +730,8 @@ void EnqueueProgramCommand::assemble_device_commands() { for (auto y = core_range.start.y; y <= core_range.end.y; y++) { CoreCoord physical_coord = device->physical_core_from_logical_core(CoreCoord({x, y}), kernel_group.get_core_type()); - uint32_t dst_noc_unicast_encoding = get_noc_unicast_encoding(physical_coord); unicast_go_signal_sub_cmds.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = dst_noc_unicast_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_coord)}); unicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); } } @@ -768,7 +764,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const auto& dst_noc_info : transfer_info.dst_noc_info) { num_packed_cmds += 1; multicast_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_info.first, .num_mcast_dests = dst_noc_info.second}); + .noc_xy_addr =this->device->get_noc_multicast_encoding(this->noc_index, std::get(dst_noc_info.first)), .num_mcast_dests = dst_noc_info.second}); sem_data.emplace_back(transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); } } @@ -796,7 +792,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const auto& dst_noc_info : transfer_info.dst_noc_info) { num_packed_cmds += 1; unicast_sub_cmds.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = dst_noc_info.first}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr =this->device->get_noc_unicast_encoding(this->noc_index, std::get(dst_noc_info.first))}); sem_data.emplace_back(transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); } } @@ -828,11 +824,22 @@ void EnqueueProgramCommand::assemble_device_commands() { for (int buffer_idx = 0; buffer_idx < program.program_transfer_info.kernel_bins.size(); buffer_idx++) { const auto& kg_transfer_info = program.program_transfer_info.kernel_bins[buffer_idx]; for (int kernel_idx = 0; kernel_idx < kg_transfer_info.dst_base_addrs.size(); kernel_idx++) { - for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + uint32_t noc_encoding; + std::visit( + [&](auto&& cores) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + noc_encoding = this->device->get_noc_multicast_encoding(this->noc_index, cores); + } else { + noc_encoding = this->device->get_noc_unicast_encoding(this->noc_index, cores); + } + }, + dst_noc_info.first); program_command_sequence.add_dispatch_write_linear( false, // flush_prefetch dst_noc_info.second, // num_mcast_dests - dst_noc_info.first, // noc_xy_addr + noc_encoding, // noc_xy_addr kg_transfer_info.dst_base_addrs[kernel_idx], align(kg_transfer_info.lengths[kernel_idx], NOC_DRAM_ALIGNMENT_BYTES)); // Difference between prefetch total relayed pages and dispatch write linear @@ -1026,12 +1033,14 @@ void EnqueueProgramCommand::process() { EnqueueRecordEventCommand::EnqueueRecordEventCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, SystemMemoryManager& manager, uint32_t event_id, uint32_t expected_num_workers_completed, bool clear_count) : command_queue_id(command_queue_id), device(device), + noc_index(noc_index), manager(manager), event_id(event_id), expected_num_workers_completed(expected_num_workers_completed), @@ -1080,7 +1089,7 @@ void EnqueueRecordEventCommand::process() { CoreCoord dispatch_physical_core = get_physical_core_coordinate(dispatch_location, core_type); unicast_sub_cmds[cq_id] = - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = get_noc_unicast_encoding(dispatch_physical_core)}; + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, dispatch_physical_core)}; event_payloads[cq_id] = {event_payload.data(), event_payload.size() * sizeof(uint32_t)}; } @@ -1209,11 +1218,12 @@ void EnqueueTerminateCommand::process() { } // HWCommandQueue section -HWCommandQueue::HWCommandQueue(Device* device, uint32_t id) : +HWCommandQueue::HWCommandQueue(Device* device, uint32_t id, NOC noc_index) : manager(device->sysmem_manager()), completion_queue_thread{} { ZoneScopedN("CommandQueue_constructor"); this->device = device; this->id = id; + this->noc_index = noc_index; this->num_entries_in_completion_q = 0; this->num_completed_completion_q_reads = 0; @@ -1340,6 +1350,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin auto command = EnqueueReadShardedBufferCommand( this->id, this->device, + this->noc_index, buffer, dst, this->manager, @@ -1376,6 +1387,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin auto command = EnqueueReadInterleavedBufferCommand( this->id, this->device, + this->noc_index, buffer, dst, this->manager, @@ -1514,6 +1526,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, auto command = EnqueueWriteShardedBufferCommand( this->id, this->device, + this->noc_index, buffer, src, this->manager, @@ -1605,6 +1618,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, auto command = EnqueueWriteInterleavedBufferCommand( this->id, this->device, + this->noc_index, buffer, src, this->manager, @@ -1646,7 +1660,7 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) { // Snapshot of expected workers from previous programs, used for dispatch_wait cmd generation. uint32_t expected_workers_completed = this->manager.get_bypass_mode() ? this->trace_ctx->num_completion_worker_cores : this->expected_num_workers_completed; - auto command = EnqueueProgramCommand(this->id, this->device, program, this->manager, expected_workers_completed); + auto command = EnqueueProgramCommand(this->id, this->device, this->noc_index, program, this->manager, expected_workers_completed); this->enqueue_command(command, blocking); log_trace( @@ -1677,7 +1691,7 @@ void HWCommandQueue::enqueue_record_event(std::shared_ptr event, bool cle event->ready = true; // what does this mean??? auto command = EnqueueRecordEventCommand( - this->id, this->device, this->manager, event->event_id, this->expected_num_workers_completed, clear_count); + this->id, this->device, this->noc_index, this->manager, event->event_id, this->expected_num_workers_completed, clear_count); this->enqueue_command(command, false); if (clear_count) { diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index 578724880f00..9809824eab57 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -55,9 +55,6 @@ string EnqueueCommandTypeToString(EnqueueCommandType ctype); #define NOC_X(x) x #define NOC_Y(y) y -uint32_t get_noc_unicast_encoding(const CoreCoord& coord); -uint32_t get_noc_multicast_encoding(const CoreCoord& start, const CoreCoord& end); - class CommandQueue; class CommandInterface; @@ -74,13 +71,14 @@ class EnqueueReadBufferCommand : public Command { private: SystemMemoryManager& manager; void* dst; - uint32_t command_queue_id; CoreType dispatch_core_type; virtual void add_prefetch_relay(HugepageDeviceCommand& command) = 0; protected: Device* device; + uint32_t command_queue_id; + NOC noc_index; uint32_t expected_num_workers_completed; uint32_t src_page_index; uint32_t pages_to_read; @@ -90,6 +88,7 @@ class EnqueueReadBufferCommand : public Command { EnqueueReadBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -112,6 +111,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadInterleavedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -121,6 +121,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadBufferCommand( command_queue_id, device, + noc_index, buffer, dst, manager, @@ -139,6 +140,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadShardedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -150,6 +152,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadBufferCommand( command_queue_id, device, + noc_index, buffer, dst, manager, @@ -165,7 +168,6 @@ class EnqueueWriteInterleavedBufferCommand; class EnqueueWriteBufferCommand : public Command { private: SystemMemoryManager& manager; - uint32_t command_queue_id; CoreType dispatch_core_type; virtual void add_dispatch_write(HugepageDeviceCommand& command) = 0; @@ -173,6 +175,8 @@ class EnqueueWriteBufferCommand : public Command { protected: Device* device; + uint32_t command_queue_id; + NOC noc_index; const void* src; const Buffer& buffer; uint32_t expected_num_workers_completed; @@ -186,6 +190,7 @@ class EnqueueWriteBufferCommand : public Command { EnqueueWriteBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -212,6 +217,7 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteInterleavedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -224,6 +230,7 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteBufferCommand( command_queue_id, device, + noc_index, buffer, src, manager, @@ -249,6 +256,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteShardedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -263,6 +271,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteBufferCommand( command_queue_id, device, + noc_index, buffer, src, manager, @@ -282,6 +291,7 @@ class EnqueueProgramCommand : public Command { private: uint32_t command_queue_id; Device* device; + NOC noc_index; Program& program; SystemMemoryManager& manager; CoreType dispatch_core_type; @@ -302,6 +312,7 @@ class EnqueueProgramCommand : public Command { EnqueueProgramCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Program& program, SystemMemoryManager& manager, uint32_t expected_num_workers_completed); @@ -321,6 +332,7 @@ class EnqueueRecordEventCommand : public Command { private: uint32_t command_queue_id; Device* device; + NOC noc_index; SystemMemoryManager& manager; uint32_t event_id; uint32_t expected_num_workers_completed; @@ -330,6 +342,7 @@ class EnqueueRecordEventCommand : public Command { EnqueueRecordEventCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, SystemMemoryManager& manager, uint32_t event_id, uint32_t expected_num_workers_completed, @@ -474,11 +487,12 @@ struct RuntimeArgsMetadata { class HWCommandQueue { public: - HWCommandQueue(Device* device, uint32_t id); + HWCommandQueue(Device* device, uint32_t id, NOC noc_index); ~HWCommandQueue(); CoreCoord completion_queue_writer_core; + NOC noc_index; volatile bool is_dprint_server_hung(); volatile bool is_noc_hung(); diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp index a506c16df3ea..75b525d0a91b 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp @@ -43,7 +43,7 @@ constexpr uint32_t is_h_variant = get_compile_time_arg_val(16); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); -constexpr uint32_t pcie_noc_xy_encoding = uint32_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)); +constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y))); constexpr uint32_t dispatch_cb_page_size = 1 << dispatch_cb_log_page_size; constexpr uint32_t completion_queue_end_addr = completion_queue_base_addr + completion_queue_size; @@ -141,7 +141,7 @@ void completion_queue_reserve_back(uint32_t num_pages) { FORCE_INLINE void notify_host_of_completion_queue_write_pointer() { uint64_t completion_queue_write_ptr_addr = command_queue_base_addr + HOST_CQ_COMPLETION_WRITE_PTR; - uint64_t pcie_address = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_ptr_addr); // For now, we are writing to host hugepages at offset + uint64_t pcie_address = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_ptr_addr); // For now, we are writing to host hugepages at offset uint32_t completion_wr_ptr_and_toggle = cq_write_interface.completion_fifo_wr_ptr | (cq_write_interface.completion_fifo_wr_toggle << 31); volatile tt_l1_ptr uint32_t* completion_wr_ptr_addr = get_cq_completion_write_ptr(); completion_wr_ptr_addr[0] = completion_wr_ptr_and_toggle; @@ -208,7 +208,7 @@ void process_write_host_h() { uint32_t npages = (xfer_size + completion_queue_page_size - 1) / completion_queue_page_size; completion_queue_reserve_back(npages); uint32_t completion_queue_write_addr = cq_write_interface.completion_fifo_wr_ptr << 4; - uint64_t host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_addr); + uint64_t host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_addr); // completion_queue_write_addr will never be equal to completion_queue_end_addr due to completion_queue_push_back // wrap logic so we don't need to handle this case explicitly to avoid 0 sized transactions if (completion_queue_write_addr + xfer_size > completion_queue_end_addr) { @@ -218,7 +218,7 @@ void process_write_host_h() { data_ptr += last_chunk_size; length -= last_chunk_size; xfer_size -= last_chunk_size; - host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_addr); + host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_addr); block_noc_writes_to_clear[rd_block_idx]+=(last_chunk_size + NOC_MAX_BURST_SIZE - 1) / NOC_MAX_BURST_SIZE; // XXXXX maybe just write the noc internal api counter } noc_async_write(data_ptr, host_completion_queue_write_addr, xfer_size); @@ -783,7 +783,6 @@ static inline bool process_cmd_d(uint32_t& cmd_ptr) { DPRINT << "cmd_wait" << ENDL(); process_wait(); break; - case CQ_DISPATCH_CMD_GO: DPRINT << "cmd_go" << ENDL(); break; diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp index f990132a60c8..0ee658ad1c2b 100644 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp @@ -52,6 +52,7 @@ constexpr uint32_t is_h_variant = get_compile_time_arg_val(22); constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); +constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y))); constexpr uint32_t downstream_cb_page_size = 1 << downstream_cb_log_page_size; constexpr uint32_t downstream_cb_end = downstream_cb_base + (1 << downstream_cb_log_page_size) * downstream_cb_pages; constexpr uint32_t prefetch_q_end = prefetch_q_base + prefetch_q_size; @@ -146,7 +147,7 @@ void read_from_pcie(volatile tt_l1_ptr prefetch_q_entry_type *& prefetch_q_rd_pt pcie_read_ptr = pcie_base; } - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); + uint64_t host_src_addr = get_noc_addr_helper(pcie_noc_xy, pcie_read_ptr); DPRINT << "read_from_pcie: " << fence + preamble_size << " " << pcie_read_ptr << ENDL(); noc_async_read(host_src_addr, fence + preamble_size, size); pending_read_size = size + preamble_size; diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp index f77c26d9f330..036316ee43a0 100644 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp +++ b/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp @@ -64,7 +64,7 @@ void read_from_pcie(volatile tt_l1_ptr uint16_t *& prefetch_q_rd_ptr, pcie_read_ptr = pcie_base; } - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); + uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(NOC_X(PCIE_NOC_X), NOC_Y(PCIE_NOC_Y)), pcie_read_ptr); noc_async_read(host_src_addr, fence + preamble_size, size); pending_read_size = size + preamble_size; pcie_read_ptr += size; diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index 1edcca12168f..a507e2e2337f 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -590,16 +590,14 @@ void Program::populate_dispatch_data(Device *device) { {RISCV::ERISC, eth_l1_mem::address_map::FIRMWARE_BASE}}; auto extract_dst_noc_unicast_info = - [&device](const set &ranges, const CoreType core_type) -> vector> { + [&device](const set &ranges, const CoreType core_type) -> vector> { // This API extracts all the pairs of noc multicast encodings given a set of core ranges - vector> dst_noc_unicast_info; + vector> dst_noc_unicast_info; for (const CoreRange &core_range : ranges) { for (auto x = core_range.start.x; x <= core_range.end.x; x++) { for (auto y = core_range.start.y; y <= core_range.end.y; y++) { CoreCoord physical_coord = device->physical_core_from_logical_core(CoreCoord({x, y}), core_type); - uint32_t dst_noc_unicast_encoding = - NOC_XY_ENCODING(NOC_X(physical_coord.x), NOC_Y(physical_coord.y)); - dst_noc_unicast_info.push_back(std::make_pair(dst_noc_unicast_encoding, /*num_mcast_dests=*/0)); + dst_noc_unicast_info.push_back(std::make_pair(physical_coord, /*num_mcast_dests=*/0)); } } } @@ -613,7 +611,7 @@ void Program::populate_dispatch_data(Device *device) { // TODO: use semaphore.core_type from main if (semaphore.core_type() == CoreType::WORKER) { - vector> dst_noc_multicast_info = + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, semaphore.core_range_set().ranges(), semaphore.core_type()); transfer_info_2 transfer_info = { @@ -623,7 +621,7 @@ void Program::populate_dispatch_data(Device *device) { .data = semaphore_data}; this->program_transfer_info.multicast_semaphores[semaphore.address()].push_back(transfer_info); } else if (semaphore.core_type() == CoreType::ETH) { - vector> dst_noc_unicast_info = + vector> dst_noc_unicast_info = extract_dst_noc_unicast_info(semaphore.core_range_set().ranges(), semaphore.core_type()); transfer_info_2 transfer_info = { .dst_base_addr = semaphore.address(), @@ -640,7 +638,7 @@ void Program::populate_dispatch_data(Device *device) { // Program Binaries and Go Signals // TODO: cleanup put the WORKERS and ETH logic together.. for (KernelGroup &kernel_group : this->get_kernel_groups(CoreType::WORKER)) { - vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, kernel_group.core_ranges.ranges(), kernel_group.get_core_type()); // So far, we don't support linking optimizations for kernel groups @@ -710,7 +708,7 @@ void Program::populate_dispatch_data(Device *device) { } } for (KernelGroup &kernel_group : this->get_kernel_groups(CoreType::ETH)) { - vector> dst_noc_unicast_info = + vector> dst_noc_unicast_info = extract_dst_noc_unicast_info(kernel_group.core_ranges.ranges(), kernel_group.get_core_type()); vector kernel_ids; diff --git a/tt_metal/impl/program/program.hpp b/tt_metal/impl/program/program.hpp index 868a9c711e18..10e33f55591c 100644 --- a/tt_metal/impl/program/program.hpp +++ b/tt_metal/impl/program/program.hpp @@ -54,19 +54,16 @@ struct KernelGroup { }; template -vector> extract_dst_noc_multicast_info(Device* device, const CoreRangeContainer& ranges, const CoreType core_type) { +vector> extract_dst_noc_multicast_info(Device* device, const CoreRangeContainer& ranges, const CoreType core_type) { // This API extracts all the pairs of noc multicast encodings given a set of core ranges - vector> dst_noc_multicast_info; + vector> dst_noc_multicast_info; dst_noc_multicast_info.reserve(ranges.size()); for (const CoreRange& core_range : ranges) { CoreCoord physical_start = device->physical_core_from_logical_core(core_range.start, core_type); CoreCoord physical_end = device->physical_core_from_logical_core(core_range.end, core_type); - uint32_t dst_noc_multicast_encoding = - NOC_MULTICAST_ENCODING(physical_start.x, physical_start.y, physical_end.x, physical_end.y); - uint32_t num_receivers = core_range.size(); - dst_noc_multicast_info.push_back(std::make_pair(dst_noc_multicast_encoding, num_receivers)); + dst_noc_multicast_info.push_back(std::make_pair(CoreRange(physical_start, physical_end), num_receivers)); } return dst_noc_multicast_info; } diff --git a/tt_metal/impl/program/program_device_map.hpp b/tt_metal/impl/program/program_device_map.hpp index e5c6d5cfd5a5..dc648887b133 100644 --- a/tt_metal/impl/program/program_device_map.hpp +++ b/tt_metal/impl/program/program_device_map.hpp @@ -16,9 +16,11 @@ struct transfer_info { bool linked; }; +using transfer_info_cores = std::variant; + struct transfer_info_2 { std::uint32_t dst_base_addr; - vector> dst_noc_info; // noc_encoding, num_mcast_dests + vector> dst_noc_info; // noc_encoding, num_mcast_dests bool linked; vector data; }; @@ -26,7 +28,7 @@ struct kernel_bins_transfer_info { vector dst_base_addrs; // BRISC, NCRISC, TRISC etc.. vector page_offsets; // offsets into paged buffer in DRAM vector lengths; // WriteLinear lengths - vector> dst_noc_info; // noc_encoding, num_mcast_dests + vector> dst_noc_info; // noc_encoding, num_mcast_dests bool linked; vector data; // all binaries' data for kernel group }; From b51aafb51a564776c0085a3b58f48d592c27d4d5 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Mon, 3 Jun 2024 13:16:34 +0000 Subject: [PATCH 12/53] #0: Allow reuse of event objects for EnqueueRecordEvent --- tt_metal/impl/dispatch/command_queue.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 59cf23af4f46..8b5ca124ab49 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -2309,9 +2309,6 @@ void EnqueueProgramImpl( } void EnqueueRecordEvent(CommandQueue& cq, std::shared_ptr event) { - TT_ASSERT(event->device == nullptr, "EnqueueRecordEvent expected to be given an uninitialized event"); - TT_ASSERT(event->event_id == -1, "EnqueueRecordEvent expected to be given an uninitialized event"); - TT_ASSERT(event->cq_id == -1, "EnqueueRecordEvent expected to be given an uninitialized event"); detail::DispatchStateCheck(true); cq.run_command(CommandInterface{ From 8cadabdc7188b91b4481cad371a209b5f2707ae7 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Mon, 3 Jun 2024 13:16:47 +0000 Subject: [PATCH 13/53] #8837: Add 2cq implementation of Resnet and add to ci --- .../demos/resnet/tests/test_metal_resnet50.py | 308 ++++++++------- .../resnet/tests/test_perf_accuracy_resnet.py | 1 + models/demos/resnet/tests/test_perf_resnet.py | 355 +++++++++--------- tests/scripts/run_performance.sh | 5 +- .../single_card/nightly/run_gs_only.sh | 2 + .../bert/test_performance.py | 2 +- .../whisper/test_performance.py | 2 +- 7 files changed, 358 insertions(+), 317 deletions(-) diff --git a/models/demos/resnet/tests/test_metal_resnet50.py b/models/demos/resnet/tests/test_metal_resnet50.py index b24297caab86..ad332a641c2b 100644 --- a/models/demos/resnet/tests/test_metal_resnet50.py +++ b/models/demos/resnet/tests/test_metal_resnet50.py @@ -8,7 +8,7 @@ import pytest import tt_lib -from models.utility_functions import is_e75, skip_for_wormhole_b0 +from models.utility_functions import is_e75, skip_for_wormhole_b0, divup from models.demos.resnet.tt.metalResnetBlock50 import ResNet, Bottleneck from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( @@ -117,26 +117,107 @@ } -@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) -@pytest.mark.parametrize("batch_size", [1, 2, 16, 20], ids=["batch_1", "batch_2", "batch_16", "batch_20"]) -@pytest.mark.parametrize( - "weights_dtype", - [tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.DataType.BFLOAT8_B], - ids=["weights_BFLOAT16", "weights_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "activations_dtype", - [tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.DataType.BFLOAT8_B], - ids=["activations_BFLOAT16", "activations_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "math_fidelity", - [tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi], - ids=["HiFi4", "HiFi2", "LoFi"], -) -def test_run_resnet50_inference( - device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +def run_model(device, tt_image, tt_resnet50): + tt_output = tt_resnet50(tt_image) + return tt_output.cpu(blocking=True) + + +def run_2cq_model(device, tt_image, tt_resnet50): + input_shape = tt_image.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_image.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_image.shape, tt_image.dtype, tt_image.layout, device, sharded_mem_config_DRAM + ) + op_event = tt_lib.device.CreateEvent() + write_event = tt_lib.device.CreateEvent() + # Initialize the op event so we can write + tt_lib.device.RecordEvent(device, 0, op_event) + + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_image, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + + # Test overlapping write + outputs = [] + for iter in range(0, 2): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_image, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + outputs.append(tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + return outputs[1] + + +def run_trace_model(device, tt_image, tt_resnet50): + input_shape = tt_image.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_image.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_image.shape, tt_image.dtype, tt_image.layout, device, sharded_mem_config_DRAM + ) + tt_lib.tensor.write_tensor(tt_image, tt_image_res) + + # Compile + tt_resnet50(tt_image_res) + # Trace + tid = tt_lib.device.BeginTraceCapture(device, 0, 1500000) + tt_output_res = tt_resnet50(tt_image_res) + tt_lib.device.EndTraceCapture(device, 0, tid) + + tt_lib.tensor.write_tensor(tt_image, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, True) + + # Done with the trace, can deallocate the buffers now. + tt_lib.device.ReleaseTrace(device, tid) + + return tt_output_res.cpu(blocking=True) + + +def run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_fn, ): if is_e75(device): pytest.skip("Resnet50 is not supported on E75") @@ -159,8 +240,6 @@ def test_run_resnet50_inference( with torch.no_grad(): torch.manual_seed(1234) - tt_lib.device.EnableMemoryReports() - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) torch_resnet50.eval() @@ -185,17 +264,8 @@ def test_run_resnet50_inference( torch_output = torch_resnet50(image).unsqueeze(1).unsqueeze(1) tt_image = tt_resnet50.preprocessing(image) - tt_output = tt_resnet50(tt_image) - tt_output = tt_output.cpu().to_torch().to(torch.float) - - # # run again to measure end to end perf - # start_time = datetime.now() - # tt_output = tt_resnet50(image) - # end_time = datetime.now() - # diff = end_time - start_time - # logger.info("End to end time (microseconds))", diff.microseconds) - # throughput_fps = (float) (1000000 / diff.microseconds) - # logger.info("Throughput (fps)", throughput_fps) + tt_output = run_fn(device, tt_image, tt_resnet50) + tt_output = tt_output.to_torch().to(torch.float) _, _, _, info = get_atol_rtol_pcc(torch_output, tt_output) logger.info(info) @@ -239,6 +309,72 @@ def test_run_resnet50_inference( [tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi], ids=["HiFi4", "HiFi2", "LoFi"], ) +def test_run_resnet50_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_2cqs_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_2cq_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) @pytest.mark.parametrize("enable_async", [True, False]) def test_run_resnet50_trace_inference( device, @@ -250,101 +386,17 @@ def test_run_resnet50_trace_inference( imagenet_sample_input, enable_async, ): - if is_e75(device): - pytest.skip("Resnet50 is not supported on E75") device.enable_async(enable_async) - if batch_size > 8 and ( - activations_dtype != tt_lib.tensor.DataType.BFLOAT8_B or weights_dtype != tt_lib.tensor.DataType.BFLOAT8_B - ): - pytest.skip("Batch > 8 must be run fully bfp8") - if batch_size <= 2: - pytest.skip("batch 1 and 2 are not supported with sharded data") - image1 = imagenet_sample_input - image = image1 - model_config = { - "MATH_FIDELITY": math_fidelity, - "WEIGHTS_DTYPE": weights_dtype, - "ACTIVATIONS_DTYPE": activations_dtype, - } - for i in range(batch_size - 1): - image = torch.cat((image, image1), dim=0) - with torch.no_grad(): - torch.manual_seed(1234) - - tt_lib.device.EnableMemoryReports() - - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) - torch_resnet50.eval() - - state_dict = torch_resnet50.state_dict() - storage_in_dram = False - sharded = False - if batch_size >= 8: - sharded = True - # run once to compile ops - tt_resnet50 = ResNet( - Bottleneck, - [3, 4, 6, 3], - device=device, - state_dict=state_dict, - base_address="", - fold_batchnorm=True, - storage_in_dram=storage_in_dram, - batch_size=batch_size, - model_config=model_config, - sharded=sharded, - ) - - torch_output = torch_resnet50(image).unsqueeze(1).unsqueeze(1) - interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig( - memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=tt_lib.tensor.BufferType.DRAM, - ) - - tt_image_res = tt_resnet50.preprocessing(image).to(device, interleaved_mem_config_DRAM) - # Compile - tt_resnet50(tt_image_res) - # Trace - tid = tt_lib.device.BeginTraceCapture(device, 0, 1334880) - tt_output_res = tt_resnet50(tt_image_res) - tt_lib.device.EndTraceCapture(device, 0, tid) + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_trace_model, + ) - tt_lib.device.ReplayTrace(device, 0, tid, True) - - tt_output = tt_output_res.cpu().to_torch().to(torch.float) - - # # run again to measure end to end perf - # start_time = datetime.now() - # tt_output = tt_resnet50(image) - # end_time = datetime.now() - # diff = end_time - start_time - # logger.info("End to end time (microseconds))", diff.microseconds) - # throughput_fps = (float) (1000000 / diff.microseconds) - # logger.info("Throughput (fps)", throughput_fps) - - _, _, _, info = get_atol_rtol_pcc(torch_output, tt_output) - logger.info(info) - - valid_pcc = 1.0 - if batch_size >= 8: - valid_pcc = golden_pcc[batch_size][ - (model_config["MATH_FIDELITY"], model_config["WEIGHTS_DTYPE"], model_config["ACTIVATIONS_DTYPE"]) - ] - else: - if model_config["ACTIVATIONS_DTYPE"] == tt_lib.tensor.DataType.BFLOAT8_B: - if model_config["MATH_FIDELITY"] == tt_lib.tensor.MathFidelity.LoFi: - valid_pcc = 0.87 - else: - valid_pcc = 0.94 - else: - if model_config["MATH_FIDELITY"] == tt_lib.tensor.MathFidelity.LoFi: - valid_pcc = 0.93 - else: - valid_pcc = 0.982 - passing_pcc, _ = comp_pcc(torch_output, tt_output, pcc=valid_pcc) - assert passing_pcc - # assert passing # fails because of torch.allclose - # Done with the trace, can deallocate the buffers now. - tt_lib.device.ReleaseTrace(device, tid) device.enable_async(False) diff --git a/models/demos/resnet/tests/test_perf_accuracy_resnet.py b/models/demos/resnet/tests/test_perf_accuracy_resnet.py index 722000caea57..6c719ebbf5b9 100644 --- a/models/demos/resnet/tests/test_perf_accuracy_resnet.py +++ b/models/demos/resnet/tests/test_perf_accuracy_resnet.py @@ -84,6 +84,7 @@ def run_perf_resnet( tt_output = tt_output.cpu().to_torch().to(torch.float) profiler.end(first_key) del tt_output + return enable_persistent_kernel_cache() diff --git a/models/demos/resnet/tests/test_perf_resnet.py b/models/demos/resnet/tests/test_perf_resnet.py index d572f544a229..94a52dfbec9e 100644 --- a/models/demos/resnet/tests/test_perf_resnet.py +++ b/models/demos/resnet/tests/test_perf_resnet.py @@ -22,12 +22,142 @@ } +def run_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + profiler.start("compile") + _ = tt_resnet50(tt_inputs).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + _ = tt_resnet50(tt_inputs).cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + outputs.append(tt_resnet50(tt_inputs).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + +def run_2cq_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + input_shape = tt_inputs.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_inputs.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_inputs.shape, tt_inputs.dtype, tt_inputs.layout, device, sharded_mem_config_DRAM + ) + op_event = tt_lib.device.CreateEvent() + write_event = tt_lib.device.CreateEvent() + # Initialize the op event so we can write + tt_lib.device.RecordEvent(device, 0, op_event) + + profiler.start("compile") + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + outputs.append(tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + +def run_trace_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + input_shape = tt_inputs.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_inputs.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_inputs.shape, tt_inputs.dtype, tt_inputs.layout, device, sharded_mem_config_DRAM + ) + # Compile + profiler.start("compile") + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_resnet50(tt_image_res).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + # Capture + tid = tt_lib.device.BeginTraceCapture(device, 0, 1500000) + tt_output_res = tt_resnet50(tt_image_res) + tt_lib.device.EndTraceCapture(device, 0, tid) + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, False) + _ = tt_output_res.cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, False) + outputs.append(tt_output_res.cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + def run_perf_resnet( batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, + model_version, ): disable_persistent_kernel_cache() if batch_size <= 2: @@ -67,6 +197,10 @@ def run_perf_resnet( model_config=model_config, sharded=sharded, ) + tt_lib.device.Synchronize(device) + + num_warmup_iterations = 5 + num_measurement_iterations = 15 with torch.no_grad(): profiler.start(cpu_key) @@ -74,69 +208,24 @@ def run_perf_resnet( profiler.end(cpu_key) tt_inputs = tt_resnet50.preprocessing(inputs) - input_shape = tt_inputs.get_legacy_shape() - shard_spec = tt_lib.tensor.ShardSpec( - tt_lib.tensor.CoreRangeSet( - { - tt_lib.tensor.CoreRange( - tt_lib.tensor.CoreCoord(0, 0), - tt_lib.tensor.CoreCoord(7, 0), - ) - } - ), - [ - divup(tt_inputs.volume() // input_shape[3], 8), - input_shape[3], - ], - tt_lib.tensor.ShardOrientation.ROW_MAJOR, - False, - ) - sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( - tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec - ) - tt_image_res = tt_lib.tensor.allocate_tensor_on_device( - tt_inputs.shape, tt_inputs.dtype, tt_inputs.layout, device, sharded_mem_config_DRAM - ) - op_event = tt_lib.device.CreateEvent() - write_event = tt_lib.device.CreateEvent() - # Initialize the op event so we can write - tt_lib.device.RecordEvent(device, 0, op_event) - warmup_end = 5 - for iter in range(0, warmup_end): - profiler.start(f"{iter}_key") - tt_lib.device.WaitForEvent(device, 1, op_event) - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) - tt_lib.device.RecordEvent(device, 1, write_event) - _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) - profiler.end(f"{iter}_key") - tt_lib.device.DumpDeviceProfiler(device) - - num_warm_iterations = 10 - warm_start = warmup_end - warm_end = warm_start + num_warm_iterations - - outputs = [] - profiler.start(f"run") - for iter in range(warm_start, warm_end): - tt_lib.device.WaitForEvent(device, 1, op_event) - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) - tt_lib.device.RecordEvent(device, 1, write_event) - outputs.append(tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=False)) - tt_lib.device.Synchronize(device) - profiler.end(f"run") - tt_lib.device.DumpDeviceProfiler(device) + if "resnet50_2cqs" in model_version: + run_2cq_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + elif "resnet50_trace" in model_version: + run_trace_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + elif "resnet50" in model_version: + run_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + else: + assert False, f"Model version to run {model_version} not found" - # enable_persistent_kernel_cache() - - first_iter_time = profiler.get(f"{0}_key") + first_iter_time = profiler.get(f"compile") # ensuring inference time fluctuations is not noise - inference_time_avg = profiler.get("run") / num_warm_iterations + inference_time_avg = profiler.get("run") / num_measurement_iterations cpu_time = profiler.get(cpu_key) compile_time = first_iter_time - inference_time_avg prep_perf_report( - model_name=f"resnet50_batch_size{batch_size}", + model_name=f"{model_version}_batch_size{batch_size}", batch_size=batch_size, inference_and_compile_time=first_iter_time, inference_time=inference_time_avg, @@ -146,20 +235,18 @@ def run_perf_resnet( inference_time_cpu=cpu_time, ) - logger.info(f"resnet50 {comments} inference time (avg): {inference_time_avg}") - logger.info(f"resnet50 compile time: {compile_time}") + logger.info(f"{model_name} {comments} inference time (avg): {inference_time_avg}") + logger.info(f"{model_name} compile time: {compile_time}") @skip_for_wormhole_b0(reason_str="Not tested on single WH") -@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( "batch_size, expected_inference_time, expected_compile_time", ( - # (1, 0.001, 1), - # (2, 0.001, 1), - # (16, 0.007, 7), - (20, 0.007, 7), + (16, 0.007, 16), + (20, 0.007, 16), ), ) def test_perf_bare_metal( @@ -174,145 +261,39 @@ def test_perf_bare_metal( pytest.skip("Resnet is not supported on E75") run_perf_resnet( - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, - device, + batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50" ) -def run_perf_resnet_trace( +@skip_for_wormhole_b0(reason_str="Not tested on single WH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, expected_inference_time, expected_compile_time", + ((20, 0.0055, 16),), +) +def test_perf_2cqs_bare_metal( + device, + use_program_cache, batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, - device, ): - disable_persistent_kernel_cache() - if batch_size <= 2: - pytest.skip("Batch size 1 and 2 are not supported with sharded data") - first_key = f"first_iter_batchsize{batch_size}" - second_key = f"second_iter_batchsize{batch_size}" - cpu_key = f"ref_key_batchsize{batch_size}" - model_name = "microsoft/resnet-50" - - image = hf_cat_image_sample_input - image_processor = AutoImageProcessor.from_pretrained(model_name) - inputs = image_processor(image, return_tensors="pt") - - inputs = inputs["pixel_values"] - comments = f"{list(inputs.shape)[-2]}x{list(inputs.shape)[-1]}_batchsize{batch_size}" - - inputs1 = inputs - for i in range(batch_size - 1): - inputs = torch.cat((inputs, inputs1), dim=0) - - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) - torch_resnet50.eval() - - state_dict = torch_resnet50.state_dict() - sharded = False - if batch_size >= 8: - sharded = True - tt_resnet50 = ResNet( - Bottleneck, - [3, 4, 6, 3], - device=device, - state_dict=state_dict, - base_address="", - fold_batchnorm=True, - storage_in_dram=False, - batch_size=batch_size, - model_config=model_config, - sharded=sharded, - ) - - with torch.no_grad(): - profiler.start(cpu_key) - logits = torch_resnet50(inputs) - profiler.end(cpu_key) - - tt_inputs = tt_resnet50.preprocessing(inputs) - interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig( - memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=tt_lib.tensor.BufferType.DRAM, - ) - tt_image_res = tt_inputs.to(device, interleaved_mem_config_DRAM) - # Compile - profiler.start(f"{0}_key") - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_resnet50(tt_image_res).cpu(blocking=True) - profiler.end(f"{0}_key") - tt_lib.device.DumpDeviceProfiler(device) - - # Capture - tid = tt_lib.device.BeginTraceCapture(device, 0, 1334880) - tt_output_res = tt_resnet50(tt_image_res) - tt_lib.device.EndTraceCapture(device, 0, tid) - tt_lib.device.DumpDeviceProfiler(device) - - warmup_end = 6 - for iter in range(1, warmup_end): - profiler.start(f"{iter}_key") - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_lib.device.ReplayTrace(device, 0, tid, False) - _ = tt_output_res.cpu(blocking=True) - profiler.end(f"{iter}_key") - tt_lib.device.DumpDeviceProfiler(device) - - num_warm_iterations = 15 - warm_start = warmup_end - warm_end = warm_start + num_warm_iterations - - outputs = [] - profiler.start(f"run") - for iter in range(warm_start, warm_end): - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_lib.device.ReplayTrace(device, 0, tid, False) - outputs.append(tt_output_res.cpu(blocking=False)) - tt_lib.device.Synchronize(device) - profiler.end(f"run") - tt_lib.device.DumpDeviceProfiler(device) - - # enable_persistent_kernel_cache() - - first_iter_time = profiler.get(f"{0}_key") - - # ensuring inference time fluctuations is not noise - inference_time_avg = profiler.get("run") / num_warm_iterations + if is_e75(device): + pytest.skip("Resnet is not supported on E75") - cpu_time = profiler.get(cpu_key) - compile_time = first_iter_time - inference_time_avg - prep_perf_report( - model_name=f"resnet50_trace_batch_size{batch_size}", - batch_size=batch_size, - inference_and_compile_time=first_iter_time, - inference_time=inference_time_avg, - expected_compile_time=expected_compile_time, - expected_inference_time=expected_inference_time, - comments=comments, - inference_time_cpu=cpu_time, + run_perf_resnet( + batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50_2cqs" ) - logger.info(f"resnet50 {comments} inference time (avg): {inference_time_avg}") - logger.info(f"resnet50 compile time: {compile_time}") - - tt_lib.device.ReleaseTrace(device, tid) - - assert inference_time_avg < expected_inference_time, f"resnet50 {comments} inference is too slow" - assert compile_time < expected_compile_time, f"resnet50 {comments} compilation is too slow" - @skip_for_wormhole_b0(reason_str="Not tested on single WH") @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( "batch_size, expected_inference_time, expected_compile_time", - ( - (16, 0.04, 25), - (20, 0.04, 25), - ), + ((20, 0.008, 16),), ) @pytest.mark.parametrize("enable_async", [True, False]) def test_perf_trace_bare_metal( @@ -327,11 +308,13 @@ def test_perf_trace_bare_metal( if is_e75(device): pytest.skip("Resnet is not supported on E75") device.enable_async(enable_async) - run_perf_resnet_trace( + mode = "async" if enable_async else "sync" + run_perf_resnet( batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, + f"resnet50_trace_{mode}", ) device.enable_async(False) diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 23cc2d0d0ba3..cd939e22cd56 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -17,7 +17,10 @@ run_perf_models_other() { env pytest models/demos/ttnn_falcon7b/tests -m $test_marker - env pytest models/demos/resnet/tests -m $test_marker + # Separate calls since we can't mix switching between number of cqs + env pytest models/demos/resnet/tests/test_perf_resnet.py::test_perf_bare_metal -m $test_marker + env pytest models/demos/resnet/tests/test_perf_resnet.py::test_perf_2cqs_bare_metal -m $test_marker + env pytest models/demos/resnet/tests/test_perf_resnet.py::test_perf_trace_bare_metal -m $test_marker env pytest tests/ttnn/integration_tests/whisper/test_performance.py -m $test_marker diff --git a/tests/scripts/single_card/nightly/run_gs_only.sh b/tests/scripts/single_card/nightly/run_gs_only.sh index 9973f35b7bda..f64956aea6b8 100755 --- a/tests/scripts/single_card/nightly/run_gs_only.sh +++ b/tests/scripts/single_card/nightly/run_gs_only.sh @@ -13,4 +13,6 @@ env pytest models/demos/metal_BERT_large_11/tests/test_demo.py env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_inference[LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0] +env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_2cqs_inference[LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0] + env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_trace_inference -k "LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0" diff --git a/tests/ttnn/integration_tests/bert/test_performance.py b/tests/ttnn/integration_tests/bert/test_performance.py index 034df32b53d3..e29b0a44329e 100644 --- a/tests/ttnn/integration_tests/bert/test_performance.py +++ b/tests/ttnn/integration_tests/bert/test_performance.py @@ -59,7 +59,7 @@ def get_expected_times(bert): return { ttnn_bert: (0.1, 0.1), ttnn_optimized_bert: (5.5, 0.07), - ttnn_optimized_sharded_bert: (5.2, 0.07), + ttnn_optimized_sharded_bert: (5.5, 0.07), }[bert] diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/tests/ttnn/integration_tests/whisper/test_performance.py index b88669f43d9d..41c559c5ef04 100644 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ b/tests/ttnn/integration_tests/whisper/test_performance.py @@ -17,7 +17,7 @@ def get_expected_times(functional_whisper): return { - ttnn_functional_whisper: (10.5, 4.16), + ttnn_functional_whisper: (11, 4.16), ttnn_optimized_functional_whisper: (1.2, 1.35), }[functional_whisper] From 626e6de69224deefca68896c06a712d8dbd33dca Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Mon, 3 Jun 2024 17:14:03 +0000 Subject: [PATCH 14/53] #0: Split 2cq tests into separate files to follow convention --- .../test_metal_resnet50_2cqs_performant.py | 42 +++++++++ .../tests/test_metal_resnet50_performant.py | 87 +++++++++++++++++++ models/demos/resnet/tests/test_perf_resnet.py | 27 +----- .../resnet/tests/test_perf_resnet_2cqs.py | 28 ++++++ tests/scripts/run_performance.sh | 5 +- .../single_card/nightly/run_gs_only.sh | 6 +- 6 files changed, 163 insertions(+), 32 deletions(-) create mode 100644 models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py create mode 100644 models/demos/resnet/tests/test_metal_resnet50_performant.py create mode 100644 models/demos/resnet/tests/test_perf_resnet_2cqs.py diff --git a/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py b/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py new file mode 100644 index 000000000000..6bb3147c6d32 --- /dev/null +++ b/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tt_lib + +from models.demos.resnet.tests.test_metal_resnet50 import run_resnet50_inference, run_2cq_model +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_2cqs_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_2cq_model, + ) diff --git a/models/demos/resnet/tests/test_metal_resnet50_performant.py b/models/demos/resnet/tests/test_metal_resnet50_performant.py new file mode 100644 index 000000000000..cbd266c568c9 --- /dev/null +++ b/models/demos/resnet/tests/test_metal_resnet50_performant.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tt_lib + +from models.demos.resnet.tests.test_metal_resnet50 import run_resnet50_inference, run_model, run_trace_model +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +@pytest.mark.parametrize("enable_async", [True, False]) +def test_run_resnet50_trace_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + enable_async, +): + device.enable_async(enable_async) + + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_trace_model, + ) + + device.enable_async(False) diff --git a/models/demos/resnet/tests/test_perf_resnet.py b/models/demos/resnet/tests/test_perf_resnet.py index 94a52dfbec9e..a93c82876c97 100644 --- a/models/demos/resnet/tests/test_perf_resnet.py +++ b/models/demos/resnet/tests/test_perf_resnet.py @@ -159,6 +159,8 @@ def run_perf_resnet( device, model_version, ): + if is_e75(device): + pytest.skip("Resnet is not supported on E75") disable_persistent_kernel_cache() if batch_size <= 2: pytest.skip("Batch size 1 and 2 are not supported with sharded data") @@ -265,29 +267,6 @@ def test_perf_bare_metal( ) -@skip_for_wormhole_b0(reason_str="Not tested on single WH") -@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) -@pytest.mark.models_performance_bare_metal -@pytest.mark.parametrize( - "batch_size, expected_inference_time, expected_compile_time", - ((20, 0.0055, 16),), -) -def test_perf_2cqs_bare_metal( - device, - use_program_cache, - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, -): - if is_e75(device): - pytest.skip("Resnet is not supported on E75") - - run_perf_resnet( - batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50_2cqs" - ) - - @skip_for_wormhole_b0(reason_str="Not tested on single WH") @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) @pytest.mark.models_performance_bare_metal @@ -305,8 +284,6 @@ def test_perf_trace_bare_metal( hf_cat_image_sample_input, enable_async, ): - if is_e75(device): - pytest.skip("Resnet is not supported on E75") device.enable_async(enable_async) mode = "async" if enable_async else "sync" run_perf_resnet( diff --git a/models/demos/resnet/tests/test_perf_resnet_2cqs.py b/models/demos/resnet/tests/test_perf_resnet_2cqs.py new file mode 100644 index 000000000000..eddbc1bf4ed7 --- /dev/null +++ b/models/demos/resnet/tests/test_perf_resnet_2cqs.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from models.demos.resnet.tests.test_perf_resnet import run_perf_resnet +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0(reason_str="Not tested on single WH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, expected_inference_time, expected_compile_time", + ((20, 0.0055, 16),), +) +def test_perf_2cqs_bare_metal( + device, + use_program_cache, + batch_size, + expected_inference_time, + expected_compile_time, + hf_cat_image_sample_input, +): + run_perf_resnet( + batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50_2cqs" + ) diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index cd939e22cd56..e535e635d451 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -18,9 +18,8 @@ run_perf_models_other() { env pytest models/demos/ttnn_falcon7b/tests -m $test_marker # Separate calls since we can't mix switching between number of cqs - env pytest models/demos/resnet/tests/test_perf_resnet.py::test_perf_bare_metal -m $test_marker - env pytest models/demos/resnet/tests/test_perf_resnet.py::test_perf_2cqs_bare_metal -m $test_marker - env pytest models/demos/resnet/tests/test_perf_resnet.py::test_perf_trace_bare_metal -m $test_marker + env pytest models/demos/resnet/tests/test_perf_resnet.py -m $test_marker + env pytest models/demos/resnet/tests/test_perf_resnet_2cqs.py -m $test_marker env pytest tests/ttnn/integration_tests/whisper/test_performance.py -m $test_marker diff --git a/tests/scripts/single_card/nightly/run_gs_only.sh b/tests/scripts/single_card/nightly/run_gs_only.sh index f64956aea6b8..36ed969d4a04 100755 --- a/tests/scripts/single_card/nightly/run_gs_only.sh +++ b/tests/scripts/single_card/nightly/run_gs_only.sh @@ -11,8 +11,6 @@ echo "Running model nightly tests for GS only" env pytest models/demos/metal_BERT_large_11/tests/test_demo.py -env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_inference[LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0] +env pytest models/demos/resnet/tests/test_metal_resnet50_performant.py -env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_2cqs_inference[LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0] - -env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_trace_inference -k "LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0" +env pytest models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py From d771a746b1083bb0e8fe68300b315a22cd0848ee Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Tue, 4 Jun 2024 05:41:11 +0000 Subject: [PATCH 15/53] #0: Add NOC_XY_PCIE_ENCODING specifically for pcie cores since WH has an additional address offset --- .../kernels/pull_from_pcie.cpp | 2 +- .../command_queue/pcie_write_16b.cpp | 2 +- .../hw/inc/blackhole/noc/noc_parameters.h | 3 + tt_metal/hw/inc/dataflow_api.h | 2 +- .../hw/inc/grayskull/noc/noc_parameters.h | 2 + tt_metal/hw/inc/wormhole/noc/noc_parameters.h | 12 +- .../impl/dispatch/kernels/cq_dispatch.cpp | 2 +- .../impl/dispatch/kernels/cq_prefetch.cpp | 2 +- .../impl/dispatch/kernels/cq_prefetch.hpp | 674 ------------------ 9 files changed, 20 insertions(+), 681 deletions(-) delete mode 100644 tt_metal/impl/dispatch/kernels/cq_prefetch.hpp diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp index 9ae0f0adffba..9f94b540aafa 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp @@ -17,7 +17,7 @@ void kernel_main() { volatile tt_l1_ptr uint32_t* done_address = reinterpret_cast(L1_UNRESERVED_BASE); while (done_address[0] == 0) { - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); + uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX), pcie_read_ptr); noc_async_read(host_src_addr, L1_UNRESERVED_BASE, read_sizeB); pcie_read_ptr += read_sizeB; if (pcie_read_ptr > pcie_base + pcie_sizeB) { diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp index 05c4a338ff59..ac8945a4d6d7 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp @@ -11,7 +11,7 @@ void kernel_main() { constexpr uint32_t base_pcie_dst_address = get_compile_time_arg_val(1); constexpr uint32_t num_16b_writes = get_compile_time_arg_val(2); - uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)) << 32; + uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX)) << 32; uint32_t l1_src_address = base_l1_src_address; uint32_t pcie_dst_address = base_pcie_dst_address; diff --git a/tt_metal/hw/inc/blackhole/noc/noc_parameters.h b/tt_metal/hw/inc/blackhole/noc/noc_parameters.h index 8b8e9ad14150..7f6529f9915e 100644 --- a/tt_metal/hw/inc/blackhole/noc/noc_parameters.h +++ b/tt_metal/hw/inc/blackhole/noc/noc_parameters.h @@ -14,6 +14,9 @@ #define NOC_XY_ENCODING(x, y) \ ((((uint64_t)(y)) << (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) | (((uint64_t)(x)) << NOC_ADDR_LOCAL_BITS)) +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ + NOC_XY_ENCODING(x, y) + #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ ((((uint64_t)(x_start)) << (NOC_ADDR_LOCAL_BITS + 2 * NOC_ADDR_NODE_ID_BITS)) | \ (((uint64_t)(y_start)) << (NOC_ADDR_LOCAL_BITS + 3 * NOC_ADDR_NODE_ID_BITS)) | \ diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index 91b1a26f8f39..12df89b03dec 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -476,7 +476,7 @@ uint64_t get_l1_noc_addr(const uint32_t id, const uint32_t page_size, const uint } uint64_t get_system_memory_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t base_addr, const uint32_t offset = 0) { - constexpr static uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)) << 32; + uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(NOC_X(PCIE_NOC_X), NOC_Y(PCIE_NOC_Y), noc_index)) << 32; uint32_t addr = base_addr + page_size * id + offset; uint64_t noc_addr = pcie_core_noc_encoding | addr; return noc_addr; diff --git a/tt_metal/hw/inc/grayskull/noc/noc_parameters.h b/tt_metal/hw/inc/grayskull/noc/noc_parameters.h index 3fa07c452942..ed13f98ea8fd 100644 --- a/tt_metal/hw/inc/grayskull/noc/noc_parameters.h +++ b/tt_metal/hw/inc/grayskull/noc/noc_parameters.h @@ -12,6 +12,8 @@ // Address formats #define NOC_XY_ENCODING(x, y) ((((uint32_t)(y)) << (NOC_ADDR_NODE_ID_BITS)) | (((uint32_t)(x)))) +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) NOC_XY_ENCODING(x, y) + #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ ((x_start) << (2 * NOC_ADDR_NODE_ID_BITS)) | ((y_start) << (3 * NOC_ADDR_NODE_ID_BITS)) | (x_end) | \ ((y_end) << (NOC_ADDR_NODE_ID_BITS)) diff --git a/tt_metal/hw/inc/wormhole/noc/noc_parameters.h b/tt_metal/hw/inc/wormhole/noc/noc_parameters.h index 0a2256ffeebe..f6b361d3ff3f 100644 --- a/tt_metal/hw/inc/wormhole/noc/noc_parameters.h +++ b/tt_metal/hw/inc/wormhole/noc/noc_parameters.h @@ -9,13 +9,21 @@ #define PCIE_NOC_X 0 #define PCIE_NOC_Y 3 +#define PCIE_NOC1_X 9 +#define PCIE_NOC1_Y 8 + // Address formats #define NOC_XY_ENCODING(x, y) \ (((uint32_t)(y)) << ((NOC_ADDR_LOCAL_BITS % 32)+NOC_ADDR_NODE_ID_BITS)) | \ - (((uint32_t)(x)) << (NOC_ADDR_LOCAL_BITS % 32)) | ((x == PCIE_NOC_X and y == PCIE_NOC_Y) * 0x8) \ + (((uint32_t)(x)) << (NOC_ADDR_LOCAL_BITS % 32)) \ + +// Address formats +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ + NOC_XY_ENCODING(x, y) | \ + ((noc_index ? (x == PCIE_NOC1_X and y == PCIE_NOC1_Y) : (x == PCIE_NOC_X and y == PCIE_NOC_Y)) * 0x8) \ #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ - (((uint32_t)(x_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+2*NOC_ADDR_NODE_ID_BITS)) | \ + (((uint32_t)(x_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+2*NOC_ADDR_NODE_ID_BITS)) | \ (((uint32_t)(y_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+3*NOC_ADDR_NODE_ID_BITS)) | \ (((uint32_t)(x_end)) << (NOC_ADDR_LOCAL_BITS % 32)) | \ (((uint32_t)(y_end)) << ((NOC_ADDR_LOCAL_BITS % 32)+NOC_ADDR_NODE_ID_BITS)) \ diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp index 75b525d0a91b..8002bd017049 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp @@ -43,7 +43,7 @@ constexpr uint32_t is_h_variant = get_compile_time_arg_val(16); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); -constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y))); +constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_PCIE_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y), NOC_INDEX)); constexpr uint32_t dispatch_cb_page_size = 1 << dispatch_cb_log_page_size; constexpr uint32_t completion_queue_end_addr = completion_queue_base_addr + completion_queue_size; diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp index 0ee658ad1c2b..0124d992b2c4 100644 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp @@ -52,7 +52,7 @@ constexpr uint32_t is_h_variant = get_compile_time_arg_val(22); constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); -constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y))); +constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_PCIE_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y), NOC_INDEX)); constexpr uint32_t downstream_cb_page_size = 1 << downstream_cb_log_page_size; constexpr uint32_t downstream_cb_end = downstream_cb_base + (1 << downstream_cb_log_page_size) * downstream_cb_pages; constexpr uint32_t prefetch_q_end = prefetch_q_base + prefetch_q_size; diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp deleted file mode 100644 index 036316ee43a0..000000000000 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp +++ /dev/null @@ -1,674 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -// Common prefetch code for use by _hd, _h, _d prefetch variants - -#include "dataflow_api.h" -#include "debug/dprint.h" -#include "tt_metal/impl/dispatch/kernels/cq_common.hpp" - -extern const uint32_t scratch_db_top[2]; - - -template -FORCE_INLINE -void write_downstream(uint32_t& data_ptr, - uint32_t& downstream_data_ptr, - uint32_t length) { - - uint32_t remaining = cb_end - downstream_data_ptr; - if (length > remaining) { - if (remaining > 0) { - noc_async_write(data_ptr, get_noc_addr_helper(downstream_noc_xy, downstream_data_ptr), remaining); - data_ptr += remaining; - length -= remaining; - } - downstream_data_ptr = cb_base; - } - - noc_async_write(data_ptr, get_noc_addr_helper(downstream_noc_xy, downstream_data_ptr), length); - downstream_data_ptr += length; -} - -template -FORCE_INLINE -void read_from_pcie(volatile tt_l1_ptr uint16_t *& prefetch_q_rd_ptr, - uint32_t& pending_read_size, - uint32_t& fence, - uint32_t& pcie_read_ptr, - uint32_t cmd_ptr, - uint32_t size) { - - // Wrap cmddat_q - if (fence + size + preamble_size > cmddat_q_base + cmddat_q_size) { - // only wrap if there are no commands ready, otherwise we'll leave some on the floor - // TODO: does this matter for perf? - if (cmd_ptr != fence) { - return; - } - fence = cmddat_q_base; - } - - // Wrap pcie/hugepage - if (pcie_read_ptr + size > pcie_base + pcie_size) { - pcie_read_ptr = pcie_base; - } - - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(NOC_X(PCIE_NOC_X), NOC_Y(PCIE_NOC_Y)), pcie_read_ptr); - noc_async_read(host_src_addr, fence + preamble_size, size); - pending_read_size = size + preamble_size; - pcie_read_ptr += size; - - *prefetch_q_rd_ptr = 0; - - // Tell host we read - *(volatile tt_l1_ptr uint32_t *) prefetch_q_rd_ptr_addr = (uint32_t)prefetch_q_rd_ptr; - - prefetch_q_rd_ptr++; - - // Wrap prefetch_q - if ((uint32_t)prefetch_q_rd_ptr == prefetch_q_end) { - prefetch_q_rd_ptr = (volatile tt_l1_ptr uint16_t*)prefetch_q_base; - } -} - -// This routine can be called in 8 states based on the boolean values cmd_ready, prefetch_q_ready, read_pending: -// - !cmd_ready, !prefetch_q_ready, !read_pending: stall on prefetch_q, issue read, read barrier -// - !cmd_ready, !prefetch_q_ready, read pending: read barrier (and re-evaluate prefetch_q_ready) -// - !cmd_ready, prefetch_q_ready, !read_pending: issue read, read barrier (XXXX +issue read after?) -// - !cmd_ready, prefetch_q_ready, read_pending: read barrier, issue read -// - cmd_ready, !prefetch_q_ready, !read_pending: exit -// - cmd_ready, !prefetch_q_ready, read_pending: exit (no barrier yet) -// - cmd_ready, prefetch_q_ready, !read_pending: issue read -// - cmd_ready, prefetch_q_ready, read_pending: exit (don't add latency to the in flight request) -// -// With WH tagging of reads: -// open question: should fetcher loop on prefetch_q_ready issuing reads until !prefetch_q_ready -// - !cmd_ready, !prefetch_q_ready, !read_pending: stall on prefetch_q, issue read, read barrier -// - !cmd_ready, !prefetch_q_ready, read pending: read barrier on oldest tag -// - !cmd_ready, prefetch_q_ready, !read_pending: issue read, read barrier (XXXX +retry after?) -// - !cmd_ready, prefetch_q_ready, read_pending: issue read, read barrier on oldest tag -// - cmd_ready, !prefetch_q_ready, !read_pending: exit -// - cmd_ready, !prefetch_q_ready, read_pending: exit (no barrier yet) -// - cmd_ready, prefetch_q_ready, !read_pending: issue and tag read -// - cmd_ready, prefetch_q_ready, read_pending: issue and tag read -template -void fetch_q_get_cmds(uint32_t& fence, uint32_t& cmd_ptr, uint32_t& pcie_read_ptr) { - - static uint32_t pending_read_size = 0; - static volatile tt_l1_ptr uint16_t* prefetch_q_rd_ptr = (volatile tt_l1_ptr uint16_t*)prefetch_q_base; - - if (fence < cmd_ptr) { - DPRINT << "wrap cmd ptr1 " << fence << " " << cmd_ptr << ENDL(); - cmd_ptr = fence; - } - - bool cmd_ready = (cmd_ptr != fence); - uint32_t fetch_size = (uint32_t)*prefetch_q_rd_ptr << prefetch_q_log_minsize; - - if (fetch_size != 0 && pending_read_size == 0) { - DPRINT << "read1: " << (uint32_t)prefetch_q_rd_ptr << " " << " " << fence << " " << fetch_size << ENDL(); - read_from_pcie - (prefetch_q_rd_ptr, pending_read_size, fence, pcie_read_ptr, cmd_ptr, fetch_size); - } - if (!cmd_ready) { - if (pending_read_size != 0) { - DPRINT << "barrier" << ENDL(); - noc_async_read_barrier(); - - // wrap the cmddat_q - if (fence < cmd_ptr) { - cmd_ptr = fence; - } - - fence += pending_read_size; - pending_read_size = 0; - // After the stall, re-check the host - fetch_size = (uint32_t)*prefetch_q_rd_ptr << prefetch_q_log_minsize; - if (fetch_size != 0) { - read_from_pcie - (prefetch_q_rd_ptr, pending_read_size, fence, pcie_read_ptr, cmd_ptr, fetch_size); - } - } else { - // By here, prefetch_q_ready must be false - // Nothing to fetch, nothing pending, nothing available, stall on host - DEBUG_STATUS("HQW"); - DPRINT << "prefetcher stall" << ENDL(); - while ((fetch_size = *prefetch_q_rd_ptr) == 0); - DPRINT << "recurse" << ENDL(); - fetch_q_get_cmds(fence, cmd_ptr, pcie_read_ptr); - DEBUG_STATUS("HQD"); - } - } -} - -template -uint32_t process_debug_cmd(uint32_t cmd_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t checksum = 0; - uint32_t data_start = (uint32_t)cmd + sizeof(CQPrefetchCmd); - uint32_t *data = (uint32_t *)data_start; - uint32_t size = cmd->debug.size; - - uint32_t front_size = (size <= cmddat_end - data_start) ? size : cmddat_end - data_start; - for (uint32_t i = 0; i < front_size / sizeof(uint32_t); i++) { - checksum += *data++; - } - uint32_t back_size = size - front_size; - if (back_size > 0) { - data = (uint32_t *)cmddat_base; - for (uint32_t i = 0; i < back_size / sizeof(uint32_t); i++) { - checksum += *data++; - } - } - - if (checksum != cmd->debug.checksum) { - DEBUG_STATUS("!CHK"); - ASSERT(0); - } - - return cmd->debug.stride; -} - -template -static uint32_t process_relay_inline_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - - uint32_t length = cmd->relay_inline.length; - uint32_t data_ptr = cmd_ptr + sizeof(CQPrefetchCmd); - - uint32_t npages = (length + cb_page_size - 1) >> cb_log_page_size; - - // Assume the downstream buffer is big relative to cmddat command size that we can - // grab what we need in one chunk - cb_acquire_pages(npages); - - uint32_t remaining = cmddat_end - data_ptr; - if (cmddat_wrap_enable && length > remaining) { - // wrap cmddat - write_downstream(data_ptr, dispatch_data_ptr, remaining); - length -= remaining; - data_ptr = cmddat_base; - } - - DPRINT << my_noc_xy << " " << dispatch_noc_xy << " " << cb_base << ENDL(); - write_downstream(data_ptr, dispatch_data_ptr, length); - - // Round to nearest page - dispatch_data_ptr += (cb_page_size - (dispatch_data_ptr & (cb_page_size - 1))) & (cb_page_size - 1); - - // XXXXX - painful syncing right now? move this into get_cmds - noc_async_writes_flushed(); - cb_release_pages(npages); - - return cmd->relay_inline.stride; -} - -// This version of inline sends inline data to the dispatcher but doesn't flush the page to the dispatcher -// This is used to assemble dispatcher commands when data comes out of band, eg, reading from DRAM -// That means this command is stateful, incorrect use will be...bad -// NOTE: this routine assumes we're sending a command header and that is LESS THAN A PAGE -template -static uint32_t process_relay_inline_noflush_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - - uint32_t length = sizeof(CQDispatchCmd); - uint32_t data_ptr = cmd_ptr + sizeof(CQPrefetchCmd); - - cb_acquire_pages(1); - if (dispatch_data_ptr == cb_end) { - dispatch_data_ptr = cb_base; - } - noc_async_write(data_ptr, get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr), length); - dispatch_data_ptr += length; - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -static uint32_t write_pages_to_dispatcher(uint32_t& dispatch_data_ptr, - uint32_t& scratch_write_addr, - uint32_t& amt_to_write) { - - uint32_t page_residual_space = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - uint32_t npages = (amt_to_write - page_residual_space + dispatch_cb_page_size + extra_space - 1) / dispatch_cb_page_size; - - // Grabbing all pages at once is ok if scratch_size < 3 * dispatch_cb_block_size - if (!test_for_nonzero || npages != 0) { - cb_acquire_pages(npages); - } - - uint64_t noc_addr = get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr); - if (dispatch_data_ptr == dispatch_cb_end) { - dispatch_data_ptr = dispatch_cb_base; - } else if (dispatch_data_ptr + amt_to_write > dispatch_cb_end) { // wrap - uint32_t last_chunk_size = dispatch_cb_end - dispatch_data_ptr; - noc_async_write(scratch_write_addr, noc_addr, last_chunk_size); - dispatch_data_ptr = dispatch_cb_base; - scratch_write_addr += last_chunk_size; - amt_to_write -= last_chunk_size; - noc_addr = get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr); - } - - noc_async_write(scratch_write_addr, noc_addr, amt_to_write); - dispatch_data_ptr += amt_to_write; - - return npages; -} - -// This fn prefetches data from DRAM memory and writes data to the dispatch core. -// Reading from DRAM has the following characteristics: -// - latency is moderately high ~400 cycles on WH -// - DRAM bw is ~maximized when page size reaches 2K -// - for kernel dispatch, it is expected that page sizes will often be <2K -// - for buffer writing, page sizes will vary -// - writing to dispatcher works best with 4K pages (2K pages cover overhead, 4K gives perf cushion) -// - writing a 4K page takes ~32*4=128 cycles -// - writing 4 4K pages is 512 cycles, close to parity w/ the latency of DRAM -// - to hide the latency (~12% overhead), assume we need to read ~32 pages=128K, double buffered -// - in other words, we'll never achieve high efficiency and always be (somewhat) latency bound -// Algorithm does: -// - read a batch from DRAM -// - loop: read a batch from DRAM while sending to dispatcher -// - send a batch to dispatcher -// The size of the first read should be based on latency. With small page sizes -// bandwidth will be low and we'll be DRAM bound (send to dispatcher is ~free). -// With larger pages we'll get closer to a bandwidth match -// The dispatch buffer is a ring buffer. -template -uint32_t process_relay_paged_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - // This ensures that a previous cmd using the scratch buf has finished - noc_async_writes_flushed(); - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t page_id = cmd->relay_paged.start_page; - uint32_t base_addr = cmd->relay_paged.base_addr; - uint32_t page_size = cmd->relay_paged.page_size; - uint32_t pages = cmd->relay_paged.pages; - uint32_t read_length = pages * page_size; - - InterleavedAddrGen addr_gen; - addr_gen.bank_base_address = base_addr; - addr_gen.page_size = page_size; - - // First step - read into DB0 - uint32_t scratch_read_addr = scratch_db_top[0]; - uint32_t amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - uint32_t amt_read = 0; - while (amt_to_read >= page_size) { - uint64_t noc_addr = addr_gen.get_noc_addr(page_id); // XXXX replace this w/ walking the banks to save mul on GS - noc_async_read(noc_addr, scratch_read_addr, page_size); - scratch_read_addr += page_size; - page_id++; - amt_to_read -= page_size; - amt_read += page_size; - } - noc_async_read_barrier(); - - // Second step - read into DB[x], write from DB[x], toggle x, iterate - // Writes are fast, reads are slow - uint32_t db_toggle = 0; - uint32_t scratch_write_addr; - read_length -= amt_read; - while (read_length != 0) { - // This ensures that writes from prior iteration are done - // TODO(pgk); we can do better on WH w/ tagging - noc_async_writes_flushed(); - - db_toggle ^= 1; - scratch_read_addr = scratch_db_top[db_toggle]; - scratch_write_addr = scratch_db_top[db_toggle ^ 1]; - - uint32_t amt_to_write = amt_read; - amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - amt_read = 0; - while (amt_to_read >= page_size) { - uint64_t noc_addr = addr_gen.get_noc_addr(page_id); // XXXX replace this w/ walking the banks to save mul on GS - noc_async_read(noc_addr, scratch_read_addr, page_size); - scratch_read_addr += page_size; - page_id++; - amt_to_read -= page_size; - amt_read += page_size; - } - - // Third step - write from DB - uint32_t npages = write_pages_to_dispatcher< - 0, - false, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - cb_release_pages(npages); - - read_length -= amt_read; - - // TODO(pgk); we can do better on WH w/ tagging - noc_async_read_barrier(); - } - - // Third step - write from DB - scratch_write_addr = scratch_db_top[db_toggle]; - uint32_t amt_to_write = amt_read; - uint32_t npages = write_pages_to_dispatcher< - CQ_DISPATCH_CMD_SIZE, - true, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - uint32_t pad_to_page = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - dispatch_data_ptr += pad_to_page; - - // One page was acquired w/ the cmd in CMD_RELAY_INLINE_NOFLUSH - cb_release_pages(npages + 1); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -uint32_t process_relay_linear_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - // This ensures that a previous cmd using the scratch buf has finished - noc_async_writes_flushed(); - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t noc_xy_addr = cmd->relay_linear.noc_xy_addr; - uint32_t read_addr = cmd->relay_linear.addr; - uint32_t length = cmd->relay_linear.length; - uint32_t read_length = length; - - // First step - read into DB0 - uint32_t scratch_read_addr = scratch_db_top[0]; - uint32_t amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - uint64_t noc_addr = get_noc_addr_helper(noc_xy_addr, read_addr); - noc_async_read(noc_addr, scratch_read_addr, amt_to_read); - read_addr += amt_to_read; - noc_async_read_barrier(); - - // Second step - read into DB[x], write from DB[x], toggle x, iterate - // Writes are fast, reads are slow - uint32_t db_toggle = 0; - uint32_t scratch_write_addr; - read_length -= amt_to_read; - while (read_length != 0) { - // This ensures that writes from prior iteration are done - // TODO(pgk); we can do better on WH w/ tagging - noc_async_writes_flushed(); - - db_toggle ^= 1; - scratch_read_addr = scratch_db_top[db_toggle]; - scratch_write_addr = scratch_db_top[db_toggle ^ 1]; - - uint32_t amt_to_write = amt_to_read; - amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - noc_addr = get_noc_addr_helper(noc_xy_addr, read_addr); - noc_async_read(noc_addr, scratch_read_addr, amt_to_read); - read_addr += amt_to_read; - - // Third step - write from DB - uint32_t npages = write_pages_to_dispatcher< - 0, - false, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - cb_release_pages(npages); - - read_length -= amt_to_read; - - // TODO(pgk); we can do better on WH w/ tagging - noc_async_read_barrier(); - } - - // Third step - write from DB - scratch_write_addr = scratch_db_top[db_toggle]; - uint32_t amt_to_write = amt_to_read; - uint32_t npages = write_pages_to_dispatcher< - CQ_DISPATCH_CMD_SIZE, - true, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - uint32_t pad_to_page = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - dispatch_data_ptr += pad_to_page; - - // One page was acquired w/ the cmd in CMD_RELAY_INLINE_NOFLUSH - cb_release_pages(npages + 1); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -uint32_t process_stall(uint32_t cmd_ptr) { - - static uint32_t count = 0; - - count++; - - DEBUG_STATUS("PSW"); - volatile tt_l1_ptr uint32_t* sem_addr = - reinterpret_cast(get_semaphore(dispatch_sync_sem_id)); - while (*sem_addr != count); - DEBUG_STATUS("PSD"); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -bool process_cmd(uint32_t cmd_ptr, - uint32_t& downstream_data_ptr, - uint32_t& stride) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - bool done = false; - - switch (cmd->base.cmd_id) { - case CQ_PREFETCH_CMD_RELAY_LINEAR: - DPRINT << "relay linear: " << cmd_ptr << ENDL(); - stride = process_relay_linear_cmd< - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_RELAY_PAGED: - DPRINT << "relay dram page: " << cmd_ptr << ENDL(); - if (cmd->relay_paged.is_dram) { - stride = process_relay_paged_cmd< - true, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - } else { - stride = process_relay_paged_cmd< - false, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - } - break; - - case CQ_PREFETCH_CMD_RELAY_INLINE: - DPRINT << "inline" << ENDL(); - stride = process_relay_inline_cmd< - cmddat_wrap_enable, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - cmddat_base, - cmddat_end, - downstream_cb_base, - downstream_cb_end, - downstream_cb_log_page_size, - downstream_cb_page_size>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_RELAY_INLINE_NOFLUSH: - DPRINT << "inline no flush" << ENDL(); - stride = process_relay_inline_noflush_cmd< - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_base, - downstream_cb_end>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_STALL: - DPRINT << "stall" << ENDL(); - stride = process_stall(cmd_ptr); - break; - - case CQ_PREFETCH_CMD_DEBUG: - DPRINT << "debug" << ENDL(); - stride = process_debug_cmd(cmd_ptr); - break; - - case CQ_PREFETCH_CMD_TERMINATE: - DPRINT << "terminating\n"; - done = true; - break; - - default: - DPRINT << "prefetch invalid command:" << (uint32_t)cmd->base.cmd_id << " " << cmd_ptr << " " << ENDL(); - DPRINT << HEX() << *(uint32_t*)cmd_ptr << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+1) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+2) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+3) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+4) << ENDL(); - DEBUG_STATUS("!CMD"); - ASSERT(0); - } - - return done; -} From aaecfba5ba1978681425cb3496956ca2c4f235a5 Mon Sep 17 00:00:00 2001 From: Michael Chiou Date: Mon, 3 Jun 2024 17:04:08 -0700 Subject: [PATCH 16/53] #9084: Rename dockerfile and added virtualenv installation Automatically installs python venv and all dependencies. Also sources it by default by adding to PATH --- ...ckerfile => ubuntu-20.04-amd64.Dockerfile} | 23 +++++++++++-------- scripts/docker/build_docker_image.sh | 2 +- 2 files changed, 14 insertions(+), 11 deletions(-) rename dockerfile/{ubuntu-20.04-x86.Dockerfile => ubuntu-20.04-amd64.Dockerfile} (56%) diff --git a/dockerfile/ubuntu-20.04-x86.Dockerfile b/dockerfile/ubuntu-20.04-amd64.Dockerfile similarity index 56% rename from dockerfile/ubuntu-20.04-x86.Dockerfile rename to dockerfile/ubuntu-20.04-amd64.Dockerfile index bdb5cb7d8697..a5ca82f1d762 100644 --- a/dockerfile/ubuntu-20.04-x86.Dockerfile +++ b/dockerfile/ubuntu-20.04-amd64.Dockerfile @@ -1,4 +1,4 @@ -# Second stage: the actual image +# TT-METAL UBUNTU 20.04 AMD64 DOCKERFILE FROM ubuntu:20.04 ARG DEBIAN_FRONTEND=noninteractive @@ -25,16 +25,19 @@ RUN /bin/bash /opt/tt_metal_infra/scripts/docker/install_test_deps.sh ${GTEST_VE COPY /scripts /opt/tt_metal_infra/scripts COPY build_metal.sh /scripts/build_metal.sh -# ENV TT_METAL_INFRA_DIR=/opt/tt_metal_infra -# ENV PYTHON_ENV_DIR=${TT_METAL_INFRA_DIR}/tt-metal/python_env -# RUN python3 -m venv $PYTHON_ENV_DIR +# Setup Env variables to setup Python Virtualenv - Install TT-Metal Python deps +ENV TT_METAL_INFRA_DIR=/opt/tt_metal_infra +ENV PYTHON_ENV_DIR=${TT_METAL_INFRA_DIR}/tt-metal/python_env +RUN python3 -m venv $PYTHON_ENV_DIR +ENV PATH="$PYTHON_ENV_DIR/bin:$PATH" -# COPY /docs/requirements-docs.txt ${TT_METAL_INFRA_DIR}/tt-metal/docs/. -# COPY /tt_metal/python_env/* ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/. -# ENV PATH="$PYTHON_ENV_DIR/bin:$PATH" -# RUN python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu \ -# && python3 -m pip install setuptools wheel +# Copy requirements from tt-metal folders with requirements.txt docs +COPY /docs/requirements-docs.txt ${TT_METAL_INFRA_DIR}/tt-metal/docs/. +COPY /tt_metal/python_env/* ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/. +RUN python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu \ + && python3 -m pip install setuptools wheel -# RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/requirements-dev.txt +RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/requirements-dev.txt +RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/docs/requirements-docs.txt CMD ["tail", "-f", "/dev/null"] diff --git a/scripts/docker/build_docker_image.sh b/scripts/docker/build_docker_image.sh index 82df50664e43..39c01283fbf5 100755 --- a/scripts/docker/build_docker_image.sh +++ b/scripts/docker/build_docker_image.sh @@ -5,5 +5,5 @@ TT_METAL_DOCKER_IMAGE_TAG=${1:-ubuntu-20.04-amd64:latest} TT_METAL_HOME=$(git rev-parse --show-toplevel) ( cd ${TT_METAL_HOME} || exit - docker build -f dockerfile/ubuntu-20.04-x86.Dockerfile -t ${TT_METAL_DOCKER_IMAGE_TAG} . + docker build -f dockerfile/ubuntu-20.04-amd64.Dockerfile -t ${TT_METAL_DOCKER_IMAGE_TAG} . ) \ No newline at end of file From 7941bba924d625440e8156f87db8444af975de46 Mon Sep 17 00:00:00 2001 From: David Ma Date: Fri, 31 May 2024 22:55:10 +0000 Subject: [PATCH 17/53] #0: Watcher interval to not include polling time This helps in cases where polling is slow (application uses links heavily), so watcher doesn't dominate the link. --- tt_metal/impl/debug/watcher_server.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tt_metal/impl/debug/watcher_server.cpp b/tt_metal/impl/debug/watcher_server.cpp index 82fd5f377da8..9b353e305947 100644 --- a/tt_metal/impl/debug/watcher_server.cpp +++ b/tt_metal/impl/debug/watcher_server.cpp @@ -785,17 +785,17 @@ static void watcher_loop(int sleep_usecs) { } log_info(LogLLRuntime, "Watcher server initialized, disabled features: {}", disabled_features); - double last_elapsed_time = watcher::get_elapsed_secs(); while (true) { - // Delay an amount such that we wait a minimum of the set sleep_usecs between polls. - while ((watcher::get_elapsed_secs() - last_elapsed_time) < ((double)sleep_usecs) / 1000000.) { + // Delay the amount of time specified by the user. Don't include watcher polling time to avoid the case where + // watcher dominates the communication links due to heavy traffic. + double last_elapsed_time = watcher::get_elapsed_secs(); + while ((watcher::get_elapsed_secs() - last_elapsed_time) < ((double) sleep_usecs) / 1000000.) { // Odds are this thread will be killed during the usleep, the kill signal is // watcher::enabled = false from the main thread. if (!watcher::enabled) break; usleep(1); } - last_elapsed_time = watcher::get_elapsed_secs(); { const std::lock_guard lock(watch_mutex); From a16f1a4178ed6067fadaec0c4c3be68a399a02d0 Mon Sep 17 00:00:00 2001 From: asaigal Date: Mon, 3 Jun 2024 23:44:16 +0000 Subject: [PATCH 18/53] #0: Revert "#8264: Worker thread optimizations:" This reverts commit 6b57cca73971550b0066f3236ebc3b496f09615c. --- CMakeLists.txt | 6 +- .../tensors/test_async_tensor_apis.cpp | 215 ++++--- tt_eager/tensor/tensor.cpp | 129 +++-- tt_eager/tensor/tensor.hpp | 55 +- tt_eager/tensor/tensor_impl.hpp | 5 +- tt_eager/tensor/tensor_utils.cpp | 532 ++++++++---------- tt_eager/tensor/types.hpp | 41 +- tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp | 8 +- .../eltwise_binary/eltwise_binary_op.cpp | 8 +- .../eltwise_unary/eltwise_unary_op.cpp | 6 +- tt_eager/tt_dnn/op_library/run_operation.cpp | 355 +++++------- .../tt_dnn/op_library/softmax/softmax_op.cpp | 8 +- .../transformer_tms/transformer_tms.cpp | 24 +- .../op_library/transpose/transpose_op.cpp | 4 +- tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp | 10 +- tt_metal/CMakeLists.txt | 2 +- tt_metal/detail/tt_metal.hpp | 12 - tt_metal/impl/device/device.cpp | 4 +- tt_metal/impl/device/device.hpp | 4 +- tt_metal/impl/dispatch/command_queue.cpp | 23 +- tt_metal/impl/dispatch/work_executor.hpp | 16 +- tt_metal/tt_metal.cpp | 105 +--- ttnn/cpp/ttnn/op_library/binary/binary_op.cpp | 8 +- 23 files changed, 672 insertions(+), 908 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bd35a6d78d6..b85f073c3f19 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,10 +34,6 @@ CHECK_COMPILERS() find_package(Boost REQUIRED COMPONENTS thread filesystem system regex) find_package(GTest REQUIRED) find_package (Python3 COMPONENTS Interpreter Development) -find_library(NUMA_LIBRARY NAMES numa) -if (NOT NUMA_LIBRARY) - message(FATAL_ERROR "NUMA library not found") -endif() ############################################################################################################################ # Setting build type flags @@ -88,7 +84,7 @@ set(CMAKE_INSTALL_DATAROOTDIR "${CMAKE_BINARY_DIR}/tmp/share") ############################################################################################################################ add_library(metal_common_libs INTERFACE) target_link_libraries(metal_common_libs INTERFACE - dl z pthread atomic stdc++ numa # system libraries + dl z pthread atomic stdc++ # system libraries Boost::thread Boost::filesystem Boost::system Boost::regex hwloc # hwloc has no cmake support, find_package won't find it ) diff --git a/tests/tt_eager/tensors/test_async_tensor_apis.cpp b/tests/tt_eager/tensors/test_async_tensor_apis.cpp index 3c7d689e57fb..3f3c8b430106 100644 --- a/tests/tt_eager/tensors/test_async_tensor_apis.cpp +++ b/tests/tt_eager/tensors/test_async_tensor_apis.cpp @@ -33,21 +33,19 @@ TEST_F(CommonFixture, TestTensorOwnershipSanity) { auto func = [device, host_tensor, readback_tensor]() mutable { // Ensure that both the lambda and global scope have ownership to this tensor EXPECT_EQ(host_tensor.tensor_attributes.use_count(), 2); - std::visit( - [](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 1); - } - }, - storage.buffer); - } - }, - host_tensor.get_storage()); + std::visit([](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 1); + } + }, + storage.buffer); + } + }, host_tensor.get_storage()); // Send tensor to device, read it back and copy it to empty tensor initialized by main thread Tensor reshaped_tensor = host_tensor.reshape(1, 1, 32, 128); auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); @@ -56,45 +54,41 @@ TEST_F(CommonFixture, TestTensorOwnershipSanity) { readback_tensor.set_shape(thread_local_tensor.get_shape()); readback_tensor.set_dtype(thread_local_tensor.get_dtype()); readback_tensor.set_layout(thread_local_tensor.get_layout()); - readback_tensor.tensor_attributes->metadata_populated = true; - readback_tensor.tensor_attributes->num_workers_completed++; + readback_tensor.set_populated(); // Ensure that the readback buffer is owned inside and outside the lambda - std::visit( - [](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 2); - } - }, - storage.buffer); - } - }, - readback_tensor.get_storage()); - }; - - func(); - std::visit( - [](auto&& storage) { + std::visit([](auto&& storage) { using T = std::decay_t; if constexpr (std::is_same_v) { std::visit( [](auto&& buf) { using buf_type = std::decay_t; if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 1); - for (int i = 0; i < 128 * 32; i++) { - EXPECT_EQ(buf[i], i); - } + EXPECT_EQ(buf.use_count(), 2); } }, - storage.buffer); + storage.buffer); } - }, - readback_tensor.get_storage()); + }, readback_tensor.get_storage()); + }; + + func(); + std::visit([](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 1); + for (int i = 0; i < 128 * 32; i++) { + EXPECT_EQ(buf[i], i); + } + } + }, + storage.buffer); + } + }, + readback_tensor.get_storage()); EXPECT_EQ(readback_tensor.get_dtype(), DataType::FLOAT32); EXPECT_EQ(readback_tensor.get_layout(), Layout::ROW_MAJOR); EXPECT_EQ(readback_tensor.get_shape(), ttnn::Shape(Shape({1, 1, 32, 128}))); @@ -132,7 +126,8 @@ TEST_F(CommonFixture, TestAsyncEltwiseBinary) { input_c_addr = std::get(input_tensor_c.get_storage()).buffer->address(); output_1_addr = std::get(output_tensor_device.get_storage()).buffer->address(); output_2_addr = std::get(output_tensor_device_2.get_storage()).buffer->address(); - } else { + } + else { EXPECT_EQ(std::get(input_tensor_a.get_storage()).buffer->address(), input_a_addr); EXPECT_EQ(std::get(input_tensor_b.get_storage()).buffer->address(), input_b_addr); EXPECT_EQ(std::get(input_tensor_c.get_storage()).buffer->address(), input_c_addr); @@ -145,8 +140,7 @@ TEST_F(CommonFixture, TestAsyncEltwiseBinary) { output_tensor_device.deallocate(); output_tensor_device_2.deallocate(); // Verify output data - auto& buf = - std::get>(std::get(output_tensor_host.get_storage()).buffer); + auto& buf = std::get>(std::get(output_tensor_host.get_storage()).buffer); EXPECT_EQ(buf.use_count(), 1); for (int j = 0; j < 1024 * 1024; j++) { EXPECT_EQ(bfloat16(buf[j]), bfloat16(static_cast(i - 2 * i * i))); @@ -165,27 +159,21 @@ TEST_F(CommonFixture, TestAsyncRefCountManager) { for (int i = 0; i < 5; i++) { // Run for multiple loops to ensure deterministic behaviour with device addresses // Initialize 2 tensors on device - Tensor tensor1 = - tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); - Tensor tensor2 = - tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + Tensor tensor1 = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + Tensor tensor2 = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); uint32_t tensor2_device_buf_addr = tensor2.device_buffer()->address(); - // Assign tensor1 to tensor2 and ensure that ref counts are appropriately updated with the buffer for tensor2 - // deallocated + // Assign tensor1 to tensor2 and ensure that ref counts are appropriately updated with the buffer for tensor2 deallocated tensor2 = tensor1; EXPECT_EQ(tensor2.tensor_attributes->main_thread_ref_count, 2); EXPECT_EQ(tensor1.tensor_attributes->main_thread_ref_count, 2); - // To check if tensor2 is deallocated, create a third tensor on device and ensure that its address matches the - // prev addr for tensor2 - Tensor tensor3 = - tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + // To check if tensor2 is deallocated, create a third tensor on device and ensure that its address matches the prev addr for tensor2 + Tensor tensor3 = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); EXPECT_EQ(tensor3.device_buffer()->address(), tensor2_device_buf_addr); EXPECT_EQ(tensor1.device_buffer()->address(), tensor2.device_buffer()->address()); } log_info(LogTest, "Testing Device tensor self-assignment through function"); for (int i = 0; i < 5; i++) { - Tensor device_tensor = - tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + Tensor device_tensor = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); uint32_t device_tensor_address = device_tensor.device_buffer()->address(); // This step will copy the tensor to a temp rval and std::move it back to the caller's instance of device_tensor // Ensure ref count and address remain unchanged @@ -196,16 +184,14 @@ TEST_F(CommonFixture, TestAsyncRefCountManager) { log_info(LogTest, "Testing Device tensor move assignment"); for (int i = 0; i < 5; i++) { - Tensor tensor1 = - tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + Tensor tensor1 = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); Tensor tensor2 = std::move(tensor1); EXPECT_EQ(tensor2.tensor_attributes->main_thread_ref_count, 1); EXPECT_EQ(tensor1.tensor_attributes, nullptr); } log_info(LogTest, "Testing Device tensor self-assignment"); - Tensor tensor_to_self_assign = - tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(0), DataType::BFLOAT16).to(device); + Tensor tensor_to_self_assign = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(0), DataType::BFLOAT16).to(device); uint32_t tensor_to_self_assign_address = tensor_to_self_assign.device_buffer()->address(); tensor_to_self_assign = tensor_to_self_assign; EXPECT_EQ(tensor_to_self_assign.tensor_attributes->main_thread_ref_count, 1); @@ -233,6 +219,7 @@ TEST_F(CommonFixture, TestAsyncRefCountManager) { // Tensor output_tensor_device = mul(add(input_tensor_a, input_tensor_b), input_tensor_c); // Tensor output_tensor_device_2 = neg(sub(output_tensor_device, input_tensor_c)); + // EXPECT_EQ(output_tensor_device.get_shape(), ttnn::Shape(Shape({1, 1, 1023, 1023}))); // EXPECT_EQ(output_tensor_device.get_dtype(), DataType::BFLOAT16); @@ -247,50 +234,45 @@ TEST_F(CommonFixture, TestAsyncRefCountManager) { // device->set_worker_mode(WorkExecutorMode::SYNCHRONOUS); // } + TEST_F(CommonFixture, TestTensorAsyncDataMovement) { // Test 2 data paths here (resembles async mode): - // 1. Main -> Worker: Create a tensor in the main thread. Ensure that it is accessible in the worker thread even - // after its destroyed + // 1. Main -> Worker: Create a tensor in the main thread. Ensure that it is accessible in the worker thread even after its destroyed // by the main thread. This resembles host -> device data movement - // 2. Worker -> Main: Create an empty tensor in the mainb thread. Populate it in the worker thread. Ensure that the - // tensor is correctly + // 2. Worker -> Main: Create an empty tensor in the mainb thread. Populate it in the worker thread. Ensure that the tensor is correctly // populated in the main thread once the worker is done. Device* device = this->devices_[0]; uint32_t tensor_start = 0; uint32_t num_tiles = 128; uint32_t tensor_stop = TILE_HEIGHT * TILE_WIDTH * num_tiles; - Tensor readback_tensor({}, 1); - ; + Tensor readback_tensor({}, 1);; std::thread worker; { // host_tensor only lives in this scope Tensor host_tensor = tt::numpy::arange(tensor_start, tensor_stop, 1); log_info(LogTest, "Spawning worker thread"); - worker = std::thread([tensor_stop, host_tensor, readback_tensor, device]() mutable { + worker = std::thread([tensor_stop, host_tensor, readback_tensor, device] () mutable { // Sleep for 3 seconds to ensure that main thread deallocates host_tensor std::this_thread::sleep_for(std::chrono::milliseconds(3000)); log_info(LogTest, "Worker started"); // Main thread should have deallocated host_tensor by this point EXPECT_EQ(host_tensor.tensor_attributes.use_count(), 1); // Ensure that the buffer inside host_buffer is owned by a single tensor_attr object - // This buffer will not go out of scope until the last object owning it is destroyed (i.e. until the thread - // is done) - std::visit( - [](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 1); - } - }, - storage.buffer); - } - }, - host_tensor.get_storage()); + // This buffer will not go out of scope until the last object owning it is destroyed (i.e. until the thread is done) + std::visit([](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 1); + } + }, + storage.buffer); + } + }, host_tensor.get_storage()); Tensor reshaped_tensor = host_tensor.reshape(1, 1, 32, tensor_stop / 32); auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); @@ -300,25 +282,22 @@ TEST_F(CommonFixture, TestTensorAsyncDataMovement) { readback_tensor.set_shape(thread_local_tensor.get_shape()); readback_tensor.set_dtype(thread_local_tensor.get_dtype()); readback_tensor.set_layout(thread_local_tensor.get_layout()); - readback_tensor.tensor_attributes->metadata_populated = true; - readback_tensor.tensor_attributes->num_workers_completed++; + readback_tensor.set_populated(); // Ensure that this buffer is currently owned by both the thread_local and read_back tensors // This is because we explictly pass in the buffer to a new tensor_attr object - std::visit( - [](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 2); - } - }, - storage.buffer); - } - }, - readback_tensor.get_storage()); + std::visit([](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 2); + } + }, + storage.buffer); + } + }, readback_tensor.get_storage()); log_info(LogTest, "Worker Done"); }); // Call deallocate on the tensor in the main thread to ensure that this call is safe @@ -329,22 +308,22 @@ TEST_F(CommonFixture, TestTensorAsyncDataMovement) { worker.join(); log_info(LogTest, "Verifying populated tensor in main thread"); std::visit( - [tensor_start, tensor_stop](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [tensor_start, tensor_stop](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 1); - for (int i = tensor_start; i < tensor_stop; i++) { - EXPECT_EQ(buf[i], i); + [tensor_start, tensor_stop](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [tensor_start, tensor_stop](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 1); + for (int i = tensor_start; i < tensor_stop; i++) { + EXPECT_EQ(buf[i], i); + } } - } - }, + }, storage.buffer); - } - }, + } + }, readback_tensor.get_storage()); EXPECT_EQ(readback_tensor.get_dtype(), DataType::FLOAT32); EXPECT_EQ(readback_tensor.get_layout(), Layout::ROW_MAJOR); diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 9f28d1035671..c59e12608b51 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -35,7 +35,7 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L [&](auto&& storage) { using StorageType = std::decay_t; if constexpr (std::is_same_v) { - this->tensor_attributes->num_shards_to_be_populated = 1; + this->tensor_attributes->tensor_populated = {true}; } else if constexpr (std::is_same_v) { TT_ASSERT(storage.buffer->device() != nullptr); workers = {storage.buffer->device()}; @@ -48,9 +48,9 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L if (not this->workers.at(0)->in_main_thread()) { this->tensor_attributes->main_thread_tensor = false; } - this->tensor_attributes->num_shards_to_be_populated = 1; + this->tensor_attributes->tensor_populated = {true}; } else if constexpr (std::is_same_v) { - this->tensor_attributes->num_shards_to_be_populated = 1; + this->tensor_attributes->tensor_populated = {true}; } else if constexpr (std::is_same_v) { workers.reserve(storage.num_buffers()); for (int i = 0; i < storage.ordered_device_ids.size(); i++) { @@ -68,16 +68,14 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L if (not this->workers.at(0)->in_main_thread()) { this->tensor_attributes->main_thread_tensor = false; } - this->tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); + this->tensor_attributes->tensor_populated = std::vector(storage.num_buffers(), true); } else if constexpr (std::is_same_v) { - this->tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); + this->tensor_attributes->tensor_populated = std::vector(storage.num_buffers(), true); } else { raise_unsupported_storage(); } }, storage); - this->tensor_attributes->num_workers_completed = this->tensor_attributes->num_shards_to_be_populated; - this->tensor_attributes->metadata_populated = true; } Tensor::Tensor(const Storage storage, const Shape shape, DataType dtype, Layout layout) : @@ -241,6 +239,45 @@ void Tensor::perform_cleanup_for_async_mode() { } } +// Main Thread - Wait for all workers in this tensor to populate the entire tensor +void Tensor::wait_for_tensor_data_populated() const { + ZoneScoped; + // Stall until all the workers for this tensor + // have populated the full tensor + for (int i = 0; i < this->tensor_attributes->tensor_populated.size(); i++) { + while (true) { + std::scoped_lock lock(this->tensor_attributes->populated_mutex); + if (this->tensor_attributes->tensor_populated.at(i)) + break; + } + } +} + +// Main Thread - Wait for the first worker in this tensor to populate the global metadata fields +void Tensor::wait_for_tensor_metadata_populated() const { + ZoneScoped; + // First worker is responsible for updating all metadata fields + // Stall until this worker is done + while (true) { + std::scoped_lock lock(this->tensor_attributes->populated_mutex); + if (this->tensor_attributes->tensor_populated.at(0)) + break; + }; +} + +// Worker Thread - Set populated flag to true, once worker has completed it's task for this tensor +void Tensor::set_populated(Device* worker) { + // If worker is not specified, set entry for all workers to true + std::scoped_lock lock(this->tensor_attributes->populated_mutex); + if (not worker) { + for (int i = 0; i < this->tensor_attributes->tensor_populated.size(); i++) { + this->tensor_attributes->tensor_populated.at(i) = true; + } + } else { + this->tensor_attributes->tensor_populated.at(worker->id()) = true; + } +} + void Tensor::deepcopy(const Tensor& other) { ZoneScoped; // Wait until the tensor being copied is populated @@ -251,8 +288,7 @@ void Tensor::deepcopy(const Tensor& other) { this->set_dtype(other.get_dtype()); this->set_layout(other.get_layout()); // Set metadata populated flag for getters - this->tensor_attributes->metadata_populated = true; - this->tensor_attributes->num_workers_completed++; + this->set_populated(); } void Tensor::populate_buffers_and_metadata(const Tensor& other) { @@ -268,17 +304,17 @@ void Tensor::populate_buffers_and_metadata(const Tensor& other) { using StorageType = std::decay_t; if constexpr (std::is_same_v or std::is_same_v) { std::get(this->tensor_attributes->storage).insert_buffer(storage.get_buffer()); + this->tensor_attributes->tensor_populated = {true}; } else if constexpr ( std::is_same_v or std::is_same_v) { std::get(this->tensor_attributes->storage).buffers = storage.buffers; std::get(this->tensor_attributes->storage).shapes = storage.shapes; + this->tensor_attributes->tensor_populated = std::vector(storage.buffers.size(), true); } }, other.get_storage()); // Non blocking storage query, since this is done for tensors that get created inside the // worker thread - this->tensor_attributes->metadata_populated = true; - this->tensor_attributes->num_workers_completed++; } std::vector Tensor::get_workers(bool blocking) const { @@ -448,20 +484,21 @@ Tensor Tensor::to(const std::vector& workers, const MemoryConfig& mem_c uint32_t num_workers = workers_to_use.size(); for (int worker_index = 0; worker_index < workers_to_use.size(); ++worker_index) { auto& worker = workers_to_use[worker_index]; - worker->push_work( - [worker, *this, device_tensor, mem_config, num_workers, worker_index] () mutable { - auto shard = get_shard_for_device(*this, worker, worker_index); - if (shard.storage_type() == StorageType::OWNED) { - shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, std::nullopt); - } - insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); - uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; - if (not num_workers_completed) { - device_tensor.set_shape(this->get_shape()); - device_tensor.set_dtype(this->get_dtype()); - device_tensor.set_layout(this->get_layout()); - device_tensor.tensor_attributes->metadata_populated = true; - } + worker->push_work([worker, *this, device_tensor, mem_config, num_workers, worker_index]() mutable { + auto shard = get_shard_for_device(*this, worker, worker_index); + if (shard.storage_type() == StorageType::OWNED) { + shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, std::nullopt); + } + insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); + if (not worker->id()) { + device_tensor.set_shape(this->get_shape()); + device_tensor.set_dtype(this->get_dtype()); + device_tensor.set_layout(this->get_layout()); + } + if (num_workers > 1) + device_tensor.set_populated(worker); + else + device_tensor.set_populated(); }); } device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count); @@ -491,18 +528,22 @@ Tensor Tensor::cpu(bool blocking) const { auto shard = get_shard_for_device(*this, target_device); shard = tensor_impl::to_host_wrapper(shard, blocking); insert_buffer_and_shape_for_device(target_device, shard, host_tensor, worker_index); - uint32_t num_workers_completed = (host_tensor.tensor_attributes->num_workers_completed)++; - if (not num_workers_completed) { + if (not target_device->id() or workers.size() == 1) { host_tensor.set_shape(this->get_shape()); host_tensor.set_dtype(this->get_dtype()); host_tensor.set_layout(this->get_layout()); - host_tensor.tensor_attributes->metadata_populated = true; + } + if (workers.size() == 1) { + host_tensor.set_populated(); + } else { + host_tensor.set_populated(target_device); } }); } - if (blocking) { - detail::SynchronizeWorkerThreads(workers); + for (auto target_device : workers) { + target_device->synchronize(); + } } // Update main_thread_ref_count for tensor after pushing to queue. this->tensor_attributes->update_main_thread_ref_count(workers.at(0), original_tensor_ref_count); @@ -570,13 +611,12 @@ Tensor Tensor::to(Layout target_layout, DeviceMesh* device_mesh) const { auto shard = get_shard_for_device(*this, worker, worker_index); shard = tensor_impl::to_layout_wrapper(shard, target_layout); insert_buffer_and_shape_for_device(worker, shard, tensor_modified_layout, worker_index); - uint32_t num_workers_completed = (tensor_modified_layout.tensor_attributes->num_workers_completed)++; - if (not num_workers_completed) { + if (not(worker->id())) { tensor_modified_layout.set_shape(this->get_shape()); tensor_modified_layout.set_dtype(this->get_dtype()); tensor_modified_layout.set_layout(target_layout); - tensor_modified_layout.tensor_attributes->metadata_populated = true; - }; + } + tensor_modified_layout.set_populated(worker); }); } return tensor_modified_layout; @@ -945,18 +985,15 @@ Tensor allocate_tensor_on_device( for (int worker_index = 0; worker_index < num_workers; ++worker_index) { auto& worker = workers[worker_index]; - worker->push_work( - [shape, data_type, layout, worker, memory_config, device_tensor, worker_index] () mutable { - auto local_tensor = create_device_tensor(shape.value(), data_type, layout, worker, memory_config); - insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); - - uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; - if (not num_workers_completed) { - device_tensor.set_shape(ttnn::Shape(shape)); - device_tensor.set_dtype(data_type); - device_tensor.set_layout(layout); - device_tensor.tensor_attributes->metadata_populated = true; - } + worker->push_work([shape, data_type, layout, worker, memory_config, device_tensor, worker_index]() mutable { + auto local_tensor = create_device_tensor(shape.value(), data_type, layout, worker, memory_config); + insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); + if (not worker->id()) { + device_tensor.set_shape(ttnn::Shape(shape)); + device_tensor.set_dtype(data_type); + device_tensor.set_layout(layout); + } + device_tensor.set_populated(worker); }); } device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count); diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index e60d7a77ef4c..d29c0730942a 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -32,12 +32,10 @@ struct Tensor { DataType dtype; Layout layout; std::mutex populated_mutex; - uint32_t num_shards_to_be_populated = 0; + std::vector tensor_populated = {}; uint32_t main_thread_ref_count = 0; std::atomic num_sibling_workers_sharing_tensor = 0; std::atomic main_thread_tensor = true; - std::atomic metadata_populated = false; - std::atomic num_workers_completed = 0; bool deallocated = false; // Set to true if device side storage was deallocated bool dynamic_storage = false; // Storage type can change, depending on op behaviour bool track_ref_count = false; @@ -157,7 +155,7 @@ struct Tensor { std::get(this->tensor_attributes->storage).ordered_device_ids), [](const Device *worker) { return worker->id(); }); } - this->tensor_attributes->num_shards_to_be_populated = workers.size(); + this->tensor_attributes->tensor_populated = std::vector(workers.size(), false); } else if (num_buffers) { if (num_buffers == 1) { this->tensor_attributes->storage = OwnedStorage(); @@ -169,7 +167,7 @@ struct Tensor { std::get(this->tensor_attributes->storage).shapes = std::vector(num_buffers, this->tensor_attributes->shape.value()); } - this->tensor_attributes->num_shards_to_be_populated = num_buffers; + this->tensor_attributes->tensor_populated = std::vector(num_buffers, false); } } @@ -288,26 +286,19 @@ struct Tensor { const ttnn::Shape &get_shape() const; const DataType &get_dtype() const; const Layout &get_layout() const; - - // ====================================================================================== - // Non-Blocking Getters. Query attributes directly, without waiting for worker completion - // ====================================================================================== - inline const Storage &storage() const { return this->tensor_attributes->storage; }; - inline const Shape &legacy_shape() const { return this->tensor_attributes->shape.value(); }; - inline const ttnn::Shape &shape() const { return this->tensor_attributes->shape; }; - inline const DataType &dtype() const { return this->tensor_attributes->dtype; }; - inline const Layout &layout() const { return this->tensor_attributes->layout; }; - // ====================================================================================== // Setters // ====================================================================================== - inline void set_storage(const Storage &storage) { this->tensor_attributes->storage = storage; } - inline void set_shape(const ttnn::Shape &shape) { this->tensor_attributes->shape = shape; } - inline void set_dtype(const DataType &dtype) { this->tensor_attributes->dtype = dtype; } - inline void set_layout(const Layout &layout) { this->tensor_attributes->layout = layout; } + void set_storage(const Storage &storage) { this->tensor_attributes->storage = storage; } + void set_shape(const ttnn::Shape &shape) { this->tensor_attributes->shape = shape; } + void set_dtype(const DataType &dtype) { this->tensor_attributes->dtype = dtype; } + void set_layout(const Layout &layout) { this->tensor_attributes->layout = layout; } + void set_populated(Device *worker = nullptr); // ====================================================================================== // Extra Helper Functions // ====================================================================================== + void wait_for_tensor_data_populated() const; + void wait_for_tensor_metadata_populated() const; StorageType storage_type() const; const Shape strides() const; uint32_t volume() const; @@ -364,31 +355,13 @@ struct Tensor { static constexpr auto attribute_names = std::make_tuple("storage", "shape", "dtype", "layout"); const auto attribute_values() const { return std::make_tuple( - std::cref(this->tensor_attributes->storage), - std::cref(this->tensor_attributes->shape), - std::cref(this->tensor_attributes->dtype), - std::cref(this->tensor_attributes->layout)); + std::cref(this->get_storage()), + std::cref(this->get_shape()), + std::cref(this->get_dtype()), + std::cref(this->get_layout())); } std::vector host_page_ordering(); - - // Main Thread - Wait for all workers in this tensor to populate the entire tensor - inline void wait_for_tensor_data_populated() const { - ZoneScoped; - // Stall until all the workers for this tensor - // have populated the full tensor - while (this->tensor_attributes->num_workers_completed < this->tensor_attributes->num_shards_to_be_populated) { - } - } - - // Main Thread - Wait for the first worker in this tensor to populate the global metadata fields - inline void wait_for_tensor_metadata_populated() const { - ZoneScoped; - // First worker is responsible for updating all metadata fields - // Stall until this worker is done - while (not this->tensor_attributes->metadata_populated) { - } - } }; Tensor create_device_tensor( diff --git a/tt_eager/tensor/tensor_impl.hpp b/tt_eager/tensor/tensor_impl.hpp index 2bf7bbdbcb53..a16047e02b01 100644 --- a/tt_eager/tensor/tensor_impl.hpp +++ b/tt_eager/tensor/tensor_impl.hpp @@ -392,6 +392,7 @@ inline Tensor to_host(const Tensor& tensor, bool blocking = true) { host_tensor.set_dtype(tensor.get_dtype()); host_tensor.set_layout(tensor.get_layout()); insert_buffer_and_shape_for_device(device, shard, host_tensor, device_index); + host_tensor.set_populated(device); } return host_tensor; } else { @@ -941,7 +942,7 @@ inline std::string to_string(const Tensor& tensor, std::optional origi } if (is_tensor_on_device(tensor)) { - return to_string(tensor.cpu()); + return to_string(to_host(tensor)); } return std::visit( @@ -984,7 +985,7 @@ inline std::string to_string(const Tensor& tensor, std::optional origi TT_THROW("Cannot print a device tensor!"); } else if constexpr (std::is_same_v) { auto devices = get_devices(tensor); - auto host_tensor = tensor.cpu(); + auto host_tensor = to_host(tensor); auto device_index = 0; std::stringstream ss; apply(host_tensor, [&](const Tensor& device_tensor) { diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index d85efa6c9f88..c9d96d91cd6c 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -11,214 +11,189 @@ namespace tt { namespace tt_metal { -template -Tensor to_weight_special_padding_tile_layout( - const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { - auto w_shape = conv_weight_tensor.get_legacy_shape(); - auto compute = [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { - uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; - uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; - auto weight_matrix_cols = w_shape[0]; - // width padding - if (weight_matrix_cols % in1_block_w_datums != 0) { - weight_matrix_cols = - (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * in1_block_w_datums; - } - // height padding - assert(in1_block_h_datums >= w_shape[1] * w_shape[3]); - uint32_t block_height_padding = in1_block_h_datums - (w_shape[1] * w_shape[3]); - auto weight_matrix_rows = ((w_shape[1] * w_shape[3]) + block_height_padding) * w_shape[2]; - Shape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(compute_volume(output_shape)); - for (auto r = 0; r < w_shape[2]; r++) { - for (auto s = 0; s < w_shape[3]; s++) { - for (auto c = 0; c < w_shape[1]; c++) { - for (auto k = 0; k < w_shape[0]; k++) { - auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + - r * ((w_shape[3] * w_shape[1]) + block_height_padding) * weight_matrix_cols; - auto idx = - k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; - output_buffer[matrix_idx] = input_buffer[idx]; + + template + Tensor to_weight_special_padding_tile_layout(const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { + auto w_shape = conv_weight_tensor.get_legacy_shape(); + auto compute = + [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { + uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; + uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; + auto weight_matrix_cols = w_shape[0]; + // width padding + if (weight_matrix_cols % in1_block_w_datums != 0) { + weight_matrix_cols = (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * + in1_block_w_datums; + } + // height padding + assert(in1_block_h_datums >= w_shape[1] * w_shape[3]); + uint32_t block_height_padding = in1_block_h_datums - (w_shape[1] * w_shape[3]); + auto weight_matrix_rows = ((w_shape[1] * w_shape[3]) + block_height_padding) * w_shape[2]; + Shape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(compute_volume(output_shape)); + for (auto r = 0; r < w_shape[2]; r++) { + for (auto s = 0; s < w_shape[3]; s++) { + for (auto c = 0; c < w_shape[1]; c++) { + for (auto k = 0; k < w_shape[0]; k++) { + auto matrix_idx = + k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + + r * ((w_shape[3] * w_shape[1]) + block_height_padding) * weight_matrix_cols; + auto idx = k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + + r * w_shape[3] + s; + output_buffer[matrix_idx] = input_buffer[idx]; + } + } } } - } - } - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = - pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - } - if (output_dtype == DataType::BFLOAT4_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = - pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + if constexpr (std::is_same::value) { + if (output_dtype == DataType::BFLOAT8_B) { + auto output_float_data = output_buffer.get(); + auto output_packed_data = pack_fp32_vec_as_bfp8_tiles( + output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + auto rm_tensor = Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + output_shape, + output_dtype, + Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + } + if (output_dtype == DataType::BFLOAT4_B) { + auto output_float_data = output_buffer.get(); + auto output_packed_data = pack_fp32_vec_as_bfp4_tiles( + output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + auto rm_tensor = Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + output_shape, + output_dtype, + Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + } + } else { + TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); + } auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); + std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); return rm_tensor.to(Layout::TILE); + }; + return std::visit( + [&compute](auto&& storage) -> Tensor { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return compute(owned_buffer::get_as(storage.buffer)); + } else if constexpr (std::is_same_v) { + return compute(borrowed_buffer::get_as(storage.buffer)); + } else { + TT_THROW("Unsupported storage type"); + } + }, + conv_weight_tensor.get_storage()); + } + + + template + Tensor to_weight_tile_layout(const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { + auto w_shape = conv_weight_tensor.get_legacy_shape(); + auto compute = + [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { + auto weight_matrix_cols = w_shape[0]; + // width padding + uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; + if(weight_matrix_cols%in1_block_w_datums != 0) { + weight_matrix_cols = (uint32_t) std::ceil( (double) weight_matrix_cols / (double) in1_block_w_datums ) * in1_block_w_datums; } - } else { - TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); - } - auto rm_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - }; - return std::visit( - [&compute](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return compute(owned_buffer::get_as(storage.buffer)); - } else if constexpr (std::is_same_v) { - return compute(borrowed_buffer::get_as(storage.buffer)); - } else { - TT_THROW("Unsupported storage type"); + // height padding + auto weight_matrix_rows = w_shape[1]*w_shape[2]*w_shape[3]; + uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; + if (weight_matrix_rows % in1_block_h_datums != 0) { + weight_matrix_rows = (uint32_t) std::ceil( (double) weight_matrix_rows / (double) in1_block_h_datums ) * in1_block_h_datums; } - }, - conv_weight_tensor.get_storage()); -} - -template -Tensor to_weight_tile_layout( - const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { - auto w_shape = conv_weight_tensor.get_legacy_shape(); - auto compute = [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { - auto weight_matrix_cols = w_shape[0]; - // width padding - uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; - if (weight_matrix_cols % in1_block_w_datums != 0) { - weight_matrix_cols = - (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * in1_block_w_datums; - } - // height padding - auto weight_matrix_rows = w_shape[1] * w_shape[2] * w_shape[3]; - uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; - if (weight_matrix_rows % in1_block_h_datums != 0) { - weight_matrix_rows = - (uint32_t)std::ceil((double)weight_matrix_rows / (double)in1_block_h_datums) * in1_block_h_datums; - } - Shape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(compute_volume(output_shape)); - for (auto r = 0; r < w_shape[2]; r++) { - for (auto s = 0; s < w_shape[3]; s++) { - for (auto c = 0; c < w_shape[1]; c++) { - for (auto k = 0; k < w_shape[0]; k++) { - auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + - r * w_shape[3] * w_shape[1] * weight_matrix_cols; - auto idx = - k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; - output_buffer[matrix_idx] = input_buffer[idx]; + Shape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(compute_volume(output_shape)); + for(auto r = 0; r < w_shape[2]; r++) { + for(auto s = 0; s < w_shape[3]; s++) { + for(auto c = 0; c < w_shape[1]; c++) { + for(auto k = 0; k < w_shape[0]; k++) { + auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + r * w_shape[3] * w_shape[1] * weight_matrix_cols; + auto idx = k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; + output_buffer[matrix_idx] = input_buffer[idx]; + } } } } - } - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = - pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - } - if (output_dtype == DataType::BFLOAT4_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = - pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - } - } else { - TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); - } - auto rm_tensor = - Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - }; - return std::visit( - [&compute](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return compute(owned_buffer::get_as(storage.buffer)); - } else if constexpr (std::is_same_v) { - return compute(borrowed_buffer::get_as(storage.buffer)); + if constexpr (std::is_same::value) { + if (output_dtype == DataType::BFLOAT8_B) { + auto output_float_data = output_buffer.get(); + auto output_packed_data = pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + } + if (output_dtype == DataType::BFLOAT4_B) { + auto output_float_data = output_buffer.get(); + auto output_packed_data = pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + } } else { - TT_THROW("Unsupported storage type"); + TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); } - }, - conv_weight_tensor.get_storage()); -} + auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + }; + return std::visit( + [&compute](auto&& storage) -> Tensor { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return compute(owned_buffer::get_as(storage.buffer)); + } else if constexpr (std::is_same_v) { + return compute(borrowed_buffer::get_as(storage.buffer)); + } else { + TT_THROW("Unsupported storage type"); + } + }, + conv_weight_tensor.get_storage()); + } -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_tiled_layout( - Tensor conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { - TT_ASSERT( - conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && - "Convolution weights should be in row major layout for conversion to tilized layout."); - const static std::map< - DataType, - std::function> - to_w_tile_layout_map = { + // Converts convolution weights to tilized 2d matrix layout. + // Returns a new tensor with layout=Tile + Tensor convert_conv_weight_tensor_to_tiled_layout(Tensor conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { + TT_ASSERT(conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && "Convolution weights should be in row major layout for conversion to tilized layout."); + const static std::map> to_w_tile_layout_map = { {DataType::BFLOAT16, &to_weight_tile_layout}, {DataType::FLOAT32, &to_weight_tile_layout}, {DataType::UINT32, &to_weight_tile_layout}, }; - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); + if (output_dtype.has_value()) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); + } else { + TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); + } } + return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); } - return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())( - conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); -} -// Converts convolution weights to tilized 2d matrix layout. -// Returns a new tensor with layout=Tile -Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( - Tensor conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { - TT_ASSERT( - conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && - "Convolution weights should be in row major layout for conversion to tilized layout."); - const static std::map< - DataType, - std::function> - to_w_tile_layout_map = { + // Converts convolution weights to tilized 2d matrix layout. + // Returns a new tensor with layout=Tile + Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout(Tensor conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { + TT_ASSERT(conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && "Convolution weights should be in row major layout for conversion to tilized layout."); + const static std::map> to_w_tile_layout_map = { {DataType::BFLOAT16, &to_weight_special_padding_tile_layout}, {DataType::FLOAT32, &to_weight_special_padding_tile_layout}, - {DataType::UINT32, &to_weight_special_padding_tile_layout}}; - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); + {DataType::UINT32, &to_weight_special_padding_tile_layout} + }; + if (output_dtype.has_value()) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); + } else { + TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); + } } + return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); } - return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())( - conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); -} /* Helper function to aid in converting grouped weight tensor to ungrouped weight tensor with padded zero channels @@ -348,39 +323,44 @@ const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volu switch (neg_idx) { case 0: - TT_ASSERT(old_volume % C * H * W == 0); - N = old_volume / (C * H * W); + TT_ASSERT(old_volume % C*H*W == 0); + N = old_volume/(C*H*W); break; case 1: - TT_ASSERT(old_volume % N * H * W == 0); - C = old_volume / (N * H * W); + TT_ASSERT(old_volume % N*H*W == 0); + C = old_volume/(N*H*W); break; case 2: - TT_ASSERT(old_volume % N * C * W == 0); - H = old_volume / (N * C * W); + TT_ASSERT(old_volume % N*C*W == 0); + H = old_volume/(N*C*W); break; case 3: - TT_ASSERT(old_volume % N * C * H == 0); - W = old_volume / (N * C * H); + TT_ASSERT(old_volume % N*C*H == 0); + W = old_volume/(N*C*H); break; - case -1: // In case where there is no negative value in ns - TT_ASSERT(N * C * H * W == old_volume); + case -1: // In case where there is no negative value in ns + TT_ASSERT(N*C*H*W == old_volume); break; - default: TT_ASSERT(false && "Unexpected neg_idx in reshape!"); + default: + TT_ASSERT(false && "Unexpected neg_idx in reshape!"); } return {(uint32_t)N, (uint32_t)C, (uint32_t)H, (uint32_t)W}; } -bool is_arch_gs(const tt::ARCH& arch) { return arch == tt::ARCH::GRAYSKULL; } + bool is_arch_gs(const tt::ARCH& arch) { + return arch == tt::ARCH::GRAYSKULL; + } -bool is_arch_whb0(const tt::ARCH& arch) { return arch == tt::ARCH::WORMHOLE_B0; } + bool is_arch_whb0(const tt::ARCH& arch) { + return arch == tt::ARCH::WORMHOLE_B0; + } -bool is_cpu_tensor(const Tensor& tensor) { - return tensor.storage_type() == StorageType::OWNED || tensor.storage_type() == StorageType::BORROWED; -} + bool is_cpu_tensor(const Tensor& tensor) { + return tensor.storage_type() == StorageType::OWNED || tensor.storage_type() == StorageType::BORROWED; + } -bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } + bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) { const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); @@ -389,7 +369,8 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, multi_device_tensor.get_legacy_shape(), multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout()}; + multi_device_tensor.get_layout() + }; } TT_THROW("Device not found in multi-device tensor"); } @@ -399,10 +380,10 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const Device* device } bool is_multi_device_tensor(const Tensor& tensor) { - return tensor.storage_type() == StorageType::MULTI_DEVICE or - tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; + return tensor.storage_type() == StorageType::MULTI_DEVICE or tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; } + std::vector get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor) { std::vector tensors; if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE) { @@ -414,7 +395,8 @@ std::vector get_tensors_from_multi_device_storage(const Tensor& multi_de DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, tensor_storage.shapes.at(device_id), multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout()}; + multi_device_tensor.get_layout() + }; } return tensors; } else if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { @@ -424,9 +406,11 @@ std::vector get_tensors_from_multi_device_storage(const Tensor& multi_de OwnedStorage{tensor_storage.get_buffer(i)}, tensor_storage.shapes[i], multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout()}); + multi_device_tensor.get_layout() + }); } - } else { + } + else { TT_FATAL(false, "get_tensors_from_multi_device_storage only support multi device tensors"); } return tensors; @@ -436,15 +420,15 @@ DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& if (tensor.storage_type() == StorageType::MULTI_DEVICE) { const auto& tensor_storage = std::get(tensor.get_storage()); return tensor_storage.strategy; - } else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + } + else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { const auto& tensor_storage = std::get(tensor.get_storage()); return tensor_storage.strategy; } TT_THROW("Tensor is not a multi-device tensor"); } -Tensor create_multi_device_tensor( - const std::vector& tensors, StorageType storage_type, const DistributedTensorConfig& strategy) { +Tensor create_multi_device_tensor(const std::vector& tensors, StorageType storage_type, const DistributedTensorConfig& strategy) { if (tensors.empty()) { TT_THROW("Cannot create multi-device tensor with empty tensor list"); } @@ -464,7 +448,8 @@ Tensor create_multi_device_tensor( MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, shapes}, tensors.at(0).get_legacy_shape(), tensors.at(0).get_dtype(), - tensors.at(0).get_layout()}; + tensors.at(0).get_layout() + }; } else if (storage_type == StorageType::MULTI_DEVICE_HOST) { std::vector owned_buffers; std::vector shapes; @@ -476,7 +461,8 @@ Tensor create_multi_device_tensor( MultiDeviceHostStorage{strategy, owned_buffers, shapes}, tensors.at(0).get_legacy_shape(), tensors.at(0).get_dtype(), - tensors.at(0).get_layout()}; + tensors.at(0).get_layout() + }; } else { TT_THROW("Invalid storage type for multi-device tensor"); } @@ -485,11 +471,9 @@ Tensor create_multi_device_tensor( Tensor transform(const Tensor& tensor, std::function transform_func) { auto input_tensors = get_tensors_from_multi_device_storage(tensor); std::vector output_tensors(input_tensors.size()); - std::transform(input_tensors.begin(), input_tensors.end(), output_tensors.begin(), [&](const auto& device_tensor) { - return transform_func(device_tensor); - }); - return create_multi_device_tensor( - output_tensors, tensor.storage_type(), get_distributed_tensor_config_from_tensor(tensor)); + std::transform(input_tensors.begin(), input_tensors.end(), output_tensors.begin(), + [&](const auto& device_tensor) { return transform_func(device_tensor); }); + return create_multi_device_tensor(output_tensors, tensor.storage_type(), get_distributed_tensor_config_from_tensor(tensor)); } void apply(const Tensor& tensor, std::function callable) { @@ -499,6 +483,7 @@ void apply(const Tensor& tensor, std::function callable) { } } + std::vector get_devices(const Tensor& tensor) { std::vector devices; if (tensor.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) { @@ -520,10 +505,7 @@ uint32_t num_buffers_in_tensor(const Tensor& tensor) { } else if (std::holds_alternative(tensor.get_storage())) { auto host_storage = std::get(tensor.get_storage()); return host_storage.num_buffers(); - } else if ( - std::holds_alternative(tensor.get_storage()) || - std::holds_alternative(tensor.get_storage()) || - std::holds_alternative(tensor.get_storage())) { + } else if (std::holds_alternative(tensor.get_storage()) || std::holds_alternative(tensor.get_storage()) || std::holds_alternative(tensor.get_storage())) { return 1; } else { TT_FATAL(false, "num_buffers_in_tensor only supports multi-device or device tensors"); @@ -533,64 +515,45 @@ uint32_t num_buffers_in_tensor(const Tensor& tensor) { Tensor get_shard_for_device(const Tensor& tensor, Device* target_device, std::optional buffer_index) { ZoneScopedN("GetShardForDevice"); Tensor shard = Tensor(); - auto& storage = tensor.tensor_attributes->storage; - std::visit( - [target_device, buffer_index, &tensor, &shard](auto&& s) { - using T = std::decay_t; - // Stalling reads for tensor data-type and layout are needed here - // since some worker might have raced ahead to these lookups, while - // another worker is populating this metadata. - if constexpr (std::is_same_v) { - shard = Tensor{ - DeviceStorage{s.get_buffer_for_device(target_device)}, - s.get_tensor_shape_for_device(target_device), - tensor.get_dtype(), - tensor.get_layout()}; - } else if constexpr (std::is_same_v) { - shard = Tensor{ - OwnedStorage{s.get_buffer(buffer_index.value())}, - s.get_tensor_shape(buffer_index.value()), - tensor.get_dtype(), - tensor.get_layout()}; - } else if constexpr ( - std::is_same_v || std::is_same_v || - std::is_same_v) { - shard = tensor; - } else { - TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors"); - } - }, - storage); + auto& storage = tensor.get_storage(); + std::visit([target_device, buffer_index, &tensor, &shard] (auto&& s) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + auto shard_shape = s.get_tensor_shape_for_device(target_device); + auto shard_buffer = s.get_buffer_for_device(target_device); + shard = Tensor{DeviceStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()}; + } else if constexpr (std::is_same_v) { + auto shard_shape = s.get_tensor_shape(buffer_index.value()); + auto shard_buffer = s.get_buffer(buffer_index.value()); + shard = Tensor{OwnedStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()}; + } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { + shard = tensor; + } else { + TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors"); + } + }, storage); return shard; } -void insert_buffer_and_shape_for_device( - Device* target_device, const Tensor& shard, Tensor& tensor_to_modify, std::optional buffer_index) { +void insert_buffer_and_shape_for_device(Device* target_device, const Tensor& shard, Tensor& tensor_to_modify, std::optional buffer_index) { ZoneScopedN("InsertBufferAndShapeForDevice"); - std::visit( - [target_device, &shard, &tensor_to_modify, buffer_index](auto&& s) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - s.insert_buffer_and_shape_for_device( - buffer_index.value(), - std::get(shard.tensor_attributes->storage).get_buffer(), - shard.tensor_attributes->shape.value()); - } else if constexpr (std::is_same_v) { - s.insert_buffer_and_shape_for_device( - target_device, - std::get(shard.tensor_attributes->storage).get_buffer(), - shard.tensor_attributes->shape.value()); - } else if constexpr (std::is_same_v) { - s.insert_buffer(std::get(shard.tensor_attributes->storage).get_buffer()); - } else if constexpr (std::is_same_v) { - s.insert_buffer(std::get(shard.tensor_attributes->storage).get_buffer()); - } else { - TT_FATAL(false, "Unsupported storage in insert_buffer_and_shape_for_device"); - } - }, - tensor_to_modify.tensor_attributes->storage); + std::visit([target_device, &shard, &tensor_to_modify, buffer_index] (auto&& s) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + s.insert_buffer_and_shape_for_device(buffer_index.value(), std::get(shard.get_storage()).get_buffer(), shard.get_legacy_shape()); + } else if constexpr (std::is_same_v) { + s.insert_buffer_and_shape_for_device(target_device, std::get(shard.get_storage()).get_buffer(), shard.get_legacy_shape()); + } else if constexpr (std::is_same_v) { + s.insert_buffer(std::get(shard.get_storage()).get_buffer()); + } else if constexpr (std::is_same_v) { + s.insert_buffer(std::get(shard.get_storage()).get_buffer()); + } else { + TT_FATAL(false, "Unsupported storage in insert_buffer_and_shape_for_device"); + } + }, tensor_to_modify.tensor_attributes->storage); } + Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) { // When using async mode, tensors with borrowed storage cannot be passed to workers. // They need to be copied to owned storage before being passed to the worker. @@ -598,26 +561,23 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) // Tensor has workers (on device) or runtime mode is synchronous or tensor has multiple buffers. // No need to check for borrowed storage. if (worker->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or - tensor.tensor_attributes->num_shards_to_be_populated > 1) - return tensor; + tensor.get_workers().size() or + tensor.tensor_attributes->tensor_populated.size() > 1) return tensor; if (tensor.storage_type() == StorageType::BORROWED) { ZoneScopedN("CopyBorrowedStorage"); auto borrowed_buffer = std::get(tensor.get_storage()).buffer; Tensor owned_tensor; - std::visit( - [&owned_tensor, &tensor](auto&& buffer) { - using BorrowedStorageType = std::vector>; - auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end())); - owned_tensor = - Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout()); - }, - borrowed_buffer); + std::visit([&owned_tensor, &tensor] (auto&& buffer) { + using BorrowedStorageType = std::vector>; + auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end())); + owned_tensor = Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout()); + }, borrowed_buffer); return owned_tensor; } return tensor; } -} // namespace tt_metal +} -} // namespace tt +} diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index c60ca89118c5..9c71b6f0d777 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -455,8 +455,7 @@ struct MultiDeviceHostStorage { std::vector ordered_device_ids; std::unordered_map buffers; std::unordered_map shapes; - mutable std::mutex buffer_mtx; - mutable std::mutex shape_mtx; + mutable std::mutex mtx; MultiDeviceStorage() = default; MultiDeviceStorage( @@ -466,14 +465,14 @@ struct MultiDeviceHostStorage { std::unordered_map shapes_) : strategy(strategy_), ordered_device_ids(ordered_device_ids_), buffers(buffers_), shapes(shapes_) {} MultiDeviceStorage(MultiDeviceStorage &&other) { - std::scoped_lock buf_lock(buffer_mtx, shape_mtx); + std::lock_guard lock(mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; } MultiDeviceStorage(const MultiDeviceStorage &other) { - std::scoped_lock buf_lock(buffer_mtx, shape_mtx); + std::lock_guard lock(other.mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -481,7 +480,7 @@ struct MultiDeviceHostStorage { } MultiDeviceStorage &operator=(const MultiDeviceStorage &other) { - std::scoped_lock buf_lock(buffer_mtx, shape_mtx); + std::lock_guard lock(other.mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -490,7 +489,7 @@ struct MultiDeviceHostStorage { } MultiDeviceStorage &operator=( MultiDeviceStorage &&other) { - std::scoped_lock buf_lock(buffer_mtx, shape_mtx); + std::lock_guard lock(mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -502,8 +501,8 @@ struct MultiDeviceHostStorage { return this->ordered_device_ids == other.ordered_device_ids and this->strategy == other.strategy and this->buffers == other.buffers and this->shapes == other.shapes; } - inline const MemoryConfig memory_config() const { - std::lock_guard lock(buffer_mtx); + const MemoryConfig memory_config() const { + std::lock_guard lock(mtx); if (this->buffers.at(0).get() == nullptr) { TT_THROW("MemoryConfig can only be obtained if the buffer is not null"); } @@ -523,54 +522,50 @@ struct MultiDeviceHostStorage { // Helper Functions - Getters and setters to get/modify storage attributes. These are needed to // preinitialize empty tensor handles and use/populate them in the worker threads. - - inline void insert_buffer_and_shape_for_device(Device* device, const DeviceBuffer buffer, const Shape shape) { + void insert_buffer_and_shape_for_device(Device* device, const DeviceBuffer buffer, const Shape shape) { TT_ASSERT(device == buffer->device(), "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); - { - std::lock_guard lock(buffer_mtx); - buffers.insert({device->id(), buffer}); - } - std::lock_guard lock(shape_mtx); + std::lock_guard lock(mtx); + buffers.insert({device->id(), buffer}); shapes.insert({device->id(), shape}); } inline DeviceBuffer get_buffer_for_device(Device* device) const { - std::lock_guard lock(buffer_mtx); + std::lock_guard lock(mtx); TT_ASSERT(buffers.find(device->id()) != buffers.end(), "Buffer not found for device " + std::to_string(device->id())); TT_ASSERT(buffers.at(device->id())->device() == device, "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); return buffers.at(device->id()); } inline DeviceBuffer& get_buffer_for_device(Device* device) { - std::lock_guard lock(buffer_mtx); + std::lock_guard lock(mtx); TT_ASSERT(buffers.find(device->id()) != buffers.end(), "Buffer not found for device " + std::to_string(device->id())); TT_ASSERT(buffers.at(device->id())->device() == device, "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); return buffers.at(device->id()); } inline DeviceBuffer get_buffer_for_device_id(uint32_t device_id) const { - std::lock_guard lock(buffer_mtx); + std::lock_guard lock(mtx); return buffers.at(device_id); } inline Shape get_tensor_shape_for_device(Device* device) const { - std::lock_guard lock(shape_mtx); + std::lock_guard lock(mtx); TT_ASSERT(shapes.find(device->id()) != shapes.end(), "Shape not found for device " + std::to_string(device->id())); return shapes.at(device->id()); } - inline uint32_t num_buffers() const { - std::lock_guard lock(buffer_mtx); + uint32_t num_buffers() const { + std::lock_guard lock(mtx); return buffers.size(); } inline bool has_buffer_for_device(Device* device) const { - std::lock_guard lock(buffer_mtx); + std::lock_guard lock(mtx); return buffers.find(device->id()) != buffers.end(); } inline bool has_buffer_for_device_id(uint32_t device_id) const { - std::lock_guard lock(buffer_mtx); + std::lock_guard lock(mtx); return buffers.find(device_id) != buffers.end(); } }; diff --git a/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp b/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp index cb6db5e822d7..9ecc86c31052 100644 --- a/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp +++ b/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp @@ -166,10 +166,10 @@ const operation::Hash EltwiseBinaryBroadcast::compute_program_hash( return operation::hash_operation( *this, parallelization_strategy, - std::get(input_tensors.at(0).storage()).memory_config(), - input_tensors.at(0).dtype(), - std::get(input_tensors.at(1).storage()).memory_config(), - input_tensors.at(1).dtype(), + input_tensors.at(0).memory_config(), + input_tensors.at(0).get_dtype(), + input_tensors.at(1).memory_config(), + input_tensors.at(1).get_dtype(), bcast_scalar, this->in_place); } diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp index ea091ce92695..6fdc8edfa8de 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp @@ -267,10 +267,10 @@ const operation::Hash EltwiseBinary::compute_program_hash(const std::vectorop_type, parallelization_strategy, - input_tensor_a.dtype(), - std::get(input_tensor_a.storage()).memory_config(), - input_tensor_b.dtype(), - std::get(input_tensor_b.storage()).memory_config(), + input_tensor_a.get_dtype(), + input_tensor_a.memory_config(), + input_tensor_b.get_dtype(), + input_tensor_b.memory_config(), this->output_dtype, this->output_mem_config, this->in_place); diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index d958fc0c1f0b..65b89afee03c 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -380,13 +380,13 @@ UnaryOpParallelizationStrategy EltwiseUnary::get_parallelization_strategy( const operation::Hash EltwiseUnary::compute_program_hash(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - const auto& input_shape = input_tensor.legacy_shape(); + const auto& input_shape = input_tensor.get_legacy_shape(); operation::Hash hash = tt::stl::hash::hash_objects_with_default_seed( typeid(*this).hash_code(), compute_volume(input_shape), - input_tensor.dtype(), - std::get(input_tensor.storage()).memory_config(), + input_tensor.get_dtype(), + input_tensor.memory_config(), this->output_mem_config); for (const auto& unary_with_param_op : this->op_chain) { diff --git a/tt_eager/tt_dnn/op_library/run_operation.cpp b/tt_eager/tt_dnn/op_library/run_operation.cpp index 4d53c4f4ebce..788cc30adf6f 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.cpp +++ b/tt_eager/tt_dnn/op_library/run_operation.cpp @@ -14,29 +14,26 @@ #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" -#include "tt_metal/tt_stl/reflection.hpp" #include "tt_numpy/functions.hpp" +#include "tt_metal/tt_stl/reflection.hpp" namespace tt::tt_metal::operation { namespace detail { inline bool any_tensor_on_multi_device(const Tensors& tensors) { - return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& tensor) { - return tensor.storage_type() == StorageType::MULTI_DEVICE; - }); + return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& tensor) { return tensor.storage_type() == StorageType::MULTI_DEVICE; }); } Device* get_device(const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors) { for (auto& input_tensor : input_tensors) { - if (std::holds_alternative(input_tensor.tensor_attributes->storage)) { - return input_tensor.workers.at(0); + if (input_tensor.storage_type() == StorageType::DEVICE) { + return input_tensor.device(); } } for (auto& optional_input_tensor : optional_input_tensors) { - if (optional_input_tensor.has_value() and - std::holds_alternative(optional_input_tensor.value().tensor_attributes->storage)) { - return optional_input_tensor.value().workers.at(0); + if (optional_input_tensor.has_value() and optional_input_tensor.value().storage_type() == StorageType::DEVICE) { + return optional_input_tensor.value().device(); } } auto device = AutoFormat::GetDefaultDevice(); @@ -46,19 +43,18 @@ Device* get_device(const Tensors& input_tensors, const OptionalConstTensors& opt void validate_op_launch(Device* worker) { if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS) { - TT_FATAL( - not worker->in_main_thread(), - "launch_op or launch_with_autoformat must be used when running in async mode."); + TT_FATAL(not worker->in_main_thread(), "launch_op or launch_with_autoformat must be used when running in async mode."); } } -template +template void override_addresses( const OverrideAddressesCallback& override_addresses_callback, - const Program& program, + const Program &program, const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, - const OutputTensors& output_tensors) { + const OutputTensors& output_tensors +) { std::vector input_buffers; for (auto& tensor : input_tensors) { input_buffers.push_back(tensor.buffer()); @@ -70,10 +66,11 @@ void override_addresses( std::vector output_buffers; for (auto& tensor : output_tensors) { - if constexpr (std::is_same_v) { + if constexpr(std::is_same_v){ auto buffer = tensor.has_value() ? tensor.value().buffer() : nullptr; output_buffers.push_back(buffer); - } else { + } + else{ output_buffers.push_back(tensor.buffer()); } } @@ -83,18 +80,19 @@ void override_addresses( template void override_addresses( const OverrideAddressesCallback& override_addresses_callback, - const Program& program, + const Program &program, const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, const Tensors& output_tensors); template void override_addresses( const OverrideAddressesCallback& override_addresses_callback, - const Program& program, + const Program &program, const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, const OptionalTensors& output_tensors); + template constexpr auto decorate_host_operation(const Function& function) { return [function](const Operation& operation, Args&&... args) { @@ -116,7 +114,7 @@ constexpr auto decorate_device_operation(const Function& function) { }; } -template +template OutputTensors run_host_operation(const HostOperation& operation, const Tensors& input_tensors) { ZoneScopedN("TT_DNN_HOST_OP"); uint32_t op_id = assign_id(); @@ -130,12 +128,11 @@ OutputTensors run_host_operation(const HostOperation& operation, } template Tensors run_host_operation(const HostOperation& operation, const Tensors& input_tensors); -template OptionalTensors run_host_operation( - const HostOperation& operation, const Tensors& input_tensors); +template OptionalTensors run_host_operation(const HostOperation& operation, const Tensors& input_tensors); inline const auto USE_FAST_DISPATCH = std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr; -template +template OutputTensors run_device_operation( std::reference_wrapper queue, const DeviceOperation& operation, @@ -174,12 +171,10 @@ OutputTensors run_device_operation( } if (not cache_hit) { - program_ptr = std::make_shared>( - operation.create_program(input_tensors, optional_input_tensors, output_tensors)); + program_ptr = std::make_shared>(operation.create_program(input_tensors, optional_input_tensors, output_tensors)); program_cache.insert(program_hash, program_ptr.value()); } - auto& program_with_callbacks = - *(reinterpret_cast*>(program_ptr.value().get())); + auto& program_with_callbacks = *(reinterpret_cast*>(program_ptr.value().get())); TT_ASSERT(program_with_callbacks.supports_program_cache()); if (cache_hit) { @@ -188,11 +183,7 @@ OutputTensors run_device_operation( auto override_addresses_callback = program_with_callbacks.override_addresses_callback.value(); // Deprecated override_addresses( - override_addresses_callback, - program_with_callbacks.program, - input_tensors, - optional_input_tensors, - output_tensors); + override_addresses_callback, program_with_callbacks.program, input_tensors, optional_input_tensors, output_tensors); } if (program_with_callbacks.override_runtime_arguments_callback.has_value()) { @@ -231,20 +222,18 @@ OutputTensors run_device_operation( [&operation, &input_tensors, &optional_input_tensors, &output_tensors, queue](auto&& program) { auto device = detail::get_device(input_tensors, optional_input_tensors); using T = std::decay_t; - if constexpr ( - std::is_same_v> || std::is_same_v>) { + if constexpr (std::is_same_v> || std::is_same_v> ) { if (USE_FAST_DISPATCH) { - // Program will temporarily own the input buffers. This is required, since with Async command - // queues, the input tensor can preemptively be deallocted on device, unless program maintains - // explicit ownership. This invocation of the program will give up ownership once its enqueued. - for (const auto& input_tensor : input_tensors) { + // Program will temporarily own the input buffers. This is required, since with Async command queues, the input + // tensor can preemptively be deallocted on device, unless program maintains explicit ownership. + // This invocation of the program will give up ownership once its enqueued. + for (const auto& input_tensor: input_tensors) { if (input_tensor.storage_type() == StorageType::DEVICE) { AssignGlobalBufferToProgram(input_tensor.device_buffer(), program); } } for (auto& optional_input_tensor : optional_input_tensors) { - if (optional_input_tensor.has_value() and - optional_input_tensor.value().storage_type() == StorageType::DEVICE) { + if (optional_input_tensor.has_value() and optional_input_tensor.value().storage_type() == StorageType::DEVICE) { AssignGlobalBufferToProgram(optional_input_tensor.value().device_buffer(), program); } } @@ -256,20 +245,10 @@ OutputTensors run_device_operation( }, program); - TracyOpTTNNDevice( - op_id, - program_hash, - program_cache.is_enabled(), - device_id, - operation, - program, - input_tensors, - optional_input_tensors, - output_tensors); + TracyOpTTNNDevice(op_id, program_hash, program_cache.is_enabled(), device_id, operation, program, input_tensors, optional_input_tensors, output_tensors); return output_tensors; } - template Tensors run_device_operation( std::reference_wrapper queue, const DeviceOperation& operation, @@ -284,16 +263,17 @@ template OptionalTensors run_device_operation( const OptionalConstTensors& optional_input_tensors, const OptionalTensors& optional_output_tensors); + } // namespace detail -template +template OutputTensors run(const HostOperation& operation, const Tensors& input_tensors) { return detail::decorate_host_operation(detail::run_host_operation)(operation, input_tensors); } template Tensors run(const HostOperation& operation, const Tensors& input_tensors); template OptionalTensors run(const HostOperation& operation, const Tensors& input_tensors); -template +template OutputTensors run( const DeviceOperation& operation, const Tensors& input_tensors, @@ -303,16 +283,15 @@ OutputTensors run( auto device = detail::get_device(input_tensors, optional_input_tensors); #ifdef DEBUG operation.validate(input_tensors, optional_input_tensors, optional_output_tensors); - detail::validate_op_launch(device); #endif + detail::validate_op_launch(device); return detail::decorate_device_operation(detail::run_device_operation)( std::ref(device->command_queue(cq_id)), operation, input_tensors, optional_input_tensors, optional_output_tensors); -} - + } template Tensors run( const DeviceOperation& operation, const Tensors& input_tensors, @@ -327,7 +306,7 @@ template OptionalTensors run( const OptionalTensors& optional_output_tensors, uint8_t cq_id); -template +template OutputTensors run_without_autoformat( const DeviceOperation& operation, const Tensors& input_tensors, @@ -349,8 +328,7 @@ OutputTensors run_without_autoformat( optional_input_tensors_on_dev.reserve(optional_input_tensors.size()); for (auto& optional_input_tensor : optional_input_tensors) { if (optional_input_tensor.has_value() and optional_input_tensor.value().storage_type() != StorageType::DEVICE) { - optional_input_tensors_on_dev.push_back( - AutoFormat::move_tensor_to_device(optional_input_tensor.value(), device)); + optional_input_tensors_on_dev.push_back(AutoFormat::move_tensor_to_device(optional_input_tensor.value(), device)); } else { optional_input_tensors_on_dev.push_back(optional_input_tensor); } @@ -370,7 +348,7 @@ template OptionalTensors run_without_autoformat( const OptionalConstTensors& optional_input_tensors, uint8_t cq_id); -template +template OutputTensors run_without_autoformat( const DeviceOperation& operation, const Tensors& input_tensors, @@ -393,8 +371,7 @@ OutputTensors run_without_autoformat( optional_input_tensors_on_dev.reserve(optional_input_tensors.size()); for (auto& optional_input_tensor : optional_input_tensors) { if (optional_input_tensor.has_value() and optional_input_tensor.value().storage_type() != StorageType::DEVICE) { - optional_input_tensors_on_dev.push_back( - AutoFormat::move_tensor_to_device(optional_input_tensor.value(), device)); + optional_input_tensors_on_dev.push_back(AutoFormat::move_tensor_to_device(optional_input_tensor.value(), device)); } else { optional_input_tensors_on_dev.push_back(optional_input_tensor); } @@ -425,6 +402,9 @@ Tensors run_with_autoformat( const bool pad_c, uint8_t cq_id) { ZoneScoped; + if (detail::any_tensor_on_multi_device(input_tensors)) { + return run(operation, input_tensors, optional_input_tensors); + } Device* device = detail::get_device(input_tensors, optional_input_tensors); detail::validate_op_launch(device); auto output_shapes = operation.compute_output_shapes(input_tensors); @@ -435,8 +415,7 @@ Tensors run_with_autoformat( auto padded_input_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape(), pad_c); auto pad_input = not AutoFormat::check_input_tensor_format(input_tensor, padded_input_shape); if (pad_input) { - formatted_input_tensors.push_back( - AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, pad_value, Layout::TILE)); + formatted_input_tensors.push_back(AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, pad_value, Layout::TILE)); } else { formatted_input_tensors.push_back(input_tensor); } @@ -450,8 +429,7 @@ Tensors run_with_autoformat( auto padded_input_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape(), pad_c); auto pad_input = not AutoFormat::check_input_tensor_format(input_tensor, padded_input_shape); if (pad_input) { - formatted_optional_input_tensors.push_back( - AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, pad_value, Layout::TILE)); + formatted_optional_input_tensors.push_back(AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, pad_value, Layout::TILE)); } else { formatted_optional_input_tensors.push_back(input_tensor); } @@ -482,6 +460,9 @@ Tensors run_with_autoformat( const std::vector>& optional_input_formatting, uint8_t cq_id) { ZoneScoped; + if (detail::any_tensor_on_multi_device(input_tensors)) { + return run(operation, input_tensors, optional_input_tensors); + } Device* device = detail::get_device(input_tensors, optional_input_tensors); detail::validate_op_launch(device); auto output_shapes = operation.compute_output_shapes(input_tensors); @@ -492,12 +473,7 @@ Tensors run_with_autoformat( Tensors formatted_input_tensors; formatted_input_tensors.reserve(input_tensors.size()); for (uint32_t i = 0; i < input_tensors.size(); ++i) { - formatted_input_tensors.push_back(AutoFormat::format_input_tensor( - input_tensors[i], - device, - input_formatting[i].pad_shape, - input_formatting[i].pad_value, - input_formatting[i].target_layout)); + formatted_input_tensors.push_back(AutoFormat::format_input_tensor(input_tensors[i], device, input_formatting[i].pad_shape, input_formatting[i].pad_value, input_formatting[i].target_layout)); } OptionalConstTensors formatted_optional_input_tensors; @@ -507,12 +483,7 @@ Tensors run_with_autoformat( auto& input_tensor = optional_input_tensors[i].value(); TT_ASSERT(optional_input_formatting[i].has_value()); auto& input_formatting = optional_input_formatting[i].value(); - formatted_optional_input_tensors.push_back(AutoFormat::format_input_tensor( - input_tensor, - device, - input_formatting.pad_shape, - input_formatting.pad_value, - input_formatting.target_layout)); + formatted_optional_input_tensors.push_back(AutoFormat::format_input_tensor(input_tensor, device, input_formatting.pad_shape, input_formatting.pad_value, input_formatting.target_layout)); } else { formatted_optional_input_tensors.push_back(optional_input_tensors[i]); } @@ -527,8 +498,7 @@ Tensors run_with_autoformat( formatted_optional_input_tensors.clear(); for (auto i = 0; i < output_tensors.size(); ++i) { - output_tensors[i] = - AutoFormat::format_output_tensor(output_tensors[i], output_shapes[i], device, output_layouts[i]); + output_tensors[i] = AutoFormat::format_output_tensor(output_tensors[i], output_shapes[i], device, output_layouts[i]); } return output_tensors; @@ -539,7 +509,8 @@ void launch_with_autoformat( const Tensors input_tensors, Tensors& output_tensors, const OptionalConstTensors optional_input_tensors, - const OptionalTensors optional_output_tensors) { + const OptionalTensors optional_output_tensors +) { // Mark each output tensor as having dynamic storage (can be on host or device, depending // on autoformat behaviour). Multi device tensors do not support dynamic storage. for (auto& output_tensor : output_tensors) { @@ -554,33 +525,28 @@ void launch_op( Tensors& output_tensors, const OptionalConstTensors optional_input_tensors, const OptionalTensors optional_output_tensors, - bool enable_autoformat_device) { + bool enable_autoformat_device +) { // Send host side op compile and run to the worker queue // Assert to ensure that worker threads are specified. ZoneScopedN("LaunchOp"); auto& workers = output_tensors.at(0).workers; std::size_t workers_size = workers.size(); - if (not enable_autoformat_device and workers.empty() or not workers.at(0)->in_main_thread()) { - // Run in main thread or immediately in worker thread + if (not enable_autoformat_device and workers.empty()) { + // Run on the host output_tensors = op_func(input_tensors, optional_input_tensors, optional_output_tensors); return; } for (auto& output_tensor : output_tensors) { - TT_FATAL( - output_tensor.workers.size(), - "Worker threads must be specified for outputs populated by launch_op. This API can only be used for " - "creating output tensors on device."); - TT_FATAL( - output_tensor.workers == workers, - "Worker threads must be consistent across all outputs populated by launch_op."); + TT_FATAL(output_tensor.workers.size(), "Worker threads must be specified for outputs populated by launch_op. This API can only be used for creating output tensors on device."); + TT_FATAL(output_tensor.workers == workers, "Worker threads must be consistent across all outputs populated by launch_op."); } validate_worker_modes(workers); // Record ref counts for all tensors before pushing to worker queue. std::vector input_tensor_ref_count = std::vector(input_tensors.size()); std::vector optional_input_tensor_ref_count = std::vector(optional_input_tensors.size()); std::vector output_tensor_ref_count = std::vector(output_tensors.size()); - std::vector optional_output_tensor_ref_count = std::vector(optional_output_tensors.size()); - ; + std::vector optional_output_tensor_ref_count = std::vector(optional_output_tensors.size());; std::vector async_safe_input_tensors = std::vector(input_tensors.size()); std::vector> async_safe_optional_input_tensors = {}; @@ -594,11 +560,10 @@ void launch_op( } for (int i = 0; i < optional_input_tensors.size(); i++) { if (optional_input_tensors[i].has_value()) { - async_safe_optional_input_tensors.push_back( - copy_borrowed_tensor_in_async_mode(workers.at(0), optional_input_tensors[i].value())); - optional_input_tensor_ref_count[i] = - async_safe_optional_input_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); - } else { + async_safe_optional_input_tensors.push_back(copy_borrowed_tensor_in_async_mode(workers.at(0), optional_input_tensors[i].value())); + optional_input_tensor_ref_count[i] = async_safe_optional_input_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); + } + else { async_safe_optional_input_tensors.push_back(std::nullopt); optional_input_tensor_ref_count[i] = 0; } @@ -608,9 +573,9 @@ void launch_op( } for (int i = 0; i < optional_output_tensors.size(); i++) { if (optional_output_tensors[i].has_value()) { - optional_output_tensor_ref_count[i] = - optional_output_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); - } else { + optional_output_tensor_ref_count[i] = optional_output_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); + } + else { optional_output_tensor_ref_count[i] = 0; } } @@ -621,18 +586,14 @@ void launch_op( if (workers_size == 1) { // Single worker per tensor and. for (int i = 0; i < async_safe_input_tensors.size(); i++) { - if (async_safe_input_tensors.at(i).get_workers().size() and - async_safe_input_tensors.at(i).get_workers().at(0) != workers.at(0)) { - // This input has a worker assigned that doesn't match the worker of the output being created (its - // shared). + if (async_safe_input_tensors.at(i).get_workers().size() and async_safe_input_tensors.at(i).get_workers().at(0) != workers.at(0)) { + // This input has a worker assigned that doesn't match the worker of the output being created (its shared). async_safe_input_tensors.at(i).tensor_attributes->num_sibling_workers_sharing_tensor++; cross_worker_input_tensor_idx.insert(i); } } for (int i = 0; i < async_safe_optional_input_tensors.size(); i++) { - if (async_safe_optional_input_tensors.at(i).has_value() and - async_safe_optional_input_tensors.at(i).value().get_workers().size() and - async_safe_optional_input_tensors.at(i).value().get_workers().at(0) != workers.at(0)) { + if (async_safe_optional_input_tensors.at(i).has_value() and async_safe_optional_input_tensors.at(i).value().get_workers().size() and async_safe_optional_input_tensors.at(i).value().get_workers().at(0) != workers.at(0)) { async_safe_optional_input_tensors.at(i).value().tensor_attributes->num_sibling_workers_sharing_tensor++; cross_worker_optional_input_tensor_idx.insert(i); } @@ -641,98 +602,89 @@ void launch_op( { ZoneScopedN("PushOpToWorkers"); - auto work_lambda = std::make_shared>( - [workers_size, - op_func, - optional_output_tensors, - async_safe_optional_input_tensors, - inputs = async_safe_input_tensors, - outputs = output_tensors, - shared_input_idx = cross_worker_input_tensor_idx, - shared_optional_input_idx = cross_worker_optional_input_tensor_idx](Device* target_device) mutable { - std::vector input_shards = std::vector(inputs.size(), Tensor()); - std::vector> optional_input_shards = {}; - std::vector> optional_output_shards = {}; - // Initialize all optional_outputs to std::nullopt - optional_output_shards.resize(optional_output_tensors.size()); - - { - ZoneScopedN("CreateShards"); - for (int i = 0; i < input_shards.size(); i++) { - input_shards[i] = get_shard_for_device(inputs[i], target_device); - } + auto work_lambda = std::make_shared>([workers_size, op_func, optional_output_tensors, async_safe_optional_input_tensors, inputs = async_safe_input_tensors, outputs = output_tensors, shared_input_idx = cross_worker_input_tensor_idx, shared_optional_input_idx = cross_worker_optional_input_tensor_idx] (Device* target_device) mutable { + std::vector input_shards = std::vector(inputs.size(), Tensor()); + std::vector> optional_input_shards = {}; + std::vector> optional_output_shards = {}; + // Initialize all optional_outputs to std::nullopt + optional_output_shards.resize(optional_output_tensors.size()); + + { + ZoneScopedN("CreateShards"); + for (int i = 0; i < input_shards.size(); i++) { + input_shards[i] = get_shard_for_device(inputs[i], target_device); + } - for (auto& input : async_safe_optional_input_tensors) { - if (input.has_value()) { - optional_input_shards.push_back(get_shard_for_device(input.value(), target_device)); - } else { - optional_input_shards.push_back(std::nullopt); - } + for (auto& input : async_safe_optional_input_tensors) { + if (input.has_value()) { + optional_input_shards.push_back(get_shard_for_device(input.value(), target_device)); + } + else { + optional_input_shards.push_back(std::nullopt); } + } - for (std::size_t optional_output_idx = 0; optional_output_idx < optional_output_tensors.size(); - optional_output_idx++) { - if (optional_output_tensors[optional_output_idx].has_value()) { - optional_output_shards[optional_output_idx] = get_shard_for_device( - optional_output_tensors[optional_output_idx].value(), target_device); - } + for (std::size_t optional_output_idx = 0; optional_output_idx < optional_output_tensors.size(); optional_output_idx++) { + if (optional_output_tensors[optional_output_idx].has_value()) { + optional_output_shards[optional_output_idx] = get_shard_for_device(optional_output_tensors[optional_output_idx].value(), target_device); } } + } - auto local_tensors = op_func(input_shards, optional_input_shards, optional_output_shards); + auto local_tensors = op_func(input_shards, optional_input_shards, optional_output_shards); - { - ZoneScopedN("OpPostProcess"); - // Release shared ownership of tensors belonging to other workers. - // If the workers for this tensor are stalled to deallocate - for (auto& shared_input : shared_input_idx) { - inputs.at(shared_input).tensor_attributes->num_sibling_workers_sharing_tensor--; - } + { + ZoneScopedN("OpPostProcess"); + // Release shared ownership of tensors belonging to other workers. + // If the workers for this tensor are stalled to deallocate + for (auto& shared_input : shared_input_idx) { + inputs.at(shared_input).tensor_attributes->num_sibling_workers_sharing_tensor--; + } - for (auto& shared_optional_input : shared_optional_input_idx) { - async_safe_optional_input_tensors.at(shared_optional_input) - .value() - .tensor_attributes->num_sibling_workers_sharing_tensor--; - } + for (auto& shared_optional_input : shared_optional_input_idx) { + async_safe_optional_input_tensors.at(shared_optional_input).value().tensor_attributes->num_sibling_workers_sharing_tensor--; + } - for (int i = 0; i < local_tensors.size(); i++) { - if (std::holds_alternative(local_tensors.at(i).tensor_attributes->storage)) { - TT_ASSERT( - outputs.at(i).tensor_attributes->dynamic_storage, - "launch_with_autoformat must be used if output tensor for op can be placed on host."); - // Make this a host side tensor - Set storage = Owned and clear workers - outputs.at(i).tensor_attributes->storage = OwnedStorage(); - outputs.at(i).workers = {}; - } else { - outputs.at(i).tensor_attributes->dynamic_storage = false; - } - insert_buffer_and_shape_for_device(target_device, local_tensors.at(i), outputs.at(i)); - int num_workers_completed = (outputs.at(i).tensor_attributes->num_workers_completed)++; - if (not num_workers_completed) { - outputs.at(i).tensor_attributes->shape = local_tensors.at(i).tensor_attributes->shape; - outputs.at(i).tensor_attributes->dtype = local_tensors.at(i).tensor_attributes->dtype; - outputs.at(i).tensor_attributes->layout = local_tensors.at(i).tensor_attributes->layout; - outputs.at(i).tensor_attributes->metadata_populated = true; - } + for (int i = 0; i < local_tensors.size(); i++) { + if (local_tensors.at(i).storage_type() == StorageType::OWNED) { + TT_ASSERT(outputs.at(i).tensor_attributes->dynamic_storage, "launch_with_autoformat must be used if output tensor for op can be placed on host."); + // Make this a host side tensor - Set storage = Owned and clear workers + outputs.at(i).tensor_attributes->storage = OwnedStorage(); + outputs.at(i).workers = {}; + } + else { + outputs.at(i).tensor_attributes->dynamic_storage = false; + } + insert_buffer_and_shape_for_device(target_device, local_tensors.at(i), outputs.at(i)); + if (not target_device->id() or workers_size == 1) { + outputs.at(i).set_shape(local_tensors.at(i).get_shape()); + outputs.at(i).set_dtype(local_tensors.at(i).get_dtype()); + outputs.at(i).set_layout(local_tensors.at(i).get_layout()); + } + if (workers_size == 1) { + outputs.at(i).set_populated(); + } + else { + outputs.at(i).set_populated(target_device); } } - }); + } + }); for (auto target_device : workers) { - target_device->push_work(std::make_shared>( - [target_device, work_lambda]() mutable { (*work_lambda)(target_device); })); + target_device->push_work(std::make_shared>([target_device, work_lambda] () mutable { + (*work_lambda)(target_device); + })); } } // Update ref counts of all tensors after push was performed (done only in main thread). for (int i = 0; i < async_safe_input_tensors.size(); i++) { - async_safe_input_tensors[i].tensor_attributes->update_main_thread_ref_count( - workers.at(0), input_tensor_ref_count[i]); + async_safe_input_tensors[i].tensor_attributes->update_main_thread_ref_count(workers.at(0), input_tensor_ref_count[i]); } for (int i = 0; i < async_safe_optional_input_tensors.size(); i++) { if (async_safe_optional_input_tensors[i].has_value()) { - async_safe_optional_input_tensors[i].value().tensor_attributes->update_main_thread_ref_count( - workers.at(0), optional_input_tensor_ref_count[i]); + async_safe_optional_input_tensors[i].value().tensor_attributes->update_main_thread_ref_count(workers.at(0), optional_input_tensor_ref_count[i]); } } for (int i = 0; i < output_tensors.size(); i++) { @@ -740,53 +692,37 @@ void launch_op( } for (int i = 0; i < optional_output_tensors.size(); i++) { if (optional_output_tensors[i].has_value()) { - optional_output_tensors[i].value().tensor_attributes->update_main_thread_ref_count( - workers.at(0), optional_output_tensor_ref_count[i]); + optional_output_tensors[i].value().tensor_attributes->update_main_thread_ref_count(workers.at(0), optional_output_tensor_ref_count[i]); } } } -void validate_workers_and_storage( - const std::vector& inputs, - const std::vector>& optional_inputs, - const std::vector& workers) { +void validate_workers_and_storage(const std::vector& inputs, const std::vector>& optional_inputs, const std::vector& workers) { bool single_device_storage = false; bool multi_device_storage = false; - // Verify that storage types are consistent - cannot mix single and multi-device storage. For multi-device tensors, - // ensure that workers are specified, since they cannot be inferred. This means that - // launch_op/launch_with_autoformat cannot be called with MultiDeviceHostStorage. - for (const auto& input : inputs) { - if (std::holds_alternative(input.tensor_attributes->storage) or - std::holds_alternative(input.tensor_attributes->storage)) { + // Verify that storage types are consistent - cannot mix single and multi-device storage. For multi-device tensors, ensure that workers are specified, since they cannot be inferred. + // This means that launch_op/launch_with_autoformat cannot be called with MultiDeviceHostStorage. + for (const auto& input: inputs) { + if (std::holds_alternative(input.tensor_attributes->storage) or std::holds_alternative(input.tensor_attributes->storage)) { single_device_storage |= true; - } else if ( - std::holds_alternative(input.tensor_attributes->storage) or - std::holds_alternative(input.tensor_attributes->storage)) { + } else if (std::holds_alternative(input.tensor_attributes->storage) or std::holds_alternative(input.tensor_attributes->storage)) { multi_device_storage |= true; } } for (auto& input : optional_inputs) { if (input.has_value()) { - if (std::holds_alternative(input.value().tensor_attributes->storage) or - std::holds_alternative(input.value().tensor_attributes->storage)) { + if (std::holds_alternative(input.value().tensor_attributes->storage) or std::holds_alternative(input.value().tensor_attributes->storage)) { single_device_storage |= true; - } else if ( - std::holds_alternative(input.value().tensor_attributes->storage) or - std::holds_alternative(input.value().tensor_attributes->storage)) { + } else if (std::holds_alternative(input.value().tensor_attributes->storage) or std::holds_alternative(input.value().tensor_attributes->storage)) { multi_device_storage |= true; } } } - TT_FATAL( - not(single_device_storage and multi_device_storage), - "Cannot mix single and multi-device tensors when calling launch op!"); + TT_FATAL(not (single_device_storage and multi_device_storage), "Cannot mix single and multi-device tensors when calling launch op!"); if (multi_device_storage) { - TT_FATAL( - workers.size(), - "Workers must be specified when calling launch_op with with multi-device tensors. Workers cannot be " - "inferred in this case."); + TT_FATAL(workers.size(), "Workers must be specified when calling launch_op with with multi-device tensors. Workers cannot be inferred in this case."); } } @@ -824,13 +760,10 @@ std::vector get_workers_for_op_output( // Workers not specified - inputs are on host and not multi-device. // Use the default device from autoformat. if (not workers_for_op.size()) { - TT_FATAL( - AutoFormat::GetDefaultDevice(), - "Default device must be specified using AutoFormat::SetDefaultDevice, if workers are not specified for " - "inputs to op."); + TT_FATAL(AutoFormat::GetDefaultDevice(), "Default device must be specified using AutoFormat::SetDefaultDevice, if workers are not specified for inputs to op."); workers_for_op = {AutoFormat::GetDefaultDevice()}; } } return workers_for_op; } -} // namespace tt::tt_metal::operation +} diff --git a/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp b/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp index c46675bcc7f3..d21e511e99b8 100644 --- a/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp +++ b/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp @@ -162,11 +162,11 @@ const operation::Hash Softmax::compute_program_hash( const std::vector &input_tensors, const std::vector>& optional_input_tensors) const { return operation::hash_operation( - std::get(input_tensors.at(0).storage()).memory_config(), - input_tensors.at(0).dtype(), - optional_input_tensors.at(0).has_value() ? std::optional{std::get(optional_input_tensors.at(0).value().storage()).memory_config()} + input_tensors.at(0).memory_config(), + input_tensors.at(0).get_dtype(), + optional_input_tensors.at(0).has_value() ? std::optional{optional_input_tensors.at(0).value().memory_config()} : std::nullopt, - optional_input_tensors.at(0).has_value() ? std::optional{optional_input_tensors.at(0).value().dtype()} + optional_input_tensors.at(0).has_value() ? std::optional{optional_input_tensors.at(0).value().get_dtype()} : std::nullopt, this->output_mem_config); } diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index 0af4c11bf4b2..da1fa273b773 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -292,10 +292,10 @@ const operation::Hash AttnMatmul::compute_program_hash(const std::vector this->transpose_hw, this->output_mem_config, this->output_dtype, - std::get(input_tensors.at(0).storage()).memory_config(), - input_tensors.at(0).dtype(), - std::get(input_tensors.at(1).storage()).memory_config(), - input_tensors.at(1).dtype()); + input_tensors.at(0).memory_config(), + input_tensors.at(0).get_dtype(), + input_tensors.at(1).memory_config(), + input_tensors.at(1).get_dtype()); } void GroupAttnMatmul::validate(const std::vector& input_tensors) const { @@ -502,14 +502,14 @@ const operation::Hash GroupAttnMatmul::compute_program_hash(const std::vectoroutput_mem_config.buffer_type, this->output_dtype, this->row_major, - std::get(input_tensor_a.storage()).memory_config().memory_layout, - std::get(input_tensor_a.storage()).memory_config().buffer_type, - input_tensor_a.dtype(), - std::get(input_tensor_b.storage()).buffer->device()->id(), - std::get(input_tensor_b.storage()).memory_config().memory_layout, - std::get(input_tensor_b.storage()).memory_config().buffer_type, - input_tensor_b.dtype(), - std::get(input_tensor_b.storage()).buffer->device()->id()); + input_tensor_a.memory_config().memory_layout, + input_tensor_a.memory_config().buffer_type, + input_tensor_a.get_dtype(), + input_tensor_a.device()->id(), + input_tensor_b.memory_config().memory_layout, + input_tensor_b.memory_config().buffer_type, + input_tensor_b.get_dtype(), + input_tensor_b.device()->id()); } // SSM eltwise mul diff --git a/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp b/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp index 1d3a6be8798a..2a06d74f1f0a 100644 --- a/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp +++ b/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp @@ -156,9 +156,9 @@ tt::stl::reflection::Attributes Transpose::attributes() const { const operation::Hash Transpose::compute_program_hash( const std::vector &input_tensors) const { auto input_tensor = input_tensors.at(0); - auto input_mem_config = std::get(input_tensor.storage()).memory_config(); + auto input_mem_config = input_tensor.memory_config(); auto output_mem_config = this->output_mem_config; - auto dtype = input_tensor.dtype(); + auto dtype = input_tensor.get_dtype(); return operation::hash_operation( input_mem_config, output_mem_config, dtype, this->dim, get_parallelization_strategy(input_tensors)); } diff --git a/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp b/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp index b8f437d21387..b2482bffa2a6 100644 --- a/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp @@ -147,19 +147,19 @@ tt::stl::reflection::Attributes Unpad::attributes() const { const operation::Hash Unpad::compute_program_hash(const std::vector &input_tensors) const { auto input_tensor = input_tensors.at(0); - auto input_mem_config = std::get(input_tensor.storage()).memory_config(); + auto input_mem_config = input_tensor.memory_config(); auto output_mem_config = this->output_mem_config; - auto dtype = input_tensor.dtype(); - auto num_dims = input_tensor.shape().rank(); + auto dtype = input_tensor.get_dtype(); + auto num_dims = input_tensor.get_legacy_shape().rank(); std::string rm_width = "TILE"; if (input_tensor.get_layout() == Layout::ROW_MAJOR) { - rm_width = fmt::format("{}", input_tensor.legacy_shape()[3]); + rm_width = fmt::format("{}", input_tensor.get_legacy_shape()[3]); } auto str = operation::hash_operation( num_dims, - input_tensor.layout(), + input_tensor.get_layout(), input_mem_config.memory_layout, input_mem_config.buffer_type, output_mem_config.memory_layout, diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index 7345da4c3360..235f4f7b0921 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -18,7 +18,7 @@ set(TT_METAL_OBJECTS add_library(tt_metal ${TT_METAL_OBJECTS}) if(BUILD_SHARED_LIBS) - target_link_libraries(tt_metal PUBLIC device metal_common_libs) + target_link_libraries(tt_metal PUBLIC device) add_dependencies(tt_metal umd_device) else() target_link_libraries(tt_metal PUBLIC ${UMD_STATIC_LIB} metal_common_libs) diff --git a/tt_metal/detail/tt_metal.hpp b/tt_metal/detail/tt_metal.hpp index bcc80005d875..507a58a3aa29 100644 --- a/tt_metal/detail/tt_metal.hpp +++ b/tt_metal/detail/tt_metal.hpp @@ -493,17 +493,5 @@ namespace tt::tt_metal{ specified_core_spec ); } - - inline void SynchronizeWorkerThreads(const std::vector& workers) { - // Push empty work to threads and ensure its been picked up - static auto empty_work = std::make_shared>([](){}); - for (auto target_device : workers) { - target_device->work_executor.push_work(empty_work); - } - // Block until work has been picked up, to flush the queue - for (auto target_device : workers) { - while(not target_device->work_executor.worker_queue.empty()); - } - } } } diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 6e9892c130c4..4d36a99e41d3 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -69,8 +69,8 @@ bool ActiveDevices::is_device_active(chip_id_t id) { } Device::Device( - chip_id_t device_id, const uint8_t num_hw_cqs, size_t l1_small_size, const std::vector &l1_bank_remap, bool minimal, uint32_t worker_core) : - id_(device_id), num_hw_cqs_(num_hw_cqs), worker_thread_core(worker_core), work_executor(worker_core, device_id) { + chip_id_t device_id, const uint8_t num_hw_cqs, size_t l1_small_size, const std::vector &l1_bank_remap, bool minimal) : + id_(device_id), num_hw_cqs_(num_hw_cqs), work_executor(device_id) { ZoneScoped; TT_ASSERT(num_hw_cqs > 0 and num_hw_cqs < 3, "num_hw_cqs can be between 1 and 2"); this->build_key_ = tt::Cluster::instance().get_harvesting_mask(device_id); diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 12df80a6bee1..ade5235ae9f3 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -77,8 +77,7 @@ class Device { const uint8_t num_hw_cqs, std::size_t l1_small_size, const std::vector &l1_bank_remap = {}, - bool minimal = false, - uint32_t worker_core = 0); + bool minimal = false); ~Device(); @@ -278,7 +277,6 @@ class Device { // Work Executor for this device - can asynchronously process host side work for // all tasks scheduled on this device WorkExecutor work_executor; - uint32_t worker_thread_core; std::unique_ptr sysmem_manager_; uint8_t num_hw_cqs_; diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 8b5ca124ab49..e0325cdddf30 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -1240,7 +1240,7 @@ HWCommandQueue::HWCommandQueue(Device* device, uint32_t id, NOC noc_index) : std::thread completion_queue_thread = std::thread(&HWCommandQueue::read_completion_queue, this); this->completion_queue_thread = std::move(completion_queue_thread); // Set the affinity of the completion queue reader. - set_device_thread_affinity(this->completion_queue_thread, device->worker_thread_core); + set_device_thread_affinity(this->completion_queue_thread, device->id()); this->expected_num_workers_completed = 0; } @@ -1934,29 +1934,24 @@ void HWCommandQueue::read_completion_queue() { }); } if (this->num_entries_in_completion_q > this->num_completed_completion_q_reads) { - ZoneScopedN("CompletionQueueReader"); uint32_t num_events_to_read = this->num_entries_in_completion_q - this->num_completed_completion_q_reads; for (uint32_t i = 0; i < num_events_to_read; i++) { - ZoneScopedN("CompletionQueuePopulated"); - std::variant read_descriptor = *(this->issued_completion_q_reads.pop()); - { - ZoneScopedN("CompletionQueueWait"); - this->manager.completion_queue_wait_front(this->id, this->exit_condition); // CQ DISPATCHER IS NOT HANDSHAKING WITH HOST RN - } + std::variant read_descriptor = + *(this->issued_completion_q_reads.pop()); + + this->manager.completion_queue_wait_front( + this->id, this->exit_condition); // CQ DISPATCHER IS NOT HANDSHAKING WITH HOST RN + if (this->exit_condition) { // Early exit return; } std::visit( - [&](auto&& read_descriptor) - { + [&](auto&& read_descriptor) { using T = std::decay_t; if constexpr (std::is_same_v) { - ZoneScopedN("CompletionQueueReadData"); this->copy_into_user_space(read_descriptor, mmio_device_id, channel); - } - else if constexpr (std::is_same_v) { - ZoneScopedN("CompletionQueueReadEvent"); + } else if constexpr (std::is_same_v) { uint32_t read_ptr = this->manager.get_completion_queue_read_ptr(this->id); thread_local static std::vector dispatch_cmd_and_event( (sizeof(CQDispatchCmd) + dispatch_constants::EVENT_PADDED_SIZE) / sizeof(uint32_t)); diff --git a/tt_metal/impl/dispatch/work_executor.hpp b/tt_metal/impl/dispatch/work_executor.hpp index a164f3a8795d..323f5e7f7e29 100644 --- a/tt_metal/impl/dispatch/work_executor.hpp +++ b/tt_metal/impl/dispatch/work_executor.hpp @@ -44,11 +44,12 @@ enum class WorkerState { IDLE = 2, }; -inline void set_device_thread_affinity(std::thread& thread_, int cpu_core_for_worker) { +inline void set_device_thread_affinity(std::thread& thread_, int managed_device_id) { // Bind a device worker/reader thread to a CPU core, determined using round-robin. + static int num_online_cores = sysconf(_SC_NPROCESSORS_ONLN); cpu_set_t cpuset; CPU_ZERO(&cpuset); - CPU_SET(cpu_core_for_worker, &cpuset); + CPU_SET(managed_device_id % num_online_cores, &cpuset); int rc = pthread_setaffinity_np(thread_.native_handle(), sizeof(cpu_set_t), &cpuset); if (rc) { log_warning( @@ -79,7 +80,7 @@ class WorkExecutor { public: LockFreeQueue> worker_queue; - WorkExecutor(int cpu_core, int device_id) : cpu_core_for_worker(cpu_core), managed_device_id(device_id) { + WorkExecutor(int device_id) : managed_device_id(device_id) { set_process_priority(0); if (this->work_executor_mode == WorkExecutorMode::ASYNCHRONOUS) { this->set_worker_queue_mode(this->worker_queue_mode); @@ -88,16 +89,14 @@ class WorkExecutor { } WorkExecutor(WorkExecutor&& other) { - worker_state = std::move(other.worker_state); - cpu_core_for_worker = std::move(other.managed_device_id); - managed_device_id = std::move(other.managed_device_id); + worker_state = other.worker_state; + managed_device_id = other.managed_device_id; } WorkExecutor& operator=(WorkExecutor &&other) { if (this != &other) { worker_state = std::move(other.worker_state); managed_device_id = std::move(other.managed_device_id); - cpu_core_for_worker = std::move(other.cpu_core_for_worker); } return *this; } @@ -219,7 +218,6 @@ class WorkExecutor { private: std::thread worker_thread; WorkerState worker_state = WorkerState::IDLE; - int cpu_core_for_worker = 0; int managed_device_id = 0; std::condition_variable cv; std::mutex cv_mutex; @@ -230,7 +228,7 @@ class WorkExecutor { this->worker_thread = std::thread(&WorkExecutor::run_worker, this); this->worker_queue.worker_thread_id = std::hash{}(this->worker_thread.get_id()); // Bind a worker tied to a device to a specific CPU core in round robin fashion. Thread affinity == Better Perf. - set_device_thread_affinity(this->worker_thread, this->cpu_core_for_worker); + set_device_thread_affinity(this->worker_thread, this->managed_device_id); } inline void stop_worker() { diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 4ce64b5b07a3..2038c3b4baea 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -4,7 +4,6 @@ #include "tt_metal/detail/tt_metal.hpp" -#include #include #include #include @@ -172,78 +171,6 @@ std::vector devices; } // namespace device_pool -namespace device_cpu_allocator { -std::unordered_map> get_cpu_cores_per_numa_node(std::unordered_set &free_cores) { - std::unordered_map> cpu_cores_per_numa_node = {}; - if (numa_available() != -1) { - // Host has NUMA enabled. Group CPU IDs by the NUMA nodes they belong to. - for (int cpu = 0; cpu < numa_num_configured_cpus(); ++cpu) { - int node = numa_node_of_cpu(cpu); - if (cpu_cores_per_numa_node.find(node) == cpu_cores_per_numa_node.end()) { - cpu_cores_per_numa_node.insert({node, {}}); - } - free_cores.insert(cpu); - cpu_cores_per_numa_node.at(node).push_back(cpu); - } - } else { - // Host does not have NUMA. Place all CPU Ids under a single node (0). - log_warning(tt::LogMetal, "Host does not use NUMA. May see reduced performance."); - for (int cpu = 0; cpu < sysconf(_SC_NPROCESSORS_ONLN); ++cpu) { - free_cores.insert(cpu); - } - } - return cpu_cores_per_numa_node; -} - -int get_cpu_core_for_device_worker_thread( - int mmio_controlled_device_id, - const std::unordered_map> &cpu_cores_per_numa_node, - std::unordered_set &free_cores) { - int core_assigned_to_device = 0; - if (numa_available() != -1) { - // Get NUMA node that the current device is mapped to through UMD - int numa_node_for_device = tt::Cluster::instance().get_numa_node_for_device(mmio_controlled_device_id); - if (cpu_cores_per_numa_node.find(numa_node_for_device) != cpu_cores_per_numa_node.end()) { - // NUMA node reported by UMD exists on host. Choose a core on this numa-node using round robin policy - int num_cores_in_numa_node = cpu_cores_per_numa_node.at(numa_node_for_device).size(); - core_assigned_to_device = - cpu_cores_per_numa_node.at(numa_node_for_device).at(mmio_controlled_device_id % num_cores_in_numa_node); - } else { - // NUMA node reported by UMD does not exist on host. Use round-robin binding policy for this worker thread. - log_warning( - tt::LogMetal, - "NUMA node {} for device {} does not exist on host.", - numa_node_for_device, - mmio_controlled_device_id); - core_assigned_to_device = mmio_controlled_device_id % sysconf(_SC_NPROCESSORS_ONLN); - } - } else { - // System does not use NUMA. Use-round robin binding strategy. - core_assigned_to_device = mmio_controlled_device_id % sysconf(_SC_NPROCESSORS_ONLN); - } - free_cores.erase(core_assigned_to_device); - return core_assigned_to_device; -} - -void bind_current_thread_to_free_cores(const std::unordered_set &free_cores) { - cpu_set_t cpuset; - pthread_t current_thread = pthread_self(); - CPU_ZERO(&cpuset); - - for (const auto &free_core : free_cores) { - CPU_SET(free_core, &cpuset); - } - int rc = pthread_setaffinity_np(current_thread, sizeof(cpu_set_t), &cpuset); - if (rc) { - log_warning( - tt::LogMetal, - "Unable to bind main thread to free CPU cores. May see performance degradation. Error Code: {}", - rc); - } -} - -} // namespace device_cpu_allocator - namespace detail { std::map CreateDevices( @@ -253,32 +180,20 @@ std::map CreateDevices( const std::vector &l1_bank_remap) { ZoneScoped; std::map active_devices; // TODO: pass this to CloseDevices - // Construct NUMA Node to CPU core map - std::unordered_set free_cores = {}; - auto cpu_cores_per_numa_node = device_cpu_allocator::get_cpu_cores_per_numa_node(free_cores); - for (const auto &device_id : device_ids) { const auto &mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); if (active_devices.find(mmio_device_id) == active_devices.end()) { for (const auto &mmio_controlled_device_id : tt::Cluster::instance().get_devices_controlled_by_mmio_device(mmio_device_id)) { - int core_assigned_to_device = device_cpu_allocator::get_cpu_core_for_device_worker_thread( - mmio_controlled_device_id, cpu_cores_per_numa_node, free_cores); - Device *dev = new Device( - mmio_controlled_device_id, - num_hw_cqs, - l1_small_size, - l1_bank_remap, - false, - core_assigned_to_device); + // if (mmio_controlled_device_id != mmio_device_id) { + // continue; + // } + Device *dev = new Device(mmio_controlled_device_id, num_hw_cqs, l1_small_size, l1_bank_remap); active_devices.insert({mmio_controlled_device_id, dev}); detail::InitDeviceProfiler(dev); } } } - // Bind main thread to cores not being used by workers. - device_cpu_allocator::bind_current_thread_to_free_cores(free_cores); - // TODO: need to only enable routing for used mmio chips tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true); return active_devices; @@ -751,10 +666,12 @@ void CompileProgram(Device *device, Program &program) { } void AllocateBuffer(Buffer *buffer, bool bottom_up) { + detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); EnqueueAllocateBuffer(buffer->device()->command_queue(), buffer, bottom_up, false); } void DeallocateBuffer(Buffer *buffer) { + detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); EnqueueDeallocateBuffer( buffer->device()->command_queue(), *(buffer->device()->allocator_), @@ -764,6 +681,7 @@ void DeallocateBuffer(Buffer *buffer) { } void GetBufferAddress(const Buffer *buffer, uint32_t *address_on_host) { + detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); EnqueueGetBufferAddr(buffer->device()->command_queue(), address_on_host, buffer, false); } @@ -802,14 +720,7 @@ Device *CreateDevice( const size_t l1_small_size, const std::vector &l1_bank_remap) { ZoneScoped; - // Construct NUMA Node to CPU core map - std::unordered_set free_cores = {}; - auto cpu_cores_per_numa_node = device_cpu_allocator::get_cpu_cores_per_numa_node(free_cores); - int core_assigned_to_device = - device_cpu_allocator::get_cpu_core_for_device_worker_thread(device_id, cpu_cores_per_numa_node, free_cores); - Device *dev = new Device(device_id, num_hw_cqs, l1_small_size, l1_bank_remap, false, core_assigned_to_device); - // Bind main thread to cores not being used by workers. - device_cpu_allocator::bind_current_thread_to_free_cores(free_cores); + Device *dev = new Device(device_id, num_hw_cqs, l1_small_size, l1_bank_remap); tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true); detail::InitDeviceProfiler(dev); return dev; diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp index 243b6ef4808a..5569bd65ab4e 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp @@ -296,10 +296,10 @@ const operation::Hash Binary::compute_program_hash(const std::vector& in typeid(*this).hash_code(), this->program_config, program_type, - input_tensor_a.dtype(), - std::get(input_tensor_a.storage()).memory_config(), - input_tensor_b.dtype(), - std::get(input_tensor_b.storage()).memory_config()); + input_tensor_a.get_dtype(), + input_tensor_a.memory_config(), + input_tensor_b.get_dtype(), + input_tensor_b.memory_config()); return hash; } From ef16db472dd78378603e2b7049026cf2415e89e1 Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Tue, 4 Jun 2024 18:15:28 +0000 Subject: [PATCH 19/53] #5389: disabled failing moreh tests --- .../python_api_testing/unit_testing/misc/test_moreh_getitem.py | 1 + .../python_api_testing/unit_testing/misc/test_moreh_nll_loss.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py index 989c0430d544..73e567134b7f 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py @@ -20,6 +20,7 @@ def to_output_4d_shape(shape, index_dims, index_size): return output_4d_shape +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dim", ( diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py index 7bd8b21160e7..c278e3dfcb85 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py @@ -207,6 +207,7 @@ def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_prog assert passing +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape", [ From d3e3dc21fcfa613239c3e1b78d89e1c1f1f51174 Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Tue, 4 Jun 2024 18:48:06 +0000 Subject: [PATCH 20/53] #5389: disabled failing moreh tests --- .../python_api_testing/unit_testing/misc/test_moreh_getitem.py | 1 + .../python_api_testing/unit_testing/misc/test_moreh_nll_loss.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py index 73e567134b7f..426e379194c9 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py @@ -80,6 +80,7 @@ def test_getitem_RAW_MJOR_one_index(shape_index_dim, dtype, index_size, device): assert passing +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", ( diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py index c278e3dfcb85..af6d27c8e717 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py @@ -291,6 +291,7 @@ def test_moreh_nll_loss_backward( assert passing +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape", [ From 25213e1821fc504e7f599e4e580220d3bb71b2eb Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Tue, 4 Jun 2024 18:48:06 +0000 Subject: [PATCH 21/53] #5389: disabled failing moreh tests --- .../unit_testing/misc/test_moreh_getitem.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py index 426e379194c9..345dc51fe2bb 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py @@ -139,6 +139,7 @@ def test_getitem_RAW_MAJOR_two_indices(shape_index_dims, dtype, index_size, devi assert passing +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", (((10, 15, 7, 80), (0, 1, 2)),), @@ -192,6 +193,7 @@ def test_getitem_RAW_MAJOR_three_indices(shape_index_dims, dtype, index_size, de assert passing +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dim", ( @@ -284,6 +286,7 @@ def test_getitem_tilized_one_index(shape_index_dim, dtype, index_size, row_major assert passing +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", ( @@ -369,6 +372,7 @@ def test_getitem_tilized_two_indices(shape_index_dims, dtype, index_size, row_ma assert passing +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", ( @@ -451,6 +455,7 @@ def test_getitem_tilized_three_indices(shape_index_dims, dtype, index_size, row_ assert passing +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", (((10, 15, 7, 80), (0, 1, 2, 3)),), From baef03c8a0fff6e10e463c40f9e44e2fdc3d7e0c Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Tue, 4 Jun 2024 19:05:37 +0000 Subject: [PATCH 22/53] #0: Update Resnet perf numbers --- README.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index d97293bb9de0..bcee552db2bb 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ | Model | Batch | End-to-end throughput [1] | Device throughput [2] | Target | |---------------------------------------------------------- |---------------------|------------------------------|-----------------------------|-------------------------------------| -| [ResNet-50](./models/demos/resnet) (fps) | 20 | 2,850 | 7,200 | 10,000 | +| [ResNet-50](./models/demos/resnet) (fps) | 20 | 4,400 | 7,700 | 10,000 | | [BERT-Large](./models/demos/bert) (sen/s) | 12 | 362 | 406 | 410 | | [Falcon7B-decode](./models/demos/ttnn_falcon7b) (t/s) | 32 | 135 | 135 | 140 | | [ViT](./models/demos/grayskull/vit) (fps) | 8 | 480 | 1570 | 2000 | @@ -42,13 +42,13 @@ > > All model demos in this table function on both N150 and N300 Wormhole cards, unless otherwise stated. -| Model | Gen. Token [3] | Batch | End-to-end throughput [1] | Device throughput [2] | Target | -|-------------------------------------------------------------|--------------------|----------------------|------------------------------|-----------------------------|----------------| -| [Falcon7B-decode](./models/demos/wormhole/falcon7b) | 129th | 32 | 11.6 t/s/u - 371 t/s | 15.4 t/s/u - 493 t/s | 21 t/s/u | -| [Mistral-7B-decode](./models/demos/wormhole/mistral7b) | 33rd | 32 | 10.9 t/s/u - 349 t/s | 13.3 t/s/u - 426 t/s | 21 t/s/u | -| [Mamba-2.8B-decode](./models/demos/mamba) | any | 32 | 9.2 t/s/u - 295 t/s | 13.1 t/s/u - 419 t/s | 22 t/s/u | -| [BERT-Large](./models/demos/metal_BERT_large_11/) (sen/s) [4] | any | 8 | 270 | 340 | 400 | -| [Stable Diffusion 1.4](./models/demos/wormhole/stable_diffusion) 512x512 (sec/img) | | 1 | 8s | 5s | | +| Model | Gen. Token [3] | Batch | End-to-end throughput [1] | Device throughput [2] | Target | +|--------------------------------------------------------------------------------------|--------------------|----------------------|------------------------------|-----------------------------|----------------| +| [Falcon7B-decode](./models/demos/wormhole/falcon7b) | 129th | 32 | 11.6 t/s/u - 371 t/s | 15.4 t/s/u - 493 t/s | 21 | +| [Mistral-7B-decode](./models/demos/wormhole/mistral7b) | 33rd | 32 | 10.9 t/s/u - 349 t/s | 13.3 t/s/u - 426 t/s | 21 | +| [Mamba-2.8B-decode](./models/demos/mamba) | any | 32 | 9.2 t/s/u - 295 t/s | 13.1 t/s/u - 419 t/s | 22 | +| [BERT-Large](./models/demos/metal_BERT_large_11/) (sen/s) [4] | | 8 | 270 | 340 | 400 | +| [Stable Diffusion 1.4](./models/demos/wormhole/stable_diffusion) 512x512 (sec/img) | | 1 | 8 | 5 | | [1] - Observed from the host. Includes dispatch overhead and kernel execution time. From 27bc4ba8b904032a255aa6c4a9f7637663cde83b Mon Sep 17 00:00:00 2001 From: Paul Keller Date: Tue, 21 May 2024 17:49:38 +0000 Subject: [PATCH 23/53] #7907: Fix prefetcher bug in relay_linear In relay_linear cmd, landing on exactly a page boundary could cause an extra page to be released Bug is also in dram_paged cmds, however, padding to dram alignment skirted the issue Found while working on splitting dispatcher for streams --- tt_metal/impl/dispatch/kernels/cq_prefetch.cpp | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp index 0124d992b2c4..6c6a6c5d8d6d 100644 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp @@ -342,16 +342,21 @@ static uint32_t process_relay_inline_noflush_cmd(uint32_t cmd_ptr, return CQ_PREFETCH_CMD_BARE_MIN_SIZE; } -template static uint32_t write_pages_to_dispatcher(uint32_t& downstream_data_ptr, uint32_t& scratch_write_addr, uint32_t& amt_to_write) { uint32_t page_residual_space = downstream_cb_page_size - (downstream_data_ptr & (downstream_cb_page_size - 1)); - uint32_t npages = (amt_to_write - page_residual_space + downstream_cb_page_size + extra_space - 1) / downstream_cb_page_size; + uint32_t npages = (amt_to_write - page_residual_space + downstream_cb_page_size - round) / downstream_cb_page_size; // Grabbing all pages at once is ok if scratch_size < 3 * downstream_cb_block_size + // test_for_nonzero is an optimization: inner loops moving lots of pages don't bother if (!test_for_nonzero || npages != 0) { cb_acquire_pages(npages); } @@ -465,7 +470,7 @@ uint32_t process_relay_paged_cmd_large(uint32_t cmd_ptr, uint32_t amt_to_write = write_length; ASSERT((amt_to_write & 0x1f) == 0); - uint32_t npages = write_pages_to_dispatcher + uint32_t npages = write_pages_to_dispatcher<1, true> (downstream_data_ptr, scratch_write_addr, amt_to_write); // One page was acquired w/ the cmd in CMD_RELAY_INLINE_NOFLUSH with 16 bytes written @@ -578,7 +583,7 @@ uint32_t process_relay_paged_cmd(uint32_t cmd_ptr, scratch_write_addr = scratch_db_top[db_toggle]; uint32_t amt_to_write = amt_read - cmd->relay_paged.length_adjust; ASSERT((amt_to_write & 0x1f) == 0); - uint32_t npages = write_pages_to_dispatcher + uint32_t npages = write_pages_to_dispatcher<1, true> (downstream_data_ptr, scratch_write_addr, amt_to_write); downstream_data_ptr = round_up_pow2(downstream_data_ptr, downstream_cb_page_size); @@ -644,7 +649,7 @@ uint32_t process_relay_linear_cmd(uint32_t cmd_ptr, // Third step - write from DB scratch_write_addr = scratch_db_top[db_toggle]; uint32_t amt_to_write = amt_to_read; - uint32_t npages = write_pages_to_dispatcher + uint32_t npages = write_pages_to_dispatcher<1, true> (downstream_data_ptr, scratch_write_addr, amt_to_write); downstream_data_ptr = round_up_pow2(downstream_data_ptr, downstream_cb_page_size); From 738300aa285ea27e0cd7410704e95e19801fb4a9 Mon Sep 17 00:00:00 2001 From: Paul Keller Date: Tue, 21 May 2024 17:56:47 +0000 Subject: [PATCH 24/53] #0: New prefetcher tests for linear reads Randomized test isn't fully baked, needs infra fix. This includes smoke test for issues recently found --- .../perf_microbenchmark/dispatch/common.h | 2 + .../dispatch/test_prefetcher.cpp | 91 +++++++++++++++---- 2 files changed, 74 insertions(+), 19 deletions(-) diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h index 249f6bc0974c..b0611bc04e92 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h @@ -754,6 +754,8 @@ inline void gen_bare_dispatcher_unicast_write_cmd(Device *device, cmd.write_linear.length = length; cmd.write_linear.num_mcast_dests = 0; + TT_FATAL((cmd.write_linear.addr & (16 - 1)) == 0); // XXXXX L1_ALIGNMENT16 + add_bare_dispatcher_cmd(cmds, cmd); } diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp index 76571678b520..2793fbac4dc9 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp @@ -470,6 +470,25 @@ void gen_dram_write_cmd(Device *device, add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); } +void gen_wait_and_stall_cmd(Device *device, + vector& prefetch_cmds, + vector& cmd_sizes) { + + vector dispatch_cmds; + + CQDispatchCmd wait; + wait.base.cmd_id = CQ_DISPATCH_CMD_WAIT; + wait.wait.barrier = true; + wait.wait.notify_prefetch = true; + wait.wait.addr = dispatch_wait_addr_g; + wait.wait.count = 0; + add_bare_dispatcher_cmd(dispatch_cmds, wait); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + + vector empty_payload; // don't give me grief, it is just a test + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_STALL, empty_payload); +} + // This is pretty much a blit: copies from worker core's start of data back to the end of data void gen_linear_read_cmd(Device *device, vector& prefetch_cmds, @@ -482,6 +501,9 @@ void gen_linear_read_cmd(Device *device, vector dispatch_cmds; const uint32_t bank_id = 0; // No interleaved pages here. + // Stall because we are reading data that was previously written + gen_wait_and_stall_cmd(device, prefetch_cmds, cmd_sizes); + gen_bare_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, length); add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE_NOFLUSH, dispatch_cmds); @@ -498,25 +520,7 @@ void gen_linear_read_cmd(Device *device, for (uint32_t i = 0; i < length_words; i++) { device_data.push_one(worker_core, device_data.at(worker_core, bank_id, offset + i)); } -} - -void gen_wait_and_stall_cmd(Device *device, - vector& prefetch_cmds, - vector& cmd_sizes) { - - vector dispatch_cmds; - - CQDispatchCmd wait; - wait.base.cmd_id = CQ_DISPATCH_CMD_WAIT; - wait.wait.barrier = true; - wait.wait.notify_prefetch = true; - wait.wait.addr = dispatch_wait_addr_g; - wait.wait.count = 0; - add_bare_dispatcher_cmd(dispatch_cmds, wait); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - - vector empty_payload; // don't give me grief, it is just a test - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_STALL, empty_payload); + device_data.pad(worker_core, bank_id, 16); // XXXX L1_ALIGNMENT } void gen_dispatcher_delay_cmd(Device *device, @@ -654,6 +658,7 @@ void gen_host_test(Device *device, uint32_t new_size = (prefetch_cmds.size() - prior_end) * sizeof(uint32_t); cmd_sizes.push_back(new_size >> dispatch_constants::PREFETCH_Q_LOG_MINSIZE); + // write host writes the command back to the host for (auto datum : dispatch_cmds) { device_data.push_one(device_data.get_host_core(), 0, datum); } @@ -664,6 +669,28 @@ void gen_host_test(Device *device, } } +void gen_rnd_linear_cmd(Device *device, + vector& prefetch_cmds, + vector& cmd_sizes, + DeviceData& device_data, + CoreCoord worker_core) { + + vector dispatch_cmds; + + // Hmm, how big a size to test? + int max_linear_cmd_read_size = 20 * dispatch_buffer_page_size_g; // XXXXX 10 * + uint32_t size = std::rand() % max_linear_cmd_read_size; + size &= ~(sizeof(uint32_t) - 1); + uint32_t offset = std::rand() % dispatch_buffer_page_size_g; + offset = (offset >> 2) << 2; + device_data.relevel(CoreType::WORKER); // XXXXX shouldn't be needed + if (device_data.size_at(worker_core, 0) * sizeof(uint32_t) < max_linear_cmd_read_size + offset) { + // Not enough data yet, just bail on this cmd + return; + } + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, size, offset); +} + void gen_rnd_dram_paged_cmd(Device *device, vector& prefetch_cmds, vector& cmd_sizes, @@ -762,6 +789,11 @@ void gen_rnd_test(Device *device, CoreCoord worker_core(first_worker_g.x + x, first_worker_g.y + y); switch (cmd) { + case CQ_PREFETCH_CMD_RELAY_LINEAR: + // TODO: disabled for now + // test issue w/ handling re-leveling of results data after paged commands + //gen_rnd_linear_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core); + break; case CQ_PREFETCH_CMD_RELAY_PAGED: gen_rnd_dram_paged_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core); break; @@ -896,6 +928,23 @@ void gen_smoke_test(Device *device, gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, 8448); add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + // Check some hard page alignment sizes + dispatch_cmds.resize(0); + gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, dispatch_buffer_page_size_g); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + + dispatch_cmds.resize(0); + gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + + dispatch_cmds.resize(0); + gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, 2 * dispatch_buffer_page_size_g); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + + dispatch_cmds.resize(0); + gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, 2 * dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + // Merge 4 commands in the FetchQ dispatch_cmds.resize(0); gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, 112); @@ -991,6 +1040,10 @@ void gen_smoke_test(Device *device, // These tests copy data from earlier tests so can't run first gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, 32); gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, 65 * 1024); + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, dispatch_buffer_page_size_g); + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, 2 * dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, 2 * dispatch_buffer_page_size_g); // Test wait/stall gen_dispatcher_delay_cmd(device, prefetch_cmds, cmd_sizes, 1024 * 1024); From f2fbda657dd4d41d717f4e591e1d7179c9d63659 Mon Sep 17 00:00:00 2001 From: Paul Keller Date: Thu, 23 May 2024 21:05:09 +0000 Subject: [PATCH 25/53] #0: Improve test_prefetcher host write tests --- .../perf_microbenchmark/dispatch/common.h | 5 +- .../dispatch/test_prefetcher.cpp | 81 ++++++++++++------- 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h index b0611bc04e92..d6a0344f9ddf 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h @@ -372,7 +372,8 @@ inline bool DeviceData::validate_one_core(Device *device, bool DeviceData::validate_host(std::unordered_set &validated_cores, const one_core_data_t& host_data) { - log_info(tt::LogTest, "Validating data from hugepage"); + uint32_t size_bytes = host_data.data.size() * sizeof(uint32_t); + log_info(tt::LogTest, "Validating {} bytes from hugepage", size_bytes); bool failed = false; @@ -383,7 +384,7 @@ bool DeviceData::validate_host(std::unordered_set &validated_cores, bool done = false; for (int data_index = 0; data_index < host_data.data.size(); data_index++) { validated_cores.insert(this->host_core); - if (host_data.data[data_index] != results[host_data_index] && fail_count < 5000) { + if (host_data.data[data_index] != results[host_data_index] && fail_count < 20) { if (!failed) { log_fatal(tt::LogTest, "Data mismatch - First 20 host data failures: [idx] expected->read"); } diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp index 2793fbac4dc9..02d3a367e4f6 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp @@ -637,24 +637,27 @@ void gen_host_test(Device *device, vector& cmd_sizes, DeviceData& device_data) { - constexpr uint32_t data_size = 614400; + constexpr uint32_t max_data_size = DEVICE_DATA_SIZE; // Read data from a worker so we can get reasonable BW measurements // TODO: extend the DRAM mechanism for pre-fill to workers vectordata; - for (uint32_t i = 0; i < data_size / sizeof(uint32_t); i++) { + for (uint32_t i = 0; i < max_data_size / sizeof(uint32_t); i++) { data.push_back(i); } CoreCoord phys_worker_core = device->worker_core_from_logical_core(first_worker_g); llrt::write_hex_vec_to_core(device->id(), phys_worker_core, data, l1_buf_base_g); tt::Cluster::instance().l1_barrier(device->id()); - for (int count = 0; count < 50; count++) { + for (int count = 1; count < 100; count++) { + uint32_t data_size_words = std::rand() % ((max_data_size / 100 / sizeof(uint32_t)) * count) + 1; + uint32_t data_size_bytes = data_size_words * sizeof(uint32_t); + std::vector dispatch_cmds; - gen_bare_dispatcher_host_write_cmd(dispatch_cmds, data_size); + gen_bare_dispatcher_host_write_cmd(dispatch_cmds, data_size_bytes); add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE_NOFLUSH, dispatch_cmds); auto prior_end = prefetch_cmds.size(); - add_prefetcher_linear_read_cmd(device, prefetch_cmds, cmd_sizes, first_worker_g, l1_buf_base_g, data_size); + add_prefetcher_linear_read_cmd(device, prefetch_cmds, cmd_sizes, first_worker_g, l1_buf_base_g, data_size_bytes); uint32_t new_size = (prefetch_cmds.size() - prior_end) * sizeof(uint32_t); cmd_sizes.push_back(new_size >> dispatch_constants::PREFETCH_Q_LOG_MINSIZE); @@ -662,7 +665,8 @@ void gen_host_test(Device *device, for (auto datum : dispatch_cmds) { device_data.push_one(device_data.get_host_core(), 0, datum); } - for (auto datum : data) { + for (int i = 0; i < data_size_words; i++) { + uint32_t datum = data[i]; device_data.push_one(device_data.get_host_core(), 0, datum); } pad_host_data(device_data); @@ -1055,30 +1059,47 @@ void gen_smoke_test(Device *device, // Test host if (!use_dram_exec_buf_g) { - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, 32); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); - - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, 36); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); - - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, 1024); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); - - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); - - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, 16384); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); + for (int multiplier = 1; multiplier <= 3; multiplier++) { + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * 32); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * 36); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * 1024); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g - 2 * sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g + sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g + sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + } } // Test Paged DRAM Write and Read. FIXME - Needs work - hits asserts. From 506280b4eba78df12ba1ef298cc6715f5ec01f9b Mon Sep 17 00:00:00 2001 From: Paul Keller Date: Thu, 23 May 2024 21:31:42 +0000 Subject: [PATCH 26/53] #7907: Split commands into 4K packets in dispatcher --- tests/scripts/run_cpp_fd2_tests.sh | 4 + .../routing/kernels/traffic_gen_tx.cpp | 80 +-- .../impl/dispatch/kernels/cq_dispatch.cpp | 495 +++++++++--------- .../impl/dispatch/kernels/eth_tunneler.cpp | 150 +++--- .../impl/dispatch/kernels/packet_demux.cpp | 2 +- tt_metal/impl/dispatch/kernels/packet_mux.cpp | 2 +- .../impl/dispatch/kernels/packet_queue.hpp | 13 +- 7 files changed, 359 insertions(+), 387 deletions(-) diff --git a/tests/scripts/run_cpp_fd2_tests.sh b/tests/scripts/run_cpp_fd2_tests.sh index 9d9d8b61445b..84134ef5e6ca 100755 --- a/tests/scripts/run_cpp_fd2_tests.sh +++ b/tests/scripts/run_cpp_fd2_tests.sh @@ -59,11 +59,15 @@ run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 1 -i 5 -x -spre" # Smoke Test run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 1 -i 5 -x -spre -sdis" # Smoke Test +run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 2 -i 5 -x -spre -sdis" # Random Test +run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 6 -i 5 -x -spre -sdis" # Host Test if [[ $ARCH_NAME == "wormhole_b0" ]]; then # packetized path used only on multi-chip WH run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 0 -i 5 -spre -sdis -packetized_en" # TrueSmoke Test with packetized path run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 1 -i 5 -spre -sdis -packetized_en" # Smoke Test with packetized path + run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 2 -i 5 -spre -sdis -packetized_en" # Random Test with packetized path + run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 6 -i 5 -spre -sdis -packetized_en" # Host Test with packetized path fi diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen_tx.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen_tx.cpp index d43c6ba8ca21..a698fba95cd9 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen_tx.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen_tx.cpp @@ -2,10 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 +// clang-format off #include "dataflow_api.h" #include "debug/dprint.h" #include "tt_metal/impl/dispatch/kernels/packet_queue.hpp" #include "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen.hpp" +// clang-format on constexpr uint32_t src_endpoint_id = get_compile_time_arg_val(0); constexpr uint32_t num_dest_endpoints = get_compile_time_arg_val(1); @@ -14,7 +16,7 @@ static_assert(is_power_of_2(num_dest_endpoints), "num_dest_endpoints must be a p constexpr uint32_t queue_start_addr_words = get_compile_time_arg_val(2); constexpr uint32_t queue_size_words = get_compile_time_arg_val(3); -constexpr uint32_t queue_size_bytes = queue_size_words*PACKET_WORD_SIZE_BYTES; +constexpr uint32_t queue_size_bytes = queue_size_words * PACKET_WORD_SIZE_BYTES; static_assert(is_power_of_2(queue_size_words), "queue_size_words must be a power of 2"); @@ -27,15 +29,13 @@ constexpr uint32_t remote_rx_x = get_compile_time_arg_val(6); constexpr uint32_t remote_rx_y = get_compile_time_arg_val(7); constexpr uint32_t remote_rx_queue_id = get_compile_time_arg_val(8); -constexpr DispatchRemoteNetworkType - tx_network_type = - static_cast(get_compile_time_arg_val(9)); +constexpr DispatchRemoteNetworkType tx_network_type = + static_cast(get_compile_time_arg_val(9)); constexpr uint32_t test_results_addr_arg = get_compile_time_arg_val(10); constexpr uint32_t test_results_size_bytes = get_compile_time_arg_val(11); -tt_l1_ptr uint32_t* const test_results = - reinterpret_cast(test_results_addr_arg); +tt_l1_ptr uint32_t* const test_results = reinterpret_cast(test_results_addr_arg); constexpr uint32_t prng_seed = get_compile_time_arg_val(12); @@ -64,10 +64,8 @@ constexpr packet_output_queue_state_t* output_queue_ptr = &output_queue; input_queue_rnd_state_t input_queue_rnd_state; - // generates packets with ranom size and payload on the input side inline bool input_queue_handler() { - if (input_queue_rnd_state.all_packets_done()) { return true; } @@ -80,19 +78,15 @@ inline bool input_queue_handler() { // Each call to input_queue_handler initializes only up to the end // of the queue buffer, so we don't need to handle wrapping. uint32_t byte_wr_addr = input_queue_ptr->get_queue_wptr_addr_bytes(); - uint32_t words_to_init = std::min(free_words, - input_queue_ptr->get_queue_words_before_wptr_wrap()); + uint32_t words_to_init = std::min(free_words, input_queue_ptr->get_queue_words_before_wptr_wrap()); uint32_t words_initialized = 0; while (words_initialized < words_to_init) { if (input_queue_rnd_state.all_packets_done()) { break; - } - else if (!input_queue_rnd_state.packet_active()) { - input_queue_rnd_state.next_packet_rnd(num_dest_endpoints, - dest_endpoint_start_id, - max_packet_size_words, - total_data_words); + } else if (!input_queue_rnd_state.packet_active()) { + input_queue_rnd_state.next_packet_rnd( + num_dest_endpoints, dest_endpoint_start_id, max_packet_size_words, total_data_words); tt_l1_ptr dispatch_packet_header_t* header_ptr = reinterpret_cast(byte_wr_addr); @@ -105,46 +99,54 @@ inline bool input_queue_handler() { words_initialized++; input_queue_rnd_state.curr_packet_words_remaining--; byte_wr_addr += PACKET_WORD_SIZE_BYTES; - } - else { + } else { uint32_t words_remaining = words_to_init - words_initialized; uint32_t num_words = std::min(words_remaining, input_queue_rnd_state.curr_packet_words_remaining); uint32_t start_val = (input_queue_rnd_state.packet_rnd_seed & 0xFFFF0000) + (input_queue_rnd_state.curr_packet_size_words - input_queue_rnd_state.curr_packet_words_remaining); - fill_packet_data(reinterpret_cast(byte_wr_addr), - num_words, - start_val); + fill_packet_data(reinterpret_cast(byte_wr_addr), num_words, start_val); words_initialized += num_words; input_queue_rnd_state.curr_packet_words_remaining -= num_words; - byte_wr_addr += num_words*PACKET_WORD_SIZE_BYTES; + byte_wr_addr += num_words * PACKET_WORD_SIZE_BYTES; } } input_queue_ptr->advance_queue_local_wptr(words_initialized); return false; } - void kernel_main() { - zero_l1_buf(test_results, test_results_size_bytes); test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_STARTED; test_results[PQ_TEST_MISC_INDEX] = 0xff000000; - test_results[PQ_TEST_MISC_INDEX+1] = 0xcc000000 | src_endpoint_id; + test_results[PQ_TEST_MISC_INDEX + 1] = 0xcc000000 | src_endpoint_id; noc_init(); - zero_l1_buf(reinterpret_cast(queue_start_addr_words*PACKET_WORD_SIZE_BYTES), - queue_size_words); + zero_l1_buf( + reinterpret_cast(queue_start_addr_words * PACKET_WORD_SIZE_BYTES), queue_size_words); input_queue_rnd_state.init(prng_seed, src_endpoint_id); - input_queue_ptr->init(input_queue_id, queue_start_addr_words, queue_size_words, - // remote_x, remote_y, remote_queue_id, remote_update_network_type: - 0, 0, 0, DispatchRemoteNetworkType::NONE); - - output_queue_ptr->init(output_queue_id, remote_rx_queue_start_addr_words, remote_rx_queue_size_words, - remote_rx_x, remote_rx_y, remote_rx_queue_id, tx_network_type, - input_queue_ptr, 1); + input_queue_ptr->init( + input_queue_id, + queue_start_addr_words, + queue_size_words, + // remote_x, remote_y, remote_queue_id, remote_update_network_type: + 0, + 0, + 0, + DispatchRemoteNetworkType::NONE); + + output_queue_ptr->init( + output_queue_id, + remote_rx_queue_start_addr_words, + remote_rx_queue_size_words, + remote_rx_x, + remote_rx_y, + remote_rx_queue_id, + tx_network_type, + input_queue_ptr, + 1); if (!wait_all_src_dest_ready(NULL, 0, output_queue_ptr, 1, timeout_cycles)) { test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_TIMEOUT; @@ -172,7 +174,8 @@ void kernel_main() { bool all_packets_initialized = input_queue_handler(); if (input_queue_ptr->get_curr_packet_valid()) { bool full_packet_sent; - uint32_t curr_data_words_sent = output_queue_ptr->forward_data_from_input(input_queue_id, full_packet_sent); + uint32_t curr_data_words_sent = output_queue_ptr->forward_data_from_input( + input_queue_id, full_packet_sent, input_queue.get_end_of_cmd()); data_words_sent += curr_data_words_sent; progress_timestamp = (curr_data_words_sent > 0) ? get_timestamp_32b() : progress_timestamp; } else if (all_packets_initialized) { @@ -208,18 +211,17 @@ void kernel_main() { set_64b_result(test_results, data_words_sent, PQ_TEST_WORD_CNT_INDEX); set_64b_result(test_results, cycles_elapsed, PQ_TEST_CYCLES_INDEX); set_64b_result(test_results, iter, PQ_TEST_ITER_INDEX); - set_64b_result(test_results, total_data_words, PQ_TEST_MISC_INDEX+4); - set_64b_result(test_results, num_packets, PQ_TEST_MISC_INDEX+6); + set_64b_result(test_results, total_data_words, PQ_TEST_MISC_INDEX + 4); + set_64b_result(test_results, num_packets, PQ_TEST_MISC_INDEX + 6); if (!timeout) { test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_PASS; test_results[PQ_TEST_MISC_INDEX] = 0xff00004; } else { test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_TIMEOUT; - set_64b_result(test_results, words_flushed, PQ_TEST_MISC_INDEX+10); + set_64b_result(test_results, words_flushed, PQ_TEST_MISC_INDEX + 10); // these calls lead to code size issues? // input_queue_ptr->dprint_object(); // output_queue_ptr->dprint_object(); } - } diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp index 8002bd017049..ea04faf8d4cd 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp @@ -10,15 +10,15 @@ // - # blocks must evenly divide the dispatch buffer size // - dispatch buffer base must be page size aligned +#include "debug/assert.h" +#include "debug/dprint.h" #include "tt_metal/impl/dispatch/cq_commands.hpp" #include "tt_metal/impl/dispatch/dispatch_address_map.hpp" #include "tt_metal/impl/dispatch/kernels/cq_common.hpp" #include "tt_metal/impl/dispatch/kernels/packet_queue_ctrl.hpp" -#include "debug/dprint.h" -#include "debug/assert.h" -// The command queue write interface controls writes to the completion region, host owns the completion region read interface -// Data requests from device and event states are written to the completion region +// The command queue write interface controls writes to the completion region, host owns the completion region read +// interface Data requests from device and event states are written to the completion region CQWriteInterface cq_write_interface; @@ -57,7 +57,6 @@ constexpr uint32_t dispatch_cb_size = dispatch_cb_page_size * dispatch_cb_pages; constexpr uint32_t dispatch_cb_end = dispatch_cb_base + dispatch_cb_size; constexpr uint32_t downstream_cb_end = downstream_cb_base + downstream_cb_size; - // Break buffer into blocks, 1/n of the total (dividing equally) // Do bookkeeping (release, etc) based on blocks // Note: due to the current method of release pages, up to 1 block of pages @@ -69,14 +68,17 @@ static uint32_t block_noc_writes_to_clear[dispatch_cb_blocks]; static uint32_t rd_block_idx; static uint32_t wr_block_idx; -static uint32_t cb_fence; // walks through cb page by page -static uint32_t cmd_ptr; // walks through pages in cb cmd by cmd +static uint32_t cb_fence; // walks through cb page by page +static uint32_t cmd_ptr; // walks through pages in cb cmd by cmd static uint32_t downstream_cb_data_ptr = downstream_cb_base; constexpr uint32_t l1_to_local_cache_copy_chunk = 6; -constexpr uint32_t max_write_packed_cores = 108; // GS 120 - 1 row TODO: this should be a compile time arg passed in from host -constexpr uint32_t l1_cache_size = ((max_write_packed_cores + l1_to_local_cache_copy_chunk - 1) / l1_to_local_cache_copy_chunk) * l1_to_local_cache_copy_chunk; +constexpr uint32_t max_write_packed_cores = + 108; // GS 120 - 1 row TODO: this should be a compile time arg passed in from host +constexpr uint32_t l1_cache_size = + ((max_write_packed_cores + l1_to_local_cache_copy_chunk - 1) / l1_to_local_cache_copy_chunk) * + l1_to_local_cache_copy_chunk; static uint32_t l1_cache[l1_cache_size]; @@ -105,12 +107,12 @@ void careful_copy_from_l1_to_local_cache(volatile uint32_t tt_l1_ptr *l1_ptr, ui } } -FORCE_INLINE volatile uint32_t* get_cq_completion_read_ptr() { - return reinterpret_cast(CQ_COMPLETION_READ_PTR); +FORCE_INLINE volatile uint32_t *get_cq_completion_read_ptr() { + return reinterpret_cast(CQ_COMPLETION_READ_PTR); } -FORCE_INLINE volatile uint32_t* get_cq_completion_write_ptr() { - return reinterpret_cast(CQ_COMPLETION_WRITE_PTR); +FORCE_INLINE volatile uint32_t *get_cq_completion_write_ptr() { + return reinterpret_cast(CQ_COMPLETION_WRITE_PTR); } FORCE_INLINE @@ -130,9 +132,10 @@ void completion_queue_reserve_back(uint32_t num_pages) { // so available space is distance from write ptr to read ptr // Toggles are equal means write ptr is ahead of read ptr // so available space is total space minus the distance from read to write ptr - available_space = completion_rd_toggle != cq_write_interface.completion_fifo_wr_toggle ? - completion_rd_ptr - cq_write_interface.completion_fifo_wr_ptr : - (completion_queue_size_16B - (cq_write_interface.completion_fifo_wr_ptr - completion_rd_ptr)); + available_space = + completion_rd_toggle != cq_write_interface.completion_fifo_wr_toggle + ? completion_rd_ptr - cq_write_interface.completion_fifo_wr_ptr + : (completion_queue_size_16B - (cq_write_interface.completion_fifo_wr_ptr - completion_rd_ptr)); } while (data_size_16B > available_space); DEBUG_STATUS("QRBD"); @@ -156,7 +159,8 @@ void completion_queue_push_back(uint32_t num_pages) { cq_write_interface.completion_fifo_wr_ptr += push_size_16B; if (cq_write_interface.completion_fifo_wr_ptr >= completion_queue_end_addr_16B) { - cq_write_interface.completion_fifo_wr_ptr = cq_write_interface.completion_fifo_wr_ptr - completion_queue_end_addr_16B + completion_queue_base_addr_16B; + cq_write_interface.completion_fifo_wr_ptr = + cq_write_interface.completion_fifo_wr_ptr - completion_queue_end_addr_16B + completion_queue_base_addr_16B; // Flip the toggle cq_write_interface.completion_fifo_wr_toggle = not cq_write_interface.completion_fifo_wr_toggle; } @@ -184,24 +188,21 @@ void process_write_host_h() { cb_fence = dispatch_cb_base; data_ptr = dispatch_cb_base; } - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx);; + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); + ; cb_fence += n_pages * dispatch_cb_page_size; // Release pages for prefetcher // Since we gate how much we acquire to < 1/4 the buffer, this should be called enough - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } uint32_t available_data = cb_fence - data_ptr; uint32_t xfer_size = (length > available_data) ? available_data : length; @@ -226,7 +227,9 @@ void process_write_host_h() { // We flush to ensure the ptr has been read out of l1 before we update it again completion_queue_push_back(npages); noc_async_writes_flushed(); - block_noc_writes_to_clear[rd_block_idx]+=(xfer_size + NOC_MAX_BURST_SIZE - 1) / NOC_MAX_BURST_SIZE; // XXXXX maybe just write the noc internal api counter + block_noc_writes_to_clear[rd_block_idx] += + (xfer_size + NOC_MAX_BURST_SIZE - 1) / + NOC_MAX_BURST_SIZE; // XXXXX maybe just write the noc internal api counter length -= xfer_size; data_ptr += xfer_size; @@ -234,59 +237,58 @@ void process_write_host_h() { cmd_ptr = data_ptr; } -template -void relay_to_next_cb(uint32_t data_ptr, - uint32_t length) { - +// Relay, potentially through the mux/dmux/tunneller path +// Code below sends 1 page worth of data except at the end of a cmd +// This means the downstream buffers are always page aligned, simplifies wrap handling +template +void relay_to_next_cb(uint32_t data_ptr, uint32_t length) { static_assert( preamble_size == 0 || preamble_size == sizeof(dispatch_packet_header_t), "Dispatcher preamble size must be 0 or sizeof(dispatch_packet_header_t)"); DPRINT << "relay_to_next_cb: " << data_ptr << " " << cb_fence << " " << length << ENDL(); - bool page_acquired = false; - // The downstream packetizing stage will initialize the other fields, but it needs info on - // the length of the transfer to be packetized. - if (preamble_size > 0) { - cb_acquire_pages(1); // XXXX optimize, take all availabl - page_acquired = true; - ASSERT(downstream_cb_data_ptr != downstream_cb_end); - - uint64_t downstream_noc_addr = get_noc_addr_helper(downstream_noc_xy, downstream_cb_data_ptr); - noc_inline_dw_write(downstream_noc_addr, length + preamble_size); - block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter - downstream_cb_data_ptr += preamble_size; - } - // First page should be valid since it has the command ASSERT(data_ptr <= dispatch_cb_end - dispatch_cb_page_size); ASSERT(data_ptr <= cb_fence - dispatch_cb_page_size); - uint32_t extra = preamble_size; while (length > 0) { - ASSERT (downstream_cb_end > downstream_cb_data_ptr); + ASSERT(downstream_cb_end > downstream_cb_data_ptr); + + cb_acquire_pages(1); + + uint32_t xfer_size; + bool not_end_of_cmd; + if (length > dispatch_cb_page_size - preamble_size) { + xfer_size = dispatch_cb_page_size - preamble_size; + not_end_of_cmd = true; + } else { + xfer_size = length; + not_end_of_cmd = false; + } - uint32_t xfer_size = (length > dispatch_cb_page_size - extra) ? - dispatch_cb_page_size - extra : - length; uint64_t dst = get_noc_addr_helper(downstream_noc_xy, downstream_cb_data_ptr); + if (preamble_size > 0) { + uint32_t flag; + noc_inline_dw_write(dst, xfer_size + preamble_size + not_end_of_cmd); + block_noc_writes_to_clear[rd_block_idx]++; + downstream_cb_data_ptr += preamble_size; + dst = get_noc_addr_helper(downstream_noc_xy, downstream_cb_data_ptr); + ASSERT(downstream_cb_data_ptr < downstream_cb_end); + } + // Get a page if needed if (data_ptr + xfer_size > cb_fence) { // Check for block completion if (cb_fence == block_next_start_addr[rd_block_idx]) { // Check for dispatch_cb wrap if (rd_block_idx == dispatch_cb_blocks - 1) { - // We can be misalgined when orphan_size is non=zero - // Code could be structured to stay aligned after wrap, - // but instead making this behave like other routines - uint32_t orphan_size = preamble_size; - ASSERT(dispatch_cb_end - data_ptr == preamble_size); + ASSERT(cb_fence == dispatch_cb_end); + uint32_t orphan_size = cb_fence - data_ptr; if (orphan_size != 0) { - cb_acquire_pages(1); // XXXX optimize, take all availabl noc_async_write(data_ptr, dst, orphan_size); block_noc_writes_to_clear[rd_block_idx]++; - page_acquired = true; length -= orphan_size; xfer_size -= orphan_size; downstream_cb_data_ptr += orphan_size; @@ -299,34 +301,26 @@ void relay_to_next_cb(uint32_t data_ptr, data_ptr = dispatch_cb_base; } - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx); + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); cb_fence += n_pages * dispatch_cb_page_size; // Release pages for prefetcher // Since we gate how much we acquire to < 1/4 the buffer, this should be called enough - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } - // Get downstream page - if (page_acquired == false) { - cb_acquire_pages(1); // XXXX optimize, take all available - } noc_async_write(data_ptr, dst, xfer_size); - block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter - cb_release_pages(1); // XXXX optimize, take all available + block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter + cb_release_pages(1); // XXXX optimize, take all available length -= xfer_size; data_ptr += xfer_size; @@ -334,8 +328,6 @@ void relay_to_next_cb(uint32_t data_ptr, if (downstream_cb_data_ptr == downstream_cb_end) { downstream_cb_data_ptr = downstream_cb_base; } - page_acquired = false; - extra = 0; } // Move to next page @@ -348,7 +340,6 @@ void relay_to_next_cb(uint32_t data_ptr, } void process_write_host_d() { - volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr; // Remember: host transfer command includes the command in the payload, don't add it here uint32_t length = cmd->write_linear_host.length; @@ -358,7 +349,6 @@ void process_write_host_d() { } void relay_write_h() { - volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr; uint32_t length = sizeof(CQDispatchCmd) + cmd->write_linear.length; uint32_t data_ptr = cmd_ptr; @@ -368,7 +358,7 @@ void relay_write_h() { // Note that for non-paged writes, the number of writes per page is always 1 // This means each noc_write frees up a page -template +template void process_write_linear(uint32_t num_mcast_dests) { volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr; @@ -376,7 +366,6 @@ void process_write_linear(uint32_t num_mcast_dests) { uint32_t dst_addr = cmd->write_linear.addr; uint32_t length = cmd->write_linear.length; uint32_t data_ptr = cmd_ptr + sizeof(CQDispatchCmd); - DPRINT << "dispatch_write: " << length << " num_mcast_dests: " << num_mcast_dests << ENDL(); while (length != 0) { uint32_t xfer_size = (length > dispatch_cb_page_size) ? dispatch_cb_page_size : length; uint64_t dst = get_noc_addr_helper(dst_noc, dst_addr); @@ -389,8 +378,9 @@ void process_write_linear(uint32_t num_mcast_dests) { if (rd_block_idx == dispatch_cb_blocks - 1) { uint32_t orphan_size = dispatch_cb_end - data_ptr; if (orphan_size != 0) { - if constexpr (multicast){ - noc_async_write_multicast(data_ptr, dst, orphan_size, num_mcast_dests); + if constexpr (multicast) { + noc_async_write_multicast( + data_ptr, dst, orphan_size, num_mcast_dests); } else { noc_async_write(data_ptr, dst, orphan_size); } @@ -404,33 +394,29 @@ void process_write_linear(uint32_t num_mcast_dests) { dst = get_noc_addr_helper(dst_noc, dst_addr); } - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx); + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); cb_fence += n_pages * dispatch_cb_page_size; // Release pages for prefetcher // Since we gate how much we acquire to < 1/4 the buffer, this should be called enough - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } - if constexpr (multicast){ + if constexpr (multicast) { noc_async_write_multicast(data_ptr, dst, xfer_size, num_mcast_dests); } else { noc_async_write(data_ptr, dst, xfer_size); } - block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter + block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter length -= xfer_size; data_ptr += xfer_size; @@ -449,7 +435,7 @@ void process_write() { } } -template +template void process_write_paged() { volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr; @@ -462,15 +448,19 @@ void process_write_paged() { InterleavedAddrGen addr_gen; addr_gen.bank_base_address = base_addr; addr_gen.page_size = page_size; - uint64_t dst_addr_offset = 0; // Offset into page. + uint64_t dst_addr_offset = 0; // Offset into page. - DPRINT << "process_write_paged - pages: " << pages << " page_size: " << page_size << " dispatch_cb_page_size: " << dispatch_cb_page_size; + DPRINT << "process_write_paged - pages: " << pages << " page_size: " << page_size + << " dispatch_cb_page_size: " << dispatch_cb_page_size; DPRINT << " start_page: " << page_id << " base_addr: " << HEX() << base_addr << DEC() << ENDL(); while (write_length != 0) { - // TODO #7360: Have more performant handling when page_size > dispatch_cb_page_size by not doing multiple writes for one buffer page - uint32_t xfer_size = page_size > dispatch_cb_page_size ? min(dispatch_cb_page_size, page_size - dst_addr_offset) : page_size; - uint64_t dst = addr_gen.get_noc_addr(page_id, dst_addr_offset); // XXXX replace this w/ walking the banks to save mul on GS + // TODO #7360: Have more performant handling when page_size > dispatch_cb_page_size by not doing multiple writes + // for one buffer page + uint32_t xfer_size = + page_size > dispatch_cb_page_size ? min(dispatch_cb_page_size, page_size - dst_addr_offset) : page_size; + uint64_t dst = addr_gen.get_noc_addr( + page_id, dst_addr_offset); // XXXX replace this w/ walking the banks to save mul on GS // Get a Dispatch page if needed if (data_ptr + xfer_size > cb_fence) { @@ -490,31 +480,28 @@ void process_write_paged() { data_ptr = dispatch_cb_base; dst = addr_gen.get_noc_addr(page_id, dst_addr_offset); } - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx); + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); cb_fence += n_pages * dispatch_cb_page_size; // Release pages for prefetcher // Since we gate how much we acquire to < 1/4 the buffer, this should be called enough - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } noc_async_write(data_ptr, dst, xfer_size); - block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter + block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter - // If paged write is not completed for a page (dispatch_cb_page_size < page_size) then add offset, otherwise incr page_id. + // If paged write is not completed for a page (dispatch_cb_page_size < page_size) then add offset, otherwise + // incr page_id. if (dst_addr_offset + xfer_size < page_size) { dst_addr_offset += xfer_size; } else { @@ -542,7 +529,7 @@ void process_write_paged() { // // Since all subcmds all appear in the first page and given the size restrictions // this command can't be too many pages. All pages are released at the end -template +template void process_write_packed(uint32_t flags) { volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; @@ -550,8 +537,8 @@ void process_write_packed(uint32_t flags) { ASSERT(count <= (mcast ? max_write_packed_cores / 2 : max_write_packed_cores)); constexpr uint32_t sub_cmd_size = sizeof(WritePackedSubCmd); // Copying in a burst is about a 30% net gain vs reading one value per loop below - careful_copy_from_l1_to_local_cache((volatile uint32_t tt_l1_ptr*)(cmd_ptr + sizeof(CQDispatchCmd)), - count * sub_cmd_size / sizeof(uint32_t)); + careful_copy_from_l1_to_local_cache( + (volatile uint32_t tt_l1_ptr *)(cmd_ptr + sizeof(CQDispatchCmd)), count * sub_cmd_size / sizeof(uint32_t)); uint32_t xfer_size = cmd->write_packed.size; uint32_t dst_addr = cmd->write_packed.addr; @@ -560,7 +547,8 @@ void process_write_packed(uint32_t flags) { uint32_t data_ptr = cmd_ptr + sizeof(CQDispatchCmd) + count * sizeof(WritePackedSubCmd); data_ptr = round_up_pow2(data_ptr, L1_NOC_ALIGNMENT); - uint32_t stride = (flags & CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE) ? 0 : round_up_pow2(xfer_size, L1_NOC_ALIGNMENT); + uint32_t stride = + (flags & CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE) ? 0 : round_up_pow2(xfer_size, L1_NOC_ALIGNMENT); DPRINT << data_ptr << " " << cmd_ptr << " " << xfer_size << " " << dispatch_cb_page_size << ENDL(); ASSERT(stride != 0 || data_ptr - cmd_ptr + xfer_size <= dispatch_cb_page_size); @@ -573,9 +561,7 @@ void process_write_packed(uint32_t flags) { WritePackedSubCmd *sub_cmd_ptr = (WritePackedSubCmd *)l1_cache; while (count != 0) { uint32_t dst_noc = sub_cmd_ptr->noc_xy_addr; - uint32_t num_dests = mcast ? - ((CQDispatchWritePackedMulticastSubCmd *)sub_cmd_ptr)->num_mcast_dests : - 1; + uint32_t num_dests = mcast ? ((CQDispatchWritePackedMulticastSubCmd *)sub_cmd_ptr)->num_mcast_dests : 1; sub_cmd_ptr++; uint64_t dst = get_noc_addr_helper(dst_noc, dst_addr); // Get a page if needed @@ -601,16 +587,12 @@ void process_write_packed(uint32_t flags) { noc_nonposted_writes_acked[noc_index] += mcasts; writes = 0; mcasts = 0; - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx); + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); cb_fence += n_pages * dispatch_cb_page_size; // This is done here so the common case doesn't have to restore the pointers @@ -644,17 +626,16 @@ void process_write_packed(uint32_t flags) { noc_nonposted_writes_acked[noc_index] += mcasts; // Release pages for prefetcher // write_packed releases pages at the end so the first page (w/ the sub_cmds) remains valid - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); cmd_ptr = data_ptr; } static uint32_t process_debug_cmd(uint32_t cmd_ptr) { - volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; uint32_t checksum = 0; uint32_t *data = (uint32_t *)((uint32_t)cmd + (uint32_t)sizeof(CQDispatchCmd)); @@ -691,8 +672,7 @@ static void process_wait() { } DEBUG_STATUS("PWW"); - volatile tt_l1_ptr uint32_t* sem_addr = - reinterpret_cast(addr); + volatile tt_l1_ptr uint32_t *sem_addr = reinterpret_cast(addr); DPRINT << " DISPATCH WAIT " << HEX() << addr << DEC() << " count " << count << ENDL(); #if defined(COMPILE_FOR_IDLE_ERISC) uint32_t heartbeat = 0; @@ -718,57 +698,54 @@ static void process_wait() { } static void process_delay_cmd() { - volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; uint32_t count = cmd->delay.delay; for (volatile uint32_t i = 0; i < count; i++); cmd_ptr += sizeof(CQDispatchCmd); } -static inline bool process_cmd_d(uint32_t& cmd_ptr) { - +static inline bool process_cmd_d(uint32_t &cmd_ptr) { bool done = false; - re_run_command: +re_run_command: volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; switch (cmd->base.cmd_id) { - case CQ_DISPATCH_CMD_WRITE_LINEAR: - DEBUG_STATUS("DWB"); - DPRINT << "cmd_write\n"; - process_write(); - DEBUG_STATUS("DWD"); - break; - - case CQ_DISPATCH_CMD_WRITE_LINEAR_H: - DPRINT << "cmd_write_linear_h\n"; - if (is_h_variant) { + case CQ_DISPATCH_CMD_WRITE_LINEAR: + DEBUG_STATUS("DWB"); + DPRINT << "cmd_write\n"; process_write(); - } else { - relay_write_h(); - } - break; + DEBUG_STATUS("DWD"); + break; - case CQ_DISPATCH_CMD_WRITE_LINEAR_H_HOST: - DPRINT << "cmd_write_linear_h_host\n"; - if (is_h_variant) { - process_write_host_h(); - } else { - process_write_host_d(); - } - break; + case CQ_DISPATCH_CMD_WRITE_LINEAR_H: + DPRINT << "cmd_write_linear_h\n"; + if (is_h_variant) { + process_write(); + } else { + relay_write_h(); + } + break; - case CQ_DISPATCH_CMD_WRITE_PAGED: - DPRINT << "cmd_write_paged is_dram: " << (uint32_t) cmd->write_paged.is_dram << ENDL(); - if (cmd->write_paged.is_dram) { - process_write_paged(); - } else { - process_write_paged(); - } - break; + case CQ_DISPATCH_CMD_WRITE_LINEAR_H_HOST: + DPRINT << "cmd_write_linear_h_host\n"; + if (is_h_variant) { + process_write_host_h(); + } else { + process_write_host_d(); + } + break; + + case CQ_DISPATCH_CMD_WRITE_PAGED: + DPRINT << "cmd_write_paged is_dram: " << (uint32_t)cmd->write_paged.is_dram << ENDL(); + if (cmd->write_paged.is_dram) { + process_write_paged(); + } else { + process_write_paged(); + } + break; - case CQ_DISPATCH_CMD_WRITE_PACKED: - { + case CQ_DISPATCH_CMD_WRITE_PACKED: { DPRINT << "cmd_write_packed" << ENDL(); uint32_t flags = cmd->write_packed.flags; if (flags & CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST) { @@ -776,92 +753,90 @@ static inline bool process_cmd_d(uint32_t& cmd_ptr) { } else { process_write_packed(flags); } - } - break; - - case CQ_DISPATCH_CMD_WAIT: - DPRINT << "cmd_wait" << ENDL(); - process_wait(); - break; - case CQ_DISPATCH_CMD_GO: - DPRINT << "cmd_go" << ENDL(); - break; - - case CQ_DISPATCH_CMD_SINK: - DPRINT << "cmd_sink" << ENDL(); - break; - - case CQ_DISPATCH_CMD_DEBUG: - DPRINT << "cmd_debug" << ENDL(); - cmd_ptr = process_debug_cmd(cmd_ptr); - goto re_run_command; - break; - - case CQ_DISPATCH_CMD_DELAY: - DPRINT << "cmd_delay" << ENDL(); - process_delay_cmd(); - break; - - case CQ_DISPATCH_CMD_TERMINATE: - DPRINT << "dispatch terminate\n"; - if (is_d_variant && !is_h_variant) { - relay_to_next_cb(cmd_ptr, sizeof(CQDispatchCmd)); - } - cmd_ptr += sizeof(CQDispatchCmd); - done = true; - break; - - default: - DPRINT << "dispatcher_d invalid command:" << cmd_ptr << " " << cb_fence << " " << dispatch_cb_base << " " << dispatch_cb_end << " " << rd_block_idx << " " << "xx" << ENDL(); - DPRINT << HEX() << *(uint32_t*)cmd_ptr << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+1) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+2) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+3) << ENDL(); - DEBUG_STATUS("!CMD"); - ASSERT(0); + } break; + + case CQ_DISPATCH_CMD_WAIT: + DPRINT << "cmd_wait" << ENDL(); + process_wait(); + break; + + case CQ_DISPATCH_CMD_GO: DPRINT << "cmd_go" << ENDL(); break; + + case CQ_DISPATCH_CMD_SINK: DPRINT << "cmd_sink" << ENDL(); break; + + case CQ_DISPATCH_CMD_DEBUG: + DPRINT << "cmd_debug" << ENDL(); + cmd_ptr = process_debug_cmd(cmd_ptr); + goto re_run_command; + break; + + case CQ_DISPATCH_CMD_DELAY: + DPRINT << "cmd_delay" << ENDL(); + process_delay_cmd(); + break; + + case CQ_DISPATCH_CMD_TERMINATE: + DPRINT << "dispatch terminate\n"; + if (is_d_variant && !is_h_variant) { + relay_to_next_cb(cmd_ptr, sizeof(CQDispatchCmd)); + } + cmd_ptr += sizeof(CQDispatchCmd); + done = true; + break; + + default: + DPRINT << "dispatcher_d invalid command:" << cmd_ptr << " " << cb_fence << " " << dispatch_cb_base << " " + << dispatch_cb_end << " " << rd_block_idx << " " + << "xx" << ENDL(); + DPRINT << HEX() << *(uint32_t *)cmd_ptr << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 1) << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 2) << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 3) << ENDL(); + DEBUG_STATUS("!CMD"); + ASSERT(0); } return done; } -static inline bool process_cmd_h(uint32_t& cmd_ptr) { - +static inline bool process_cmd_h(uint32_t &cmd_ptr) { bool done = false; volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; switch (cmd->base.cmd_id) { - case CQ_DISPATCH_CMD_WRITE_LINEAR_H: - DPRINT << "dispatch_h write_linear_h\n"; - process_write(); - break; - - case CQ_DISPATCH_CMD_WRITE_LINEAR_H_HOST: - DPRINT << "dispatch_h linear_h_host\n"; - process_write_host_h(); - break; - - case CQ_DISPATCH_CMD_TERMINATE: - DPRINT << "dispatch_h terminate\n"; - cmd_ptr += sizeof(CQDispatchCmd); - done = true; - break; - - default: - DPRINT << "dispatcher_h invalid command:" << cmd_ptr << " " << cb_fence << " " << " " << dispatch_cb_base << " " << dispatch_cb_end << " " << rd_block_idx << " " << "xx" << ENDL(); - DPRINT << HEX() << *(uint32_t*)cmd_ptr << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+1) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+2) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+3) << ENDL(); - DEBUG_STATUS("!CMD"); - ASSERT(0); + case CQ_DISPATCH_CMD_WRITE_LINEAR_H: + DPRINT << "dispatch_h write_linear_h\n"; + process_write(); + break; + + case CQ_DISPATCH_CMD_WRITE_LINEAR_H_HOST: + DPRINT << "dispatch_h linear_h_host\n"; + process_write_host_h(); + break; + + case CQ_DISPATCH_CMD_TERMINATE: + DPRINT << "dispatch_h terminate\n"; + cmd_ptr += sizeof(CQDispatchCmd); + done = true; + break; + + default: + DPRINT << "dispatcher_h invalid command:" << cmd_ptr << " " << cb_fence << " " + << " " << dispatch_cb_base << " " << dispatch_cb_end << " " << rd_block_idx << " " + << "xx" << ENDL(); + DPRINT << HEX() << *(uint32_t *)cmd_ptr << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 1) << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 2) << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 3) << ENDL(); + DEBUG_STATUS("!CMD"); + ASSERT(0); } return done; } void kernel_main() { - DPRINT << "dispatch_" << is_h_variant << is_d_variant << ": start" << ENDL(); static_assert(is_d_variant || split_dispatch_page_preamble_size == 0); @@ -891,27 +866,22 @@ void kernel_main() { dispatch_cb_blocks, dispatch_cb_log_page_size, my_noc_xy, - my_dispatch_cb_sem_id>(cmd_ptr, - cb_fence, - block_noc_writes_to_clear, - block_next_start_addr, - rd_block_idx); + my_dispatch_cb_sem_id>( + cmd_ptr, cb_fence, block_noc_writes_to_clear, block_next_start_addr, rd_block_idx); } - done = is_d_variant ? - process_cmd_d(cmd_ptr) : - process_cmd_h(cmd_ptr); + done = is_d_variant ? process_cmd_d(cmd_ptr) : process_cmd_h(cmd_ptr); // Move to next page cmd_ptr = round_up_pow2(cmd_ptr, dispatch_cb_page_size); // XXXXX move this inside while loop waiting for get_dispatch_cb_page above // XXXXX can potentially clear a partial block when stalled w/ some more bookkeeping - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } noc_async_write_barrier(); @@ -934,7 +904,8 @@ void kernel_main() { // We're 1 block behind cb_release_pages(dispatch_cb_pages_per_block); } - uint32_t npages = dispatch_cb_pages_per_block - ((block_next_start_addr[rd_block_idx] - cmd_ptr) >> dispatch_cb_log_page_size); + uint32_t npages = + dispatch_cb_pages_per_block - ((block_next_start_addr[rd_block_idx] - cmd_ptr) >> dispatch_cb_log_page_size); cb_release_pages(npages); // Confirm expected number of pages, spinning here is a leak diff --git a/tt_metal/impl/dispatch/kernels/eth_tunneler.cpp b/tt_metal/impl/dispatch/kernels/eth_tunneler.cpp index 971afc15f8d6..8453cca33c48 100644 --- a/tt_metal/impl/dispatch/kernels/eth_tunneler.cpp +++ b/tt_metal/impl/dispatch/kernels/eth_tunneler.cpp @@ -2,10 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 +// clang-format off #include "dataflow_api.h" #include "debug/dprint.h" #include "tt_metal/impl/dispatch/kernels/packet_queue.hpp" #include "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen.hpp" +// clang-format on #define NUM_BIDIR_TUNNELS 1 #define NUM_TUNNEL_QUEUES (NUM_BIDIR_TUNNELS * 2) @@ -17,103 +19,88 @@ constexpr uint32_t endpoint_id_start_index = get_compile_time_arg_val(0); constexpr uint32_t tunnel_lanes = get_compile_time_arg_val(1); constexpr uint32_t in_queue_start_addr_words = get_compile_time_arg_val(2); constexpr uint32_t in_queue_size_words = get_compile_time_arg_val(3); -constexpr uint32_t in_queue_size_bytes = in_queue_size_words*PACKET_WORD_SIZE_BYTES; +constexpr uint32_t in_queue_size_bytes = in_queue_size_words * PACKET_WORD_SIZE_BYTES; static_assert(is_power_of_2(in_queue_size_words), "in_queue_size_words must be a power of 2"); static_assert(tunnel_lanes <= NUM_TUNNEL_QUEUES, "cannot have more than 2 tunnel directions."); static_assert(tunnel_lanes, "tunnel directions cannot be 0. 1 => Unidirectional. 2 => Bidirectional"); -constexpr uint32_t remote_receiver_x[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(4) & 0xFF), - (get_compile_time_arg_val(5) & 0xFF) - }; - -constexpr uint32_t remote_receiver_y[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(4) >> 8) & 0xFF, - (get_compile_time_arg_val(5) >> 8) & 0xFF - }; - -constexpr uint32_t remote_receiver_queue_id[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(4) >> 16) & 0xFF, - (get_compile_time_arg_val(5) >> 16) & 0xFF - }; - -constexpr DispatchRemoteNetworkType remote_receiver_network_type[NUM_TUNNEL_QUEUES] = - { - static_cast((get_compile_time_arg_val(4) >> 24) & 0xFF), - static_cast((get_compile_time_arg_val(5) >> 24) & 0xFF) - }; - -constexpr uint32_t remote_receiver_queue_start_addr_words[NUM_TUNNEL_QUEUES] = - { - get_compile_time_arg_val(6), - get_compile_time_arg_val(8) - }; - -constexpr uint32_t remote_receiver_queue_size_words[NUM_TUNNEL_QUEUES] = - { - get_compile_time_arg_val(7), - get_compile_time_arg_val(9) - }; - -static_assert(is_power_of_2(remote_receiver_queue_size_words[0]), "remote_receiver_queue_size_words must be a power of 2"); -static_assert(is_power_of_2(remote_receiver_queue_size_words[1]), "remote_receiver_queue_size_words must be a power of 2"); - -constexpr uint32_t remote_sender_x[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(10) & 0xFF), - (get_compile_time_arg_val(11) & 0xFF) - }; - -constexpr uint32_t remote_sender_y[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(10) >> 8) & 0xFF, - (get_compile_time_arg_val(11) >> 8) & 0xFF - }; - -constexpr uint32_t remote_sender_queue_id[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(10) >> 16) & 0xFF, - (get_compile_time_arg_val(11) >> 16) & 0xFF - }; - -constexpr DispatchRemoteNetworkType remote_sender_network_type[NUM_TUNNEL_QUEUES] = - { - static_cast((get_compile_time_arg_val(10) >> 24) & 0xFF), - static_cast((get_compile_time_arg_val(11) >> 24) & 0xFF) - }; +constexpr uint32_t remote_receiver_x[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(4) & 0xFF), (get_compile_time_arg_val(5) & 0xFF)}; + +constexpr uint32_t remote_receiver_y[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(4) >> 8) & 0xFF, (get_compile_time_arg_val(5) >> 8) & 0xFF}; + +constexpr uint32_t remote_receiver_queue_id[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(4) >> 16) & 0xFF, (get_compile_time_arg_val(5) >> 16) & 0xFF}; + +constexpr DispatchRemoteNetworkType remote_receiver_network_type[NUM_TUNNEL_QUEUES] = { + static_cast((get_compile_time_arg_val(4) >> 24) & 0xFF), + static_cast((get_compile_time_arg_val(5) >> 24) & 0xFF)}; + +constexpr uint32_t remote_receiver_queue_start_addr_words[NUM_TUNNEL_QUEUES] = { + get_compile_time_arg_val(6), get_compile_time_arg_val(8)}; + +constexpr uint32_t remote_receiver_queue_size_words[NUM_TUNNEL_QUEUES] = { + get_compile_time_arg_val(7), get_compile_time_arg_val(9)}; + +static_assert( + is_power_of_2(remote_receiver_queue_size_words[0]), "remote_receiver_queue_size_words must be a power of 2"); +static_assert( + is_power_of_2(remote_receiver_queue_size_words[1]), "remote_receiver_queue_size_words must be a power of 2"); + +constexpr uint32_t remote_sender_x[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(10) & 0xFF), (get_compile_time_arg_val(11) & 0xFF)}; + +constexpr uint32_t remote_sender_y[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(10) >> 8) & 0xFF, (get_compile_time_arg_val(11) >> 8) & 0xFF}; + +constexpr uint32_t remote_sender_queue_id[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(10) >> 16) & 0xFF, (get_compile_time_arg_val(11) >> 16) & 0xFF}; + +constexpr DispatchRemoteNetworkType remote_sender_network_type[NUM_TUNNEL_QUEUES] = { + static_cast((get_compile_time_arg_val(10) >> 24) & 0xFF), + static_cast((get_compile_time_arg_val(11) >> 24) & 0xFF)}; constexpr uint32_t test_results_buf_addr_arg = get_compile_time_arg_val(12); constexpr uint32_t test_results_buf_size_bytes = get_compile_time_arg_val(13); -tt_l1_ptr uint32_t* const test_results = - reinterpret_cast(test_results_buf_addr_arg); +tt_l1_ptr uint32_t* const test_results = reinterpret_cast(test_results_buf_addr_arg); constexpr uint32_t timeout_cycles = get_compile_time_arg_val(14); void kernel_main() { - rtos_context_switch_ptr = (void (*)())RtosTable[0]; noc_init(); test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_STARTED; test_results[PQ_TEST_MISC_INDEX] = 0xff000000; - test_results[PQ_TEST_MISC_INDEX+1] = 0xbb000000; - test_results[PQ_TEST_MISC_INDEX+2] = 0xAABBCCDD; - test_results[PQ_TEST_MISC_INDEX+3] = 0xDDCCBBAA; - test_results[PQ_TEST_MISC_INDEX+4] = endpoint_id_start_index; + test_results[PQ_TEST_MISC_INDEX + 1] = 0xbb000000; + test_results[PQ_TEST_MISC_INDEX + 2] = 0xAABBCCDD; + test_results[PQ_TEST_MISC_INDEX + 3] = 0xDDCCBBAA; + test_results[PQ_TEST_MISC_INDEX + 4] = endpoint_id_start_index; for (uint32_t i = 0; i < tunnel_lanes; i++) { - input_queues[i].init(i, in_queue_start_addr_words + i*in_queue_size_words, in_queue_size_words, - remote_sender_x[i], remote_sender_y[i], remote_sender_queue_id[i], remote_sender_network_type[i]); + input_queues[i].init( + i, + in_queue_start_addr_words + i * in_queue_size_words, + in_queue_size_words, + remote_sender_x[i], + remote_sender_y[i], + remote_sender_queue_id[i], + remote_sender_network_type[i]); } for (uint32_t i = 0; i < tunnel_lanes; i++) { - output_queues[i].init(i + NUM_TUNNEL_QUEUES, remote_receiver_queue_start_addr_words[i], remote_receiver_queue_size_words[i], - remote_receiver_x[i], remote_receiver_y[i], remote_receiver_queue_id[i], remote_receiver_network_type[i], - &input_queues[i], 1); + output_queues[i].init( + i + NUM_TUNNEL_QUEUES, + remote_receiver_queue_start_addr_words[i], + remote_receiver_queue_size_words[i], + remote_receiver_x[i], + remote_receiver_y[i], + remote_receiver_queue_id[i], + remote_receiver_network_type[i], + &input_queues[i], + 1); } if (!wait_all_src_dest_ready(input_queues, tunnel_lanes, output_queues, tunnel_lanes, timeout_cycles)) { @@ -142,10 +129,11 @@ void kernel_main() { for (uint32_t i = 0; i < tunnel_lanes; i++) { if (input_queues[i].get_curr_packet_valid()) { bool full_packet_sent; - uint32_t words_sent = output_queues[i].forward_data_from_input(0, full_packet_sent); - //data_words_sent += words_sent; - //if ((words_sent > 0) && (timeout_cycles > 0)) { - progress_timestamp = get_timestamp_32b(); + uint32_t words_sent = + output_queues[i].forward_data_from_input(0, full_packet_sent, input_queues[i].get_end_of_cmd()); + // data_words_sent += words_sent; + // if ((words_sent > 0) && (timeout_cycles > 0)) { + progress_timestamp = get_timestamp_32b(); //} } output_queues[i].prev_words_in_flight_check_flush(); @@ -156,8 +144,8 @@ void kernel_main() { all_outputs_finished &= output_finished; } - //need to optimize this. - //context switch to base fw is very costly. + // need to optimize this. + // context switch to base fw is very costly. internal_::risc_context_switch(); } diff --git a/tt_metal/impl/dispatch/kernels/packet_demux.cpp b/tt_metal/impl/dispatch/kernels/packet_demux.cpp index 9fa19a887649..7c915f73766b 100644 --- a/tt_metal/impl/dispatch/kernels/packet_demux.cpp +++ b/tt_metal/impl/dispatch/kernels/packet_demux.cpp @@ -235,7 +235,7 @@ void kernel_main() { uint32_t dest = input_queue.get_curr_packet_dest(); uint8_t output_queue_id = dest_output_queue_id(dest); bool full_packet_sent; - uint32_t words_sent = output_queues[output_queue_id].forward_data_from_input(0, full_packet_sent); + uint32_t words_sent = output_queues[output_queue_id].forward_data_from_input(0, full_packet_sent, input_queue.get_end_of_cmd()); data_words_sent += words_sent; if ((words_sent > 0) && (timeout_cycles > 0)) { progress_timestamp = get_timestamp_32b(); diff --git a/tt_metal/impl/dispatch/kernels/packet_mux.cpp b/tt_metal/impl/dispatch/kernels/packet_mux.cpp index 515951018eb3..a97984306372 100644 --- a/tt_metal/impl/dispatch/kernels/packet_mux.cpp +++ b/tt_metal/impl/dispatch/kernels/packet_mux.cpp @@ -185,7 +185,7 @@ void kernel_main() { } if (input_queues[curr_input].get_curr_packet_valid()) { bool full_packet_sent; - uint32_t words_sent = output_queue.forward_data_from_input(curr_input, full_packet_sent); + uint32_t words_sent = output_queue.forward_data_from_input(curr_input, full_packet_sent, input_queues[curr_input].get_end_of_cmd()); data_words_sent += words_sent; if ((words_sent > 0) && (timeout_cycles > 0)) { progress_timestamp = get_timestamp_32b(); diff --git a/tt_metal/impl/dispatch/kernels/packet_queue.hpp b/tt_metal/impl/dispatch/kernels/packet_queue.hpp index 0be258377269..bf4e9a294fb3 100644 --- a/tt_metal/impl/dispatch/kernels/packet_queue.hpp +++ b/tt_metal/impl/dispatch/kernels/packet_queue.hpp @@ -410,6 +410,7 @@ class packet_input_queue_state_t : public packet_queue_state_t { uint16_t curr_packet_src; uint16_t curr_packet_dest; uint32_t curr_packet_size_words; + uint32_t end_of_cmd; uint32_t curr_packet_words_sent; uint32_t curr_packet_tag; uint16_t curr_packet_flags; @@ -423,7 +424,9 @@ class packet_input_queue_state_t : public packet_queue_state_t { (this->queue_start_addr_words + this->get_queue_rptr_sent_offset_words())*PACKET_WORD_SIZE_BYTES ); this->curr_packet_header_ptr = next_packet_header_ptr; - uint32_t packet_size_bytes = next_packet_header_ptr->packet_size_bytes; + uint32_t packet_size_and_flags = next_packet_header_ptr->packet_size_bytes; + uint32_t packet_size_bytes = packet_size_and_flags & 0xFFFFFFFE; + this->end_of_cmd = !(packet_size_and_flags & 1); this->curr_packet_size_words = packet_size_bytes/PACKET_WORD_SIZE_BYTES; if (packet_size_bytes % PACKET_WORD_SIZE_BYTES) { this->curr_packet_size_words++; @@ -489,6 +492,10 @@ class packet_input_queue_state_t : public packet_queue_state_t { this->reset_ready_flag(); } + inline uint32_t get_end_of_cmd() const { + return this->end_of_cmd; + } + inline bool is_packetizer_input() const { return this->cb_mode; } @@ -863,7 +870,7 @@ class packet_output_queue_state_t : public packet_queue_state_t { return num_words_to_forward; } - inline uint32_t forward_data_from_input(uint32_t input_queue_index, bool& full_packet_sent) { + inline uint32_t forward_data_from_input(uint32_t input_queue_index, bool& full_packet_sent, uint32_t end_of_cmd) { packet_input_queue_state_t* input_queue_ptr = &(this->input_queue_status.input_queue_array[input_queue_index]); uint32_t num_words_to_forward = this->get_num_words_to_send(input_queue_index); @@ -894,7 +901,7 @@ class packet_output_queue_state_t : public packet_queue_state_t { this->remote_wptr_update(num_words_to_forward); } else { this->unpacketizer_page_words_sent += num_words_to_forward; - if (full_packet_sent) { + if (full_packet_sent && end_of_cmd) { uint32_t unpacketizer_page_words_sent_past_page_bound = this->unpacketizer_page_words_sent & (this->cb_mode_page_size_words - 1); if (unpacketizer_page_words_sent_past_page_bound > 0) { From 1869a59a3b71f877eb4775439bd1c54f3d1a6f3f Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Tue, 4 Jun 2024 20:42:46 +0000 Subject: [PATCH 27/53] #6448: re-enable all-gather bidir for dim 0,1 --- .../unit_testing/misc/test_all_gather.py | 8 +++++++- tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py index 769adb144b10..5d6a12971ef2 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py @@ -224,7 +224,6 @@ def test_all_gather_on_t3000_post_commit_looping( [ (4, 2, [4, 1, 33, 256], 0, ttl.tensor.Layout.ROW_MAJOR), (8, 1, [8, 1, 33, 256], 0, ttl.tensor.Layout.ROW_MAJOR), - # (8, 1, [8, 1, 256, 32], 0, ttl.tensor.Layout.TILE), (8, 1, [8, 8, 256, 384], 1, ttl.tensor.Layout.ROW_MAJOR), (4, 2, [8, 8, 256, 384], 1, ttl.tensor.Layout.ROW_MAJOR), (4, 2, [8, 8, 256, 384], 1, ttl.tensor.Layout.TILE), @@ -259,6 +258,8 @@ def test_all_gather_on_t3000_post_commit_looping( (8, 1, [1, 1, 1024, 256], 3, ttl.tensor.Layout.TILE), (8, 1, [1, 1, 256, 2048], 2, ttl.tensor.Layout.TILE), (8, 1, [1, 1, 256, 8192], 2, ttl.tensor.Layout.TILE), # double on reduction dim for 8 chip + (8, 1, [8, 1, 256, 32], 0, ttl.tensor.Layout.TILE), + (8, 1, [8, 8, 128, 4096], 1, ttl.tensor.Layout.TILE), ], ) @pytest.mark.parametrize( @@ -424,6 +425,11 @@ def test_line_all_gather_on_t3000_post_commit( ([8, 8, 256, 384], 3, ttl.tensor.Layout.TILE), ([8, 8, 256, 768], 3, ttl.tensor.Layout.ROW_MAJOR), ([8, 8, 256, 768], 3, ttl.tensor.Layout.TILE), + ([8, 8, 1024, 4096], 1, ttl.tensor.Layout.TILE), + ([8, 8, 2048, 4096], 1, ttl.tensor.Layout.TILE), + ([8, 8, 128, 4096], 1, ttl.tensor.Layout.ROW_MAJOR), + ([8, 8, 1024, 4096], 1, ttl.tensor.Layout.ROW_MAJOR), + ([8, 8, 2048, 4096], 1, ttl.tensor.Layout.ROW_MAJOR), # Only for BFP8B # ([1, 1, 640, 32768], 3, ttl.tensor.Layout.TILE), # MLP AllGather. Llama 2 decode attn, mlp. Llama2, Falcon 40B decode mlp attn diff --git a/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp b/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp index 32debd44f72b..964e67305b16 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp +++ b/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp @@ -47,7 +47,7 @@ class AllGatherConfig { erisc_handshake_address(round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, 16)), topology(topology), - enable_bidirectional(/*false*/topology == all_gather_op::Topology::Ring && dim != 0 && dim != 1), + enable_bidirectional(topology == all_gather_op::Topology::Ring), input_is_dram(input_tensor.buffer()->buffer_type() == BufferType::DRAM), output_is_dram(output_tensor.buffer()->buffer_type() == BufferType::DRAM), From 6e889fbcfe6f9a0ea71f1f5e7a29b85fec6b4fd3 Mon Sep 17 00:00:00 2001 From: David Ma Date: Tue, 4 Jun 2024 16:30:00 +0000 Subject: [PATCH 28/53] #8890: Reduce size of *_src_format constexprs --- tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h | 4 ++-- tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h | 4 ++-- tt_metal/jit_build/genfiles.cpp | 5 ++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h b/tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h index 7558f53219ae..a9c8bf6258f6 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h @@ -18,12 +18,12 @@ inline const uint32_t get_output_base_id() return (OUTPUT_BASE_ID); } -inline const uint32_t get_output_src_format(const std::uint32_t output_id) +inline const unsigned char get_output_src_format(const std::uint32_t output_id) { return pack_src_format[output_id]; } -inline const uint32_t get_output_dst_format(const std::uint32_t output_id) +inline const unsigned char get_output_dst_format(const std::uint32_t output_id) { return pack_dst_format[output_id]; } diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h index b92af5b8ddc3..74c71eb97519 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h @@ -18,12 +18,12 @@ inline const uint32_t get_output_base_id() return (OUTPUT_BASE_ID); } -inline const uint32_t get_output_src_format(const std::uint32_t output_id) +inline const unsigned char get_output_src_format(const std::uint32_t output_id) { return pack_src_format[output_id]; } -inline const uint32_t get_output_dst_format(const std::uint32_t output_id) +inline const unsigned char get_output_dst_format(const std::uint32_t output_id) { return pack_dst_format[output_id]; } diff --git a/tt_metal/jit_build/genfiles.cpp b/tt_metal/jit_build/genfiles.cpp index 9c244ddd913f..b805e5ffa1e9 100644 --- a/tt_metal/jit_build/genfiles.cpp +++ b/tt_metal/jit_build/genfiles.cpp @@ -199,9 +199,8 @@ generate_pack_data_formats(tt_hlk_desc& desc, DataFormat unpack_conditional_dst_ static void emit_pack_data_formats(std::string pack_data_format_descs, std::vector src_formats_all_cbs, std::vector dst_formats_all_cbs) { ofstream file_stream; file_stream.open(pack_data_format_descs); - // TODO: we should be emitting "unsigned char", no reason to use 4B per data format - file_stream << create_formats_array_string("constexpr std::int32_t", "pack_src_format", NUM_CIRCULAR_BUFFERS, data_format_vec_to_string(src_formats_all_cbs)); - file_stream << create_formats_array_string("constexpr std::int32_t", "pack_dst_format", NUM_CIRCULAR_BUFFERS, data_format_vec_to_string(dst_formats_all_cbs)); + file_stream << create_formats_array_string("constexpr unsigned char", "pack_src_format", NUM_CIRCULAR_BUFFERS, data_format_vec_to_string(src_formats_all_cbs)); + file_stream << create_formats_array_string("constexpr unsigned char", "pack_dst_format", NUM_CIRCULAR_BUFFERS, data_format_vec_to_string(dst_formats_all_cbs)); // budabackend-style format array // file_stream << create_formats_array_string("const std::int32_t", "pack_src_format", 16, data_format_vec_to_string(src_formats)); From 6a009655e529490aba893f6f8281e66a3ed291c7 Mon Sep 17 00:00:00 2001 From: yugaoT Date: Tue, 4 Jun 2024 16:52:57 +0000 Subject: [PATCH 29/53] #0: merge all kernels into one group --- .../misc/test_matmul_dram_sharded.py | 184 ++++++++++++++++++ ...m_large_block_zm_fused_bias_activation.cpp | 8 + ...mm_tile_layout_in0_sender_dram_sharded.cpp | 33 +++- ...mm_tile_layout_in1_sender_dram_sharded.cpp | 24 ++- ...ulti_core_reuse_dram_sharded_optimized.cpp | 102 +++++++--- 5 files changed, 307 insertions(+), 44 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py index 0f5e1bb50e3f..ed50144d26a5 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py @@ -279,3 +279,187 @@ def test_matmul_in1_dram_sharded_with_program_cache( ttl.tensor.Tensor(py_dummy_tensor, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, mem_config) ) assert device.num_program_cache_entries() == 3 + + +def run_test_matmul_in1_dram_sharded_mm_chain( + device, + in0_sharded, + out_sharded, + in1_in_dram, + M, + K, + N, + fidelity, + has_bias, + activation, + grid_size, + in0_dtype, + in1_dtype, + out_dtype, + function_level_defaults, + use_program_cache, +): + if is_grayskull() and (N == 4096 or K == 32768): + pytest.skip("Skipping too large tensor test on Grayskull") + + if is_grayskull(): + N_padded = N + num_banks = 8 + else: + N_padded = pad_to_dram_banks(N) + num_banks = 12 + + in0_shape = [1, 1, M, K] + in1_shape = [1, 1, K, N] + in1_shard_shape = [K, N_padded // num_banks] + num_cores = grid_size[0] * grid_size[1] + + in0_block_h = M // 32 + in0_block_w = K // num_cores // 32 + out_block_h = M // 32 + out_block_w = N // num_cores // 32 + + out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w) + + logger.debug("N_padded " + str(N_padded)) + logger.debug("in0 block h w " + str(in0_block_h * 32) + " " + str(in0_block_w * 32)) + logger.debug("in1 block h w " + str(in0_block_w * 32) + " " + str(out_block_w * 32)) + logger.debug("out block h w " + str(out_block_h * 32) + " " + str(out_block_w * 32)) + logger.debug("out subblock h w " + str(out_subblock_h * 32) + " " + str(out_subblock_w * 32)) + + sharded_mem_config = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, + buffer_type=ttl.tensor.BufferType.L1, + ) + + in0 = torch.randn(in0_shape).bfloat16().float() + in1 = torch.randn(in1_shape).bfloat16().float() + + in0_shard_grid = (grid_size[0] - 1, grid_size[1] - 1) + in0_shard_shape = [M, int(in0_block_w * 32)] + in0_shard_grid = ttl.tensor.CoreRangeSet({ttl.tensor.CoreRange(ttl.tensor.CoreCoord(0, 0), in0_shard_grid)}) + in0_shard_spec = ttl.tensor.ShardSpec(in0_shard_grid, in0_shard_shape, ttl.tensor.ShardOrientation.ROW_MAJOR, False) + in0_mem_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, ttl.tensor.BufferType.L1, in0_shard_spec + ) + in0_t = torch2tt_tensor(in0, device, tt_memory_config=in0_mem_config, tt_dtype=in0_dtype) + + in1_shard_grid = ttl.tensor.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1) + in1_shard_grid = ttl.tensor.CoreRangeSet({ttl.tensor.CoreRange(ttl.tensor.CoreCoord(0, 0), in1_shard_grid)}) + in1_shard_spec = ttl.tensor.ShardSpec(in1_shard_grid, in1_shard_shape, ttl.tensor.ShardOrientation.ROW_MAJOR, False) + in1_mem_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, ttl.tensor.BufferType.DRAM, in1_shard_spec + ) + in1_t = torch2tt_tensor(in1, device, tt_memory_config=in1_mem_config, tt_dtype=in1_dtype) + + program_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( + in0_block_w=in0_block_w // 4, + out_subblock_h=out_subblock_h, + out_subblock_w=out_subblock_w, + per_core_M=out_block_h, + per_core_N=out_block_w, + fuse_batch=True, + fused_activation=None, + ) + + if is_grayskull(): + compute_kernel_config = ttl.tensor.GrayskullComputeKernelConfig( + math_fidelity=fidelity, + math_approx_mode=True, + ) + else: + compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=fidelity, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + # 1st mm + output_t = ttl.operations.primary.matmul( + in0_t, + in1_t, + program_config=program_config, + output_mem_config=sharded_mem_config, + output_dtype=out_dtype, + compute_kernel_config=compute_kernel_config, + ) + + for _ in range(200): + output_t = ttl.operations.primary.matmul( + in0_t, + in1_t, + program_config=program_config, + output_mem_config=sharded_mem_config, + output_dtype=out_dtype, + compute_kernel_config=compute_kernel_config, + ) + + output_t = output_t.cpu().to(ttl.tensor.Layout.ROW_MAJOR) + + pt_out = in0 @ in1 + + tt_out = tt2torch_tensor(output_t) + + print(tt_out) + print(pt_out) + + passing, output = comp_pcc(pt_out, tt_out) + logger.info(output) + assert True + + +@pytest.mark.parametrize( + "fidelity", + [ + ttl.tensor.MathFidelity.HiFi2, + ], + ids=[ + "HiFi2", + ], +) +@pytest.mark.parametrize( + "has_bias", + [ + False, + ], + ids=["no_bias"], +) +@pytest.mark.parametrize( + "in0_dtype, in1_dtype, out_dtype", + [ + (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B, ttl.tensor.DataType.BFLOAT16), + ], +) +def test_matmul_in1_dram_sharded_with_mm_chain( + device, + fidelity, + has_bias, + in0_dtype, + in1_dtype, + out_dtype, + function_level_defaults, + use_program_cache, +): + M = 32 + K = 4096 + N = 4096 + grid_size = (8, 2) + run_test_matmul_in1_dram_sharded_mm_chain( + device, + True, + True, + True, + M, + K, + N, + fidelity, + has_bias, + None, + grid_size, + in0_dtype, + in1_dtype, + out_dtype, + function_level_defaults, + use_program_cache, + ) diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp b/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp index ca37e340a703..ede0790e5ea6 100644 --- a/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp @@ -68,6 +68,14 @@ inline void reblock_and_untilize( } void MAIN { + // RUNTIME ARGS + #ifdef MATMUL_DRAM_SHARDED + const bool is_worker_core = get_arg_val(0) == 1; + // if not worker core, skip + if (not is_worker_core) { + return; + } + #endif constexpr uint32_t in0_block_w = get_compile_time_arg_val(0); // inner block size in tiles constexpr uint32_t in0_num_subblocks = get_compile_time_arg_val(1); // outer row block size (in inner row blocks) diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp index 109def5e9bcd..bbe72e1f48ab 100644 --- a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp @@ -31,7 +31,11 @@ void kernel_main() { constexpr uint32_t num_storage_cores = num_blocks / num_blocks_per_shard; // RUNTIME ARGS - const bool is_worker_core = get_arg_val(0) == 1; + const uint32_t worker_core_type = get_arg_val(0); + // if not worker core, skip + if (worker_core_type == 0) { + return; + } const uint32_t sender_id = get_arg_val(1); volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(2)); volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(2 + num_storage_cores)); @@ -71,7 +75,7 @@ void kernel_main() { uint32_t local_read_addr = get_read_ptr(cb_id_in2); - if (not is_worker_core) { + if (worker_core_type == 1) { // mcast sender + no compute for (uint32_t i = 0; i < num_blocks_per_shard; ++i) { const uint32_t block_id = sender_block_id + i; @@ -101,7 +105,8 @@ void kernel_main() { local_read_addr += in0_block_size_bytes; } - } else { + } else if (worker_core_type == 2) { // mcast sender + compute + for(uint32_t block = 0; block < num_blocks; ++block) { const uint32_t block_id = block / num_blocks_per_shard; @@ -138,5 +143,27 @@ void kernel_main() { cb_push_back(cb_id_in0, in0_block_num_tiles); } + } else { // mcast receiver + compute + + for(uint32_t block = 0; block < num_blocks; ++block) { + const uint32_t block_id = block / num_blocks_per_shard; + + // get the mcast sender noc + uint64_t in0_mcast_sender_semaphore_noc_addr = get_noc_addr(in0_mcast_sender_noc_x[block_id], in0_mcast_sender_noc_y[block_id], in0_mcast_sender_semaphore_addr); + + // Operand 0 + cb_reserve_back(cb_id_in0, in0_block_num_tiles); + + // Set in0 semaphore value to INVALID + noc_semaphore_set(in0_mcast_receiver_semaphore_addr_ptr, INVALID); + + // Atomic increment source core counter + noc_semaphore_inc(in0_mcast_sender_semaphore_noc_addr, 1); + + // wait on in0 semaphore value to become VALID (set by mcast sender after it multicasts data) + noc_semaphore_wait(in0_mcast_receiver_semaphore_addr_ptr, VALID); + + cb_push_back(cb_id_in0, in0_block_num_tiles); + } } } diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp index 5bde1c06534c..0546a8db1c0e 100644 --- a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp @@ -9,17 +9,23 @@ void kernel_main() { // RUNTIME ARGS - const uint32_t in1_tensor_addr = get_arg_val(0); + const bool is_worker_core = get_arg_val(0) == 1; + // if not worker core, skip + if (not is_worker_core) { + return; + } + + const uint32_t in1_tensor_addr = get_arg_val(1); #ifdef FUSE_BIAS - const uint32_t in3_tensor_addr = get_arg_val(1); + const uint32_t in3_tensor_addr = get_arg_val(2); #endif - const uint32_t dram_bank_id = get_arg_val(2); - const uint32_t vc = get_arg_val(3); - const uint32_t num_shard_to_write_back = get_arg_val(4); - const uint32_t reshard_tensor_start_offset = get_arg_val(5); - volatile tt_l1_ptr uint32_t * per_core_N_reshard_bytes = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(6)); - volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(7)); - volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(8)); + const uint32_t dram_bank_id = get_arg_val(3); + const uint32_t vc = get_arg_val(4); + const uint32_t num_shard_to_write_back = get_arg_val(5); + const uint32_t reshard_tensor_start_offset = get_arg_val(6); + volatile tt_l1_ptr uint32_t * per_core_N_reshard_bytes = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(7)); + volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(8)); + volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(9)); // COMPILE TIME ARGS diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp index 38efb0589eff..9b8b3200eaa3 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp @@ -503,6 +503,13 @@ operation::ProgramWithCallbacks create_program_dram_sharded( log_debug("all_cores: {}", core); } + // grid bounding box + CoreRange bounding_box = all_cores.bounding_box(); + std::set bounding_box_set; bounding_box_set.insert(bounding_box); + CoreRangeSet all_cores_in_rect_grid(bounding_box_set); + std::vector all_cores_in_rect_grid_vec = corerange_to_cores(all_cores_in_rect_grid); + log_debug("bounding_box: {}", bounding_box); + // Mcast args auto in0_mcast_sender_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); auto in0_mcast_receiver_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); @@ -581,16 +588,6 @@ operation::ProgramWithCallbacks create_program_dram_sharded( in1_sender_writer_compile_time_args.push_back(bias_buffer_num_pages); in1_sender_writer_compile_time_args.push_back((std::uint32_t)1); } - std::vector in0_receiver_compile_time_args = { - // in0 block args - (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles - // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - // in0 mcast args - (std::uint32_t)in0_mcast_sender_semaphore, - (std::uint32_t)in0_mcast_receiver_semaphore, - // - (std::uint32_t)num_blocks_per_shard}; std::map mm_kernel_defines; std::map mm_kernel_in0_sender_define; @@ -625,11 +622,12 @@ operation::ProgramWithCallbacks create_program_dram_sharded( if (skip_write_back) { mm_kernel_in1_sender_writer_defines["SKIP_WRITE_BACK"] = "1"; } + mm_kernel_defines["MATMUL_DRAM_SHARDED"] = "1"; auto mm_kernel_in0_sender_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp", - mcast_senders, + all_cores_in_rect_grid, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, .noc = in0_noc, @@ -639,22 +637,13 @@ operation::ProgramWithCallbacks create_program_dram_sharded( auto mm_kernel_in1_sender_writer_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp", - all_worker_cores, + all_cores_in_rect_grid, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = in1_noc, .compile_args = in1_sender_writer_compile_time_args, .defines = mm_kernel_in1_sender_writer_defines}); - KernelHandle mm_kernel_in0_receiver_id = tt_metal::CreateKernel( - program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver_dram_sharded.cpp", - mcast_receivers, - tt_metal::DataMovementConfig{ - .processor = tt_metal::DataMovementProcessor::RISCV_1, - .noc = in0_noc, - .compile_args = in0_receiver_compile_time_args}); - // Compute kernel compile time args uint32_t in0_subblock_num_tiles = out_subblock_h * in0_block_w; @@ -687,7 +676,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded( auto mm_kernel = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", - all_worker_cores, + // all_worker_cores, + all_cores_in_rect_grid, tt_metal::ComputeConfig{ .math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, @@ -850,14 +840,15 @@ operation::ProgramWithCallbacks create_program_dram_sharded( for (auto core : mcast_senders_coords) { std::vector mm_in0_sender_args; - bool is_worker_core; + // mcast sender - 1, mcast sender + compute core - 2 + uint32_t worker_core_type; if (find(storage_worker_common.begin(), storage_worker_common.end(), core) != storage_worker_common.end()) { - is_worker_core = true; + worker_core_type = 2; } else { - is_worker_core = false; + worker_core_type = 1; } - mm_in0_sender_args.push_back((std::uint32_t)is_worker_core); + mm_in0_sender_args.push_back((std::uint32_t)worker_core_type); mm_in0_sender_args.push_back((std::uint32_t)sender_id); mm_in0_sender_args.insert( mm_in0_sender_args.end(), in0_mcast_sender_noc_x.begin(), in0_mcast_sender_noc_x.end()); @@ -876,12 +867,30 @@ operation::ProgramWithCallbacks create_program_dram_sharded( // in0 receivers rt args std::vector mm_in0_receiver_args; + // mcast receiver - 3 + uint32_t worker_core_type = 3; + mm_in0_receiver_args.push_back((std::uint32_t)worker_core_type); + mm_in0_receiver_args.push_back((std::uint32_t) 0); mm_in0_receiver_args.insert( mm_in0_receiver_args.end(), in0_mcast_sender_noc_x.begin(), in0_mcast_sender_noc_x.end()); mm_in0_receiver_args.insert( mm_in0_receiver_args.end(), in0_mcast_sender_noc_y.begin(), in0_mcast_sender_noc_y.end()); - tt_metal::SetRuntimeArgs(program, mm_kernel_in0_receiver_id, core, mm_in0_receiver_args); - reader_kernel_ids.push_back(mm_kernel_in0_receiver_id); + + tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_receiver_args); + reader_kernel_ids.push_back(mm_kernel_in0_sender_id); + } + + for (auto core : all_cores_in_rect_grid_vec) { + if (std::find(mcast_senders_coords.begin(), mcast_senders_coords.end(), core) == mcast_senders_coords.end() and + std::find(mcast_receiver_coords.begin(), mcast_receiver_coords.end(), core) == mcast_receiver_coords.end()) { + // in0 receivers rt args + std::vector mm_in0_idle_args; + // idle core - 0 + uint32_t worker_core_type = 0; + mm_in0_idle_args.push_back((std::uint32_t)worker_core_type); + + tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_idle_args); + } } uint32_t bank_id = 0; @@ -894,11 +903,40 @@ operation::ProgramWithCallbacks create_program_dram_sharded( uint32_t curr_worker_core = 0; uint32_t curr_storage_core = 0; + // for all the cores in the rect grid, we send one rt arg to determine if they are worker core + for (uint32_t i = 0; i < all_cores_in_rect_grid_vec.size(); ++i) { + auto core = all_cores_in_rect_grid_vec[i]; + + if (all_worker_cores.ranges().find(core) == all_worker_cores.ranges().end()) { // not worker + // in1 reader rt args + bool is_worker_core = false; + std::vector mm_in1_sender_writer_args; + mm_in1_sender_writer_args.push_back((std::uint32_t) is_worker_core); + + tt_metal::SetRuntimeArgs(program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); + + // compute rt args + std::vector mm_compute_args; + mm_compute_args.push_back((std::uint32_t) is_worker_core); + + tt_metal::SetRuntimeArgs(program, mm_kernel, core, mm_compute_args); + } else { + // compute rt args + bool is_worker_core = true; + std::vector mm_compute_args; + mm_compute_args.push_back((std::uint32_t) is_worker_core); + + tt_metal::SetRuntimeArgs(program, mm_kernel, core, mm_compute_args); + } + } + for (uint32_t i = 0; i < all_worker_cores_ordered.size(); ++i) { auto core = all_worker_cores_ordered[i]; // in1 reader rt args + bool is_worker_core = true; std::vector mm_in1_sender_writer_args; + mm_in1_sender_writer_args.push_back((std::uint32_t) is_worker_core); mm_in1_sender_writer_args.push_back(in1_buffer->address()); if (bias_buffer != nullptr) { mm_in1_sender_writer_args.push_back(bias_buffer->address()); @@ -1014,7 +1052,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded( } } - mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 4, num_iter); + mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 5, num_iter); } tt_metal::SetRuntimeArgs(program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); @@ -1044,11 +1082,11 @@ operation::ProgramWithCallbacks create_program_dram_sharded( auto core = all_worker_cores_ordered[i]; auto writer_kernel_id = writer_kernel_ids[i]; auto& writer_runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - writer_runtime_args[0] = src_buffer_b->address(); + writer_runtime_args[1] = src_buffer_b->address(); if (bias_tensor.has_value()) { - writer_runtime_args[1] = bias_tensor.value().buffer()->address(); + writer_runtime_args[2] = bias_tensor.value().buffer()->address(); } else { - writer_runtime_args[1] = 0; + writer_runtime_args[2] = 0; } } }; From 79d283fc0e7c18c554f1fd27808f9394f00ce792 Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Tue, 4 Jun 2024 21:59:39 +0000 Subject: [PATCH 30/53] #7724: Disable a test to reduce runtime --- .../streams/test_autonomous_relay_streams.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp index 2c963a0796d0..1281b2414ef8 100644 --- a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp @@ -656,7 +656,7 @@ TEST_F(CommandQueueFixture, TestAutonomousRelayStreams) { } std::srand(0); - uint32_t num_loop_iterations = 10; + uint32_t num_loop_iterations = 2; uint32_t num_messages_to_send = 1'000'000; uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; uint32_t relay_stream_buffer_size_bytes = 16 * 1024; @@ -733,7 +733,7 @@ TEST_F(CommandQueueFixture, TestAutonomousRelayStreamsSmallPackets) { return; } -TEST_F(CommandQueueFixture, TestAutonomousRelayStreamsLoopingShort) { +TEST_F(CommandQueueFixture, DISABLED_TestAutonomousRelayStreamsLoopingShort) { auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); auto num_devices = tt::tt_metal::GetNumAvailableDevices(); if (arch == tt::ARCH::GRAYSKULL) { From 3ccf9ef057187d37d04cc625a1b651c6bb30bdab Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Tue, 4 Jun 2024 02:35:05 +0000 Subject: [PATCH 31/53] #9088: support for multi-device galaxy device_ids --- tt_eager/tensor/tensor.hpp | 2 +- tt_eager/tensor/types.hpp | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index d29c0730942a..16c9665d2c35 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -325,7 +325,7 @@ struct Tensor { return buffer->device(); } else if (this->storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) { auto &storage = std::get(this->get_storage()); - return storage.get_buffer_for_device_id(0)->device(); + return this->get_workers().at(0); } else { TT_THROW("Cannot get the device from a tensor with host storage"); } diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index 9c71b6f0d777..81247e39c871 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -503,16 +503,17 @@ struct MultiDeviceHostStorage { const MemoryConfig memory_config() const { std::lock_guard lock(mtx); - if (this->buffers.at(0).get() == nullptr) { + auto first_device_id = this->ordered_device_ids.at(0); + if (this->buffers.at(first_device_id).get() == nullptr) { TT_THROW("MemoryConfig can only be obtained if the buffer is not null"); } std::optional shard_spec = std::nullopt; - if (is_sharded(this->buffers.at(0)->buffer_layout())) { - shard_spec = this->buffers.at(0)->shard_spec().tensor_shard_spec; + if (is_sharded(this->buffers.at(first_device_id)->buffer_layout())) { + shard_spec = this->buffers.at(first_device_id)->shard_spec().tensor_shard_spec; } return MemoryConfig{ - .memory_layout = this->buffers.at(0)->buffer_layout(), - .buffer_type = this->buffers.at(0)->buffer_type(), + .memory_layout = this->buffers.at(first_device_id)->buffer_layout(), + .buffer_type = this->buffers.at(first_device_id)->buffer_type(), .shard_spec = shard_spec}; } From 54b93f2f00cbc6f22af0ed449b79b274afb5ad90 Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Mon, 3 Jun 2024 19:25:08 +0000 Subject: [PATCH 32/53] #9088: update falcon7b to support wh 7x8 and 8x8 core grid --- .../tests/multi_chip/test_falcon_attention.py | 1 + models/demos/ttnn_falcon7b/tt/falcon_attention.py | 15 ++++++++------- models/demos/ttnn_falcon7b/tt/falcon_decoder.py | 1 + 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py index 1eb0382ce263..a3ebb92457d8 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py @@ -154,6 +154,7 @@ def test_falcon_attention( configuration.max_position_embeddings, model_config, parameters=parameters, + core_grid=device_mesh.get_devices()[0].core_grid, ) tt_out, tt_layer_present = tt_FalconAttention_model( diff --git a/models/demos/ttnn_falcon7b/tt/falcon_attention.py b/models/demos/ttnn_falcon7b/tt/falcon_attention.py index 63fb859b7599..51921c0c45c8 100644 --- a/models/demos/ttnn_falcon7b/tt/falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tt/falcon_attention.py @@ -24,6 +24,7 @@ def __init__( max_position_embeddings: int = 2048, model_config=None, parameters=None, + core_grid=None, ): super().__init__() self.hidden_size = hidden_size @@ -49,11 +50,7 @@ def __init__( ) self.scalar = 1 / math.sqrt(self.head_dim) - - if is_wormhole_b0(): - self.core_grid = ttnn.CoreGrid(y=7, x=8) - else: - self.core_grid = ttnn.CoreGrid(y=9, x=12) + self.core_grid = core_grid def __call__( self, @@ -165,7 +162,9 @@ def __call__( attn_weights = ttnn.experimental.operations.primary.transformers.attn_matmul( query_layer, key_layer_transposed, - compute_with_storage_grid_size=ttnn.experimental.tensor.CoreCoord(8, 7), + compute_with_storage_grid_size=ttnn.experimental.tensor.CoreCoord( + self.core_grid.x, self.core_grid.y + ), output_mem_config=self.model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"], output_dtype=self.model_config["PRE_SOFTMAX_MM_OUTPUT_DTYPE"], # Must be BFLOAT16 ) @@ -228,7 +227,9 @@ def __call__( attn_output = ttnn.experimental.operations.primary.transformers.attn_matmul( attn_weights, value_layer, - compute_with_storage_grid_size=ttnn.experimental.tensor.CoreCoord(8, 7), + compute_with_storage_grid_size=ttnn.experimental.tensor.CoreCoord( + self.core_grid.x, self.core_grid.y + ), output_mem_config=self.model_config["POST_SOFTMAX_MM_OUTPUT_MEMCFG"], output_dtype=self.model_config["POST_SOFTMAX_MM_OUTPUT_DTYPE"], # Must be BFLOAT16 ) diff --git a/models/demos/ttnn_falcon7b/tt/falcon_decoder.py b/models/demos/ttnn_falcon7b/tt/falcon_decoder.py index fed5b893129e..045011db439f 100644 --- a/models/demos/ttnn_falcon7b/tt/falcon_decoder.py +++ b/models/demos/ttnn_falcon7b/tt/falcon_decoder.py @@ -31,6 +31,7 @@ def __init__( max_position_embeddings=config.max_position_embeddings, model_config=model_config, parameters=parameters.self_attention, + core_grid=device.get_devices()[0].core_grid, ) self.mlp = TtFalconMLP(model_config, parameters=parameters.mlp) From 03c757ece85aa1a78857b47e11668976ff8ed852 Mon Sep 17 00:00:00 2001 From: Joseph Chu Date: Tue, 4 Jun 2024 02:22:16 +0000 Subject: [PATCH 33/53] #9088: support multi-device mesh with single device --- conftest.py | 9 -------- .../tests/multi_chip/test_falcon_mlp.py | 1 + tests/ttnn/unit_tests/test_multi_device.py | 2 ++ .../unit_tests/test_multi_device_async.py | 7 ++---- .../unit_tests/test_multi_device_trace.py | 6 +++++ tt_eager/tensor/tensor.cpp | 2 ++ tt_eager/tensor/tensor_utils.cpp | 22 +++++++++++-------- ttnn/cpp/ttnn/multi_device.hpp | 2 ++ 8 files changed, 28 insertions(+), 23 deletions(-) diff --git a/conftest.py b/conftest.py index 7df64b2c7505..6c617cc1e7a8 100644 --- a/conftest.py +++ b/conftest.py @@ -326,9 +326,6 @@ def device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_devices_requested = len(device_ids) - if num_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh(ttnn.DeviceGrid(1, num_devices_requested), device_ids[:num_devices_requested]) logger.debug(f"multidevice with {device_mesh.get_num_devices()} devices is created") @@ -354,9 +351,6 @@ def pcie_device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_pcie_devices_requested = len(device_ids) - if num_pcie_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh( ttnn.DeviceGrid(1, num_pcie_devices_requested), device_ids[:num_pcie_devices_requested] ) @@ -386,9 +380,6 @@ def t3k_device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_devices_requested = len(device_ids) - if num_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh(ttnn.DeviceGrid(1, num_devices_requested), device_ids[:num_devices_requested]) logger.debug(f"multidevice with {device_mesh.get_num_devices()} devices is created") diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py index 6301284023c5..192babe1f3e9 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py @@ -52,6 +52,7 @@ def torch_model(): @pytest.mark.parametrize( "device_mesh", [ + 1, 2, ], indirect=True, diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index c8b7386279dc..501840cfe5ce 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -159,6 +159,8 @@ def test_multi_device_replicate(device_mesh, shape, layout, memory_config): def test_ttnn_multi_device_all_gather(pcie_device_mesh): """Multidevice API test for ttnn.all_gather CCL operation""" + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("Requires multiple devices to run") full_tensor = torch.rand((1, 1, 32, 32 * pcie_device_mesh.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=3)) diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index 2f5cc0e82529..35a3bf71a5bd 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -278,8 +278,8 @@ def test_multi_device_explicit_dealloc(pcie_device_mesh): """Multidevice API: Ensure that deallocating multi-device tensors works as expected""" from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh - for device in pcie_device_mesh.get_device_ids(): - pcie_device_mesh.get_device(device).enable_async(True) + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("Requires multiple devices to run") # Create input tensors that cause OOM during op execution # Explictly deallocate buffers after each op to ensure we don't run OOM. @@ -311,9 +311,6 @@ def test_multi_device_explicit_dealloc(pcie_device_mesh): ttnn_output_tensor, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0) ) - for device in pcie_device_mesh.get_device_ids(): - pcie_device_mesh.get_device(device).enable_async(False) - @pytest.mark.parametrize("scalar", [3]) @pytest.mark.parametrize("size", [64]) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index e75279713485..aa350b6d1e76 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -16,6 +16,9 @@ @pytest.mark.parametrize("use_all_gather", [True, False]) @pytest.mark.parametrize("enable_async", [True, False]) def test_multi_device_single_trace(pcie_device_mesh, shape, use_all_gather, enable_async): + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("This test requires multiple devices") + # Trace requires program cache to be enabled for device_id in pcie_device_mesh.get_device_ids(): pcie_device_mesh.get_device(device_id).enable_async(enable_async) @@ -103,6 +106,9 @@ def test_multi_device_multi_trace(pcie_device_mesh, shape, use_all_gather, enabl if shape == (1, 1, 32, 32) or shape == (1, 3, 512, 512) or shape == (1, 3, 32, 32): pytest.skip("This configuration is not working with all-gather") + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("This test requires multiple devices") + # Trace requires program cache to be enabled for device_id in pcie_device_mesh.get_device_ids(): pcie_device_mesh.get_device(device_id).enable_async(enable_async) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index c59e12608b51..694138fe1f87 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -604,6 +604,8 @@ Tensor Tensor::to(Layout target_layout, DeviceMesh* device_mesh) const { auto& worker = workers[worker_index]; worker->push_work([*this, tensor_modified_layout, target_layout, worker, worker_index]() mutable { TT_ASSERT( + this->storage_type() == StorageType::OWNED || + this->storage_type() == StorageType::BORROWED|| this->storage_type() == StorageType::MULTI_DEVICE_HOST && "to(layout) must be called on host tensors with MULTI_DEVICE_HOST_STORAGE when multiple workers " "are specified"); diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index c9d96d91cd6c..f6cd958d7911 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -363,16 +363,20 @@ const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volu bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) { - const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); - if (tensor_storage.has_buffer_for_device_id(device_id)) { - return Tensor{ - DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, - multi_device_tensor.get_legacy_shape(), - multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout() - }; + if (std::holds_alternative(multi_device_tensor.get_storage())) { + const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); + if (tensor_storage.has_buffer_for_device_id(device_id)) { + return Tensor{ + DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, + multi_device_tensor.get_legacy_shape(), + multi_device_tensor.get_dtype(), + multi_device_tensor.get_layout()}; + } + } else if (std::holds_alternative(multi_device_tensor.get_storage())) { + return multi_device_tensor; } - TT_THROW("Device not found in multi-device tensor"); + + TT_THROW("User is trying to access a device tensor that is not on device."); } Tensor get_device_tensor(const Tensor& multi_device_tensor, const Device* device) { diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index 41943189363b..1a36bad30869 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -46,6 +46,8 @@ std::vector get_device_tensors(const ttnn::Tensor& tensor) { tensors.push_back(shard); } return tensors; + } else { + return {tensor}; } TT_THROW("Expected tensor to be on MultiDeviceHostStorage type!"); } From 2e8e3600d13b359e09a9a75c69c5a459f15bd113 Mon Sep 17 00:00:00 2001 From: Paul Keller Date: Tue, 4 Jun 2024 17:54:12 +0000 Subject: [PATCH 34/53] #9026: Fix FD dispatcher wait on wrapped value EnqueueProgram needs to emit a barrier w/o a wait It was waiting on a stale semaphore value causing an issue at semaphore wrap time Now waiting is optional --- .../perf_microbenchmark/dispatch/test_prefetcher.cpp | 1 + tt_metal/impl/dispatch/command_queue.cpp | 6 ++---- tt_metal/impl/dispatch/cq_commands.hpp | 3 +++ tt_metal/impl/dispatch/device_command.hpp | 7 ++++--- tt_metal/impl/dispatch/kernels/cq_dispatch.cpp | 9 ++++++--- 5 files changed, 16 insertions(+), 10 deletions(-) diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp index 02d3a367e4f6..bc5e958996ff 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp @@ -480,6 +480,7 @@ void gen_wait_and_stall_cmd(Device *device, wait.base.cmd_id = CQ_DISPATCH_CMD_WAIT; wait.wait.barrier = true; wait.wait.notify_prefetch = true; + wait.wait.wait = true; wait.wait.addr = dispatch_wait_addr_g; wait.wait.count = 0; add_bare_dispatcher_cmd(dispatch_cmds, wait); diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index e0325cdddf30..7a84851109f3 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -875,11 +875,9 @@ void EnqueueProgramCommand::assemble_device_commands() { } } - // Wait Noc Write Barrier, wait for binaries to be written to worker cores + // Wait Noc Write Barrier, wait for binaries/configs to be written to worker cores if (program.program_transfer_info.num_active_cores > 0) { - // Wait Noc Write Barrier, wait for binaries to be written to worker cores - // TODO: any way to not have dispatcher poll the addr here? - program_command_sequence.add_dispatch_wait(true, DISPATCH_MESSAGE_ADDR, 0); + program_command_sequence.add_dispatch_wait(true, DISPATCH_MESSAGE_ADDR, 0, 0, false, false); } // Go Signals diff --git a/tt_metal/impl/dispatch/cq_commands.hpp b/tt_metal/impl/dispatch/cq_commands.hpp index f4a4ddb0a446..db16fa618211 100644 --- a/tt_metal/impl/dispatch/cq_commands.hpp +++ b/tt_metal/impl/dispatch/cq_commands.hpp @@ -162,6 +162,9 @@ struct CQDispatchWaitCmd { uint8_t barrier; // if true, issue write barrier uint8_t notify_prefetch; // if true, inc prefetch sem uint8_t clear_count; // if true, reset count to 0 + uint8_t wait; // if true, wait on count value below + uint8_t pad1; + uint16_t pad2; uint32_t addr; // address to read uint32_t count; // wait while address is < count } __attribute__((packed)); diff --git a/tt_metal/impl/dispatch/device_command.hpp b/tt_metal/impl/dispatch/device_command.hpp index 67977c63797e..e8c1255a8b52 100644 --- a/tt_metal/impl/dispatch/device_command.hpp +++ b/tt_metal/impl/dispatch/device_command.hpp @@ -73,7 +73,7 @@ class DeviceCommand { vector_memcpy_aligned cmd_vector() const { return this->cmd_region_vector; } void add_dispatch_wait( - uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0, bool notify_prefetch = false) { + uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0, bool notify_prefetch = false, bool do_wait = true) { auto initialize_wait_cmds = [&](CQPrefetchCmd *relay_wait, CQDispatchCmd *wait_cmd) { relay_wait->base.cmd_id = CQ_PREFETCH_CMD_RELAY_INLINE; relay_wait->relay_inline.length = sizeof(CQDispatchCmd); @@ -82,6 +82,7 @@ class DeviceCommand { wait_cmd->base.cmd_id = CQ_DISPATCH_CMD_WAIT; wait_cmd->wait.barrier = barrier; wait_cmd->wait.notify_prefetch = notify_prefetch; + wait_cmd->wait.wait = do_wait; wait_cmd->wait.addr = address; wait_cmd->wait.count = count; wait_cmd->wait.clear_count = clear_count; @@ -101,8 +102,8 @@ class DeviceCommand { } void add_dispatch_wait_with_prefetch_stall( - uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0) { - this->add_dispatch_wait(barrier, address, count, clear_count, true); + uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0, bool do_wait = true) { + this->add_dispatch_wait(barrier, address, count, clear_count, true, do_wait); uint32_t increment_sizeB = align(sizeof(CQPrefetchCmd), PCIE_ALIGNMENT); auto initialize_stall_cmd = [&](CQPrefetchCmd *stall_cmd) { *stall_cmd = {}; diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp index ea04faf8d4cd..07bf38efdb20 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp @@ -663,9 +663,10 @@ static void process_wait() { uint32_t barrier = cmd->wait.barrier; uint32_t notify_prefetch = cmd->wait.notify_prefetch; + uint32_t clear_count = cmd->wait.clear_count; + uint32_t wait = cmd->wait.wait; uint32_t addr = cmd->wait.addr; uint32_t count = cmd->wait.count; - uint32_t clear_count = cmd->wait.clear_count; if (barrier) { noc_async_write_barrier(); @@ -677,10 +678,12 @@ static void process_wait() { #if defined(COMPILE_FOR_IDLE_ERISC) uint32_t heartbeat = 0; #endif - while (!wrap_ge(*sem_addr, count)) { + if (wait) { + while (!wrap_ge(*sem_addr, count)) { #if defined(COMPILE_FOR_IDLE_ERISC) - RISC_POST_HEARTBEAT(heartbeat); + RISC_POST_HEARTBEAT(heartbeat); #endif + } } DEBUG_STATUS("PWD"); From 7f0bbbecf7036f89f7df85d3bf4b3347753784cb Mon Sep 17 00:00:00 2001 From: asaigal Date: Tue, 4 Jun 2024 22:19:55 +0000 Subject: [PATCH 35/53] #0: Add back Async Mode optimizations - Remove NUMA node based thread affinity policy, since it was causing a slowdown on CI --- CMakeLists.txt | 6 +- .../tensors/test_async_tensor_apis.cpp | 215 +++---- tt_eager/tensor/tensor.cpp | 129 ++--- tt_eager/tensor/tensor.hpp | 55 +- tt_eager/tensor/tensor_impl.hpp | 5 +- tt_eager/tensor/tensor_utils.cpp | 529 ++++++++++-------- tt_eager/tensor/types.hpp | 41 +- tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp | 8 +- .../eltwise_binary/eltwise_binary_op.cpp | 8 +- .../eltwise_unary/eltwise_unary_op.cpp | 6 +- tt_eager/tt_dnn/op_library/run_operation.cpp | 355 +++++++----- .../tt_dnn/op_library/softmax/softmax_op.cpp | 8 +- .../transformer_tms/transformer_tms.cpp | 24 +- .../op_library/transpose/transpose_op.cpp | 4 +- tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp | 10 +- tt_metal/CMakeLists.txt | 2 +- tt_metal/detail/tt_metal.hpp | 12 + tt_metal/impl/device/device.cpp | 4 +- tt_metal/impl/device/device.hpp | 4 +- tt_metal/impl/dispatch/command_queue.cpp | 23 +- tt_metal/impl/dispatch/work_executor.hpp | 16 +- tt_metal/tt_metal.cpp | 91 ++- ttnn/cpp/ttnn/op_library/binary/binary_op.cpp | 8 +- 23 files changed, 893 insertions(+), 670 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b85f073c3f19..4bd35a6d78d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,10 @@ CHECK_COMPILERS() find_package(Boost REQUIRED COMPONENTS thread filesystem system regex) find_package(GTest REQUIRED) find_package (Python3 COMPONENTS Interpreter Development) +find_library(NUMA_LIBRARY NAMES numa) +if (NOT NUMA_LIBRARY) + message(FATAL_ERROR "NUMA library not found") +endif() ############################################################################################################################ # Setting build type flags @@ -84,7 +88,7 @@ set(CMAKE_INSTALL_DATAROOTDIR "${CMAKE_BINARY_DIR}/tmp/share") ############################################################################################################################ add_library(metal_common_libs INTERFACE) target_link_libraries(metal_common_libs INTERFACE - dl z pthread atomic stdc++ # system libraries + dl z pthread atomic stdc++ numa # system libraries Boost::thread Boost::filesystem Boost::system Boost::regex hwloc # hwloc has no cmake support, find_package won't find it ) diff --git a/tests/tt_eager/tensors/test_async_tensor_apis.cpp b/tests/tt_eager/tensors/test_async_tensor_apis.cpp index 3f3c8b430106..3c7d689e57fb 100644 --- a/tests/tt_eager/tensors/test_async_tensor_apis.cpp +++ b/tests/tt_eager/tensors/test_async_tensor_apis.cpp @@ -33,19 +33,21 @@ TEST_F(CommonFixture, TestTensorOwnershipSanity) { auto func = [device, host_tensor, readback_tensor]() mutable { // Ensure that both the lambda and global scope have ownership to this tensor EXPECT_EQ(host_tensor.tensor_attributes.use_count(), 2); - std::visit([](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 1); - } - }, - storage.buffer); - } - }, host_tensor.get_storage()); + std::visit( + [](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 1); + } + }, + storage.buffer); + } + }, + host_tensor.get_storage()); // Send tensor to device, read it back and copy it to empty tensor initialized by main thread Tensor reshaped_tensor = host_tensor.reshape(1, 1, 32, 128); auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); @@ -54,41 +56,45 @@ TEST_F(CommonFixture, TestTensorOwnershipSanity) { readback_tensor.set_shape(thread_local_tensor.get_shape()); readback_tensor.set_dtype(thread_local_tensor.get_dtype()); readback_tensor.set_layout(thread_local_tensor.get_layout()); - readback_tensor.set_populated(); + readback_tensor.tensor_attributes->metadata_populated = true; + readback_tensor.tensor_attributes->num_workers_completed++; // Ensure that the readback buffer is owned inside and outside the lambda - std::visit([](auto&& storage) { + std::visit( + [](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 2); + } + }, + storage.buffer); + } + }, + readback_tensor.get_storage()); + }; + + func(); + std::visit( + [](auto&& storage) { using T = std::decay_t; if constexpr (std::is_same_v) { std::visit( [](auto&& buf) { using buf_type = std::decay_t; if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 2); + EXPECT_EQ(buf.use_count(), 1); + for (int i = 0; i < 128 * 32; i++) { + EXPECT_EQ(buf[i], i); + } } }, - storage.buffer); + storage.buffer); } - }, readback_tensor.get_storage()); - }; - - func(); - std::visit([](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 1); - for (int i = 0; i < 128 * 32; i++) { - EXPECT_EQ(buf[i], i); - } - } - }, - storage.buffer); - } - }, - readback_tensor.get_storage()); + }, + readback_tensor.get_storage()); EXPECT_EQ(readback_tensor.get_dtype(), DataType::FLOAT32); EXPECT_EQ(readback_tensor.get_layout(), Layout::ROW_MAJOR); EXPECT_EQ(readback_tensor.get_shape(), ttnn::Shape(Shape({1, 1, 32, 128}))); @@ -126,8 +132,7 @@ TEST_F(CommonFixture, TestAsyncEltwiseBinary) { input_c_addr = std::get(input_tensor_c.get_storage()).buffer->address(); output_1_addr = std::get(output_tensor_device.get_storage()).buffer->address(); output_2_addr = std::get(output_tensor_device_2.get_storage()).buffer->address(); - } - else { + } else { EXPECT_EQ(std::get(input_tensor_a.get_storage()).buffer->address(), input_a_addr); EXPECT_EQ(std::get(input_tensor_b.get_storage()).buffer->address(), input_b_addr); EXPECT_EQ(std::get(input_tensor_c.get_storage()).buffer->address(), input_c_addr); @@ -140,7 +145,8 @@ TEST_F(CommonFixture, TestAsyncEltwiseBinary) { output_tensor_device.deallocate(); output_tensor_device_2.deallocate(); // Verify output data - auto& buf = std::get>(std::get(output_tensor_host.get_storage()).buffer); + auto& buf = + std::get>(std::get(output_tensor_host.get_storage()).buffer); EXPECT_EQ(buf.use_count(), 1); for (int j = 0; j < 1024 * 1024; j++) { EXPECT_EQ(bfloat16(buf[j]), bfloat16(static_cast(i - 2 * i * i))); @@ -159,21 +165,27 @@ TEST_F(CommonFixture, TestAsyncRefCountManager) { for (int i = 0; i < 5; i++) { // Run for multiple loops to ensure deterministic behaviour with device addresses // Initialize 2 tensors on device - Tensor tensor1 = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); - Tensor tensor2 = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + Tensor tensor1 = + tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + Tensor tensor2 = + tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); uint32_t tensor2_device_buf_addr = tensor2.device_buffer()->address(); - // Assign tensor1 to tensor2 and ensure that ref counts are appropriately updated with the buffer for tensor2 deallocated + // Assign tensor1 to tensor2 and ensure that ref counts are appropriately updated with the buffer for tensor2 + // deallocated tensor2 = tensor1; EXPECT_EQ(tensor2.tensor_attributes->main_thread_ref_count, 2); EXPECT_EQ(tensor1.tensor_attributes->main_thread_ref_count, 2); - // To check if tensor2 is deallocated, create a third tensor on device and ensure that its address matches the prev addr for tensor2 - Tensor tensor3 = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + // To check if tensor2 is deallocated, create a third tensor on device and ensure that its address matches the + // prev addr for tensor2 + Tensor tensor3 = + tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); EXPECT_EQ(tensor3.device_buffer()->address(), tensor2_device_buf_addr); EXPECT_EQ(tensor1.device_buffer()->address(), tensor2.device_buffer()->address()); } log_info(LogTest, "Testing Device tensor self-assignment through function"); for (int i = 0; i < 5; i++) { - Tensor device_tensor = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + Tensor device_tensor = + tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); uint32_t device_tensor_address = device_tensor.device_buffer()->address(); // This step will copy the tensor to a temp rval and std::move it back to the caller's instance of device_tensor // Ensure ref count and address remain unchanged @@ -184,14 +196,16 @@ TEST_F(CommonFixture, TestAsyncRefCountManager) { log_info(LogTest, "Testing Device tensor move assignment"); for (int i = 0; i < 5; i++) { - Tensor tensor1 = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); + Tensor tensor1 = + tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(i), DataType::BFLOAT16).to(device); Tensor tensor2 = std::move(tensor1); EXPECT_EQ(tensor2.tensor_attributes->main_thread_ref_count, 1); EXPECT_EQ(tensor1.tensor_attributes, nullptr); } log_info(LogTest, "Testing Device tensor self-assignment"); - Tensor tensor_to_self_assign = tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(0), DataType::BFLOAT16).to(device); + Tensor tensor_to_self_assign = + tt::numpy::full(Shape({1, 1, 1024, 1024}), static_cast(0), DataType::BFLOAT16).to(device); uint32_t tensor_to_self_assign_address = tensor_to_self_assign.device_buffer()->address(); tensor_to_self_assign = tensor_to_self_assign; EXPECT_EQ(tensor_to_self_assign.tensor_attributes->main_thread_ref_count, 1); @@ -219,7 +233,6 @@ TEST_F(CommonFixture, TestAsyncRefCountManager) { // Tensor output_tensor_device = mul(add(input_tensor_a, input_tensor_b), input_tensor_c); // Tensor output_tensor_device_2 = neg(sub(output_tensor_device, input_tensor_c)); - // EXPECT_EQ(output_tensor_device.get_shape(), ttnn::Shape(Shape({1, 1, 1023, 1023}))); // EXPECT_EQ(output_tensor_device.get_dtype(), DataType::BFLOAT16); @@ -234,45 +247,50 @@ TEST_F(CommonFixture, TestAsyncRefCountManager) { // device->set_worker_mode(WorkExecutorMode::SYNCHRONOUS); // } - TEST_F(CommonFixture, TestTensorAsyncDataMovement) { // Test 2 data paths here (resembles async mode): - // 1. Main -> Worker: Create a tensor in the main thread. Ensure that it is accessible in the worker thread even after its destroyed + // 1. Main -> Worker: Create a tensor in the main thread. Ensure that it is accessible in the worker thread even + // after its destroyed // by the main thread. This resembles host -> device data movement - // 2. Worker -> Main: Create an empty tensor in the mainb thread. Populate it in the worker thread. Ensure that the tensor is correctly + // 2. Worker -> Main: Create an empty tensor in the mainb thread. Populate it in the worker thread. Ensure that the + // tensor is correctly // populated in the main thread once the worker is done. Device* device = this->devices_[0]; uint32_t tensor_start = 0; uint32_t num_tiles = 128; uint32_t tensor_stop = TILE_HEIGHT * TILE_WIDTH * num_tiles; - Tensor readback_tensor({}, 1);; + Tensor readback_tensor({}, 1); + ; std::thread worker; { // host_tensor only lives in this scope Tensor host_tensor = tt::numpy::arange(tensor_start, tensor_stop, 1); log_info(LogTest, "Spawning worker thread"); - worker = std::thread([tensor_stop, host_tensor, readback_tensor, device] () mutable { + worker = std::thread([tensor_stop, host_tensor, readback_tensor, device]() mutable { // Sleep for 3 seconds to ensure that main thread deallocates host_tensor std::this_thread::sleep_for(std::chrono::milliseconds(3000)); log_info(LogTest, "Worker started"); // Main thread should have deallocated host_tensor by this point EXPECT_EQ(host_tensor.tensor_attributes.use_count(), 1); // Ensure that the buffer inside host_buffer is owned by a single tensor_attr object - // This buffer will not go out of scope until the last object owning it is destroyed (i.e. until the thread is done) - std::visit([](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 1); - } - }, - storage.buffer); - } - }, host_tensor.get_storage()); + // This buffer will not go out of scope until the last object owning it is destroyed (i.e. until the thread + // is done) + std::visit( + [](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 1); + } + }, + storage.buffer); + } + }, + host_tensor.get_storage()); Tensor reshaped_tensor = host_tensor.reshape(1, 1, 32, tensor_stop / 32); auto device_tensor = reshaped_tensor.to(Layout::TILE).to(device); @@ -282,22 +300,25 @@ TEST_F(CommonFixture, TestTensorAsyncDataMovement) { readback_tensor.set_shape(thread_local_tensor.get_shape()); readback_tensor.set_dtype(thread_local_tensor.get_dtype()); readback_tensor.set_layout(thread_local_tensor.get_layout()); - readback_tensor.set_populated(); + readback_tensor.tensor_attributes->metadata_populated = true; + readback_tensor.tensor_attributes->num_workers_completed++; // Ensure that this buffer is currently owned by both the thread_local and read_back tensors // This is because we explictly pass in the buffer to a new tensor_attr object - std::visit([](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 2); - } - }, - storage.buffer); - } - }, readback_tensor.get_storage()); + std::visit( + [](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 2); + } + }, + storage.buffer); + } + }, + readback_tensor.get_storage()); log_info(LogTest, "Worker Done"); }); // Call deallocate on the tensor in the main thread to ensure that this call is safe @@ -308,22 +329,22 @@ TEST_F(CommonFixture, TestTensorAsyncDataMovement) { worker.join(); log_info(LogTest, "Verifying populated tensor in main thread"); std::visit( - [tensor_start, tensor_stop](auto&& storage) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - std::visit( - [tensor_start, tensor_stop](auto&& buf) { - using buf_type = std::decay_t; - if constexpr (std::is_same_v>) { - EXPECT_EQ(buf.use_count(), 1); - for (int i = tensor_start; i < tensor_stop; i++) { - EXPECT_EQ(buf[i], i); - } + [tensor_start, tensor_stop](auto&& storage) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + std::visit( + [tensor_start, tensor_stop](auto&& buf) { + using buf_type = std::decay_t; + if constexpr (std::is_same_v>) { + EXPECT_EQ(buf.use_count(), 1); + for (int i = tensor_start; i < tensor_stop; i++) { + EXPECT_EQ(buf[i], i); } - }, + } + }, storage.buffer); - } - }, + } + }, readback_tensor.get_storage()); EXPECT_EQ(readback_tensor.get_dtype(), DataType::FLOAT32); EXPECT_EQ(readback_tensor.get_layout(), Layout::ROW_MAJOR); diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 694138fe1f87..cdcb1b2e93e0 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -35,7 +35,7 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L [&](auto&& storage) { using StorageType = std::decay_t; if constexpr (std::is_same_v) { - this->tensor_attributes->tensor_populated = {true}; + this->tensor_attributes->num_shards_to_be_populated = 1; } else if constexpr (std::is_same_v) { TT_ASSERT(storage.buffer->device() != nullptr); workers = {storage.buffer->device()}; @@ -48,9 +48,9 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L if (not this->workers.at(0)->in_main_thread()) { this->tensor_attributes->main_thread_tensor = false; } - this->tensor_attributes->tensor_populated = {true}; + this->tensor_attributes->num_shards_to_be_populated = 1; } else if constexpr (std::is_same_v) { - this->tensor_attributes->tensor_populated = {true}; + this->tensor_attributes->num_shards_to_be_populated = 1; } else if constexpr (std::is_same_v) { workers.reserve(storage.num_buffers()); for (int i = 0; i < storage.ordered_device_ids.size(); i++) { @@ -68,14 +68,16 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L if (not this->workers.at(0)->in_main_thread()) { this->tensor_attributes->main_thread_tensor = false; } - this->tensor_attributes->tensor_populated = std::vector(storage.num_buffers(), true); + this->tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); } else if constexpr (std::is_same_v) { - this->tensor_attributes->tensor_populated = std::vector(storage.num_buffers(), true); + this->tensor_attributes->num_shards_to_be_populated = storage.num_buffers(); } else { raise_unsupported_storage(); } }, storage); + this->tensor_attributes->num_workers_completed = this->tensor_attributes->num_shards_to_be_populated; + this->tensor_attributes->metadata_populated = true; } Tensor::Tensor(const Storage storage, const Shape shape, DataType dtype, Layout layout) : @@ -239,45 +241,6 @@ void Tensor::perform_cleanup_for_async_mode() { } } -// Main Thread - Wait for all workers in this tensor to populate the entire tensor -void Tensor::wait_for_tensor_data_populated() const { - ZoneScoped; - // Stall until all the workers for this tensor - // have populated the full tensor - for (int i = 0; i < this->tensor_attributes->tensor_populated.size(); i++) { - while (true) { - std::scoped_lock lock(this->tensor_attributes->populated_mutex); - if (this->tensor_attributes->tensor_populated.at(i)) - break; - } - } -} - -// Main Thread - Wait for the first worker in this tensor to populate the global metadata fields -void Tensor::wait_for_tensor_metadata_populated() const { - ZoneScoped; - // First worker is responsible for updating all metadata fields - // Stall until this worker is done - while (true) { - std::scoped_lock lock(this->tensor_attributes->populated_mutex); - if (this->tensor_attributes->tensor_populated.at(0)) - break; - }; -} - -// Worker Thread - Set populated flag to true, once worker has completed it's task for this tensor -void Tensor::set_populated(Device* worker) { - // If worker is not specified, set entry for all workers to true - std::scoped_lock lock(this->tensor_attributes->populated_mutex); - if (not worker) { - for (int i = 0; i < this->tensor_attributes->tensor_populated.size(); i++) { - this->tensor_attributes->tensor_populated.at(i) = true; - } - } else { - this->tensor_attributes->tensor_populated.at(worker->id()) = true; - } -} - void Tensor::deepcopy(const Tensor& other) { ZoneScoped; // Wait until the tensor being copied is populated @@ -288,7 +251,8 @@ void Tensor::deepcopy(const Tensor& other) { this->set_dtype(other.get_dtype()); this->set_layout(other.get_layout()); // Set metadata populated flag for getters - this->set_populated(); + this->tensor_attributes->metadata_populated = true; + this->tensor_attributes->num_workers_completed++; } void Tensor::populate_buffers_and_metadata(const Tensor& other) { @@ -304,17 +268,17 @@ void Tensor::populate_buffers_and_metadata(const Tensor& other) { using StorageType = std::decay_t; if constexpr (std::is_same_v or std::is_same_v) { std::get(this->tensor_attributes->storage).insert_buffer(storage.get_buffer()); - this->tensor_attributes->tensor_populated = {true}; } else if constexpr ( std::is_same_v or std::is_same_v) { std::get(this->tensor_attributes->storage).buffers = storage.buffers; std::get(this->tensor_attributes->storage).shapes = storage.shapes; - this->tensor_attributes->tensor_populated = std::vector(storage.buffers.size(), true); } }, other.get_storage()); // Non blocking storage query, since this is done for tensors that get created inside the // worker thread + this->tensor_attributes->metadata_populated = true; + this->tensor_attributes->num_workers_completed++; } std::vector Tensor::get_workers(bool blocking) const { @@ -484,21 +448,20 @@ Tensor Tensor::to(const std::vector& workers, const MemoryConfig& mem_c uint32_t num_workers = workers_to_use.size(); for (int worker_index = 0; worker_index < workers_to_use.size(); ++worker_index) { auto& worker = workers_to_use[worker_index]; - worker->push_work([worker, *this, device_tensor, mem_config, num_workers, worker_index]() mutable { - auto shard = get_shard_for_device(*this, worker, worker_index); - if (shard.storage_type() == StorageType::OWNED) { - shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, std::nullopt); - } - insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); - if (not worker->id()) { - device_tensor.set_shape(this->get_shape()); - device_tensor.set_dtype(this->get_dtype()); - device_tensor.set_layout(this->get_layout()); - } - if (num_workers > 1) - device_tensor.set_populated(worker); - else - device_tensor.set_populated(); + worker->push_work( + [worker, *this, device_tensor, mem_config, num_workers, worker_index] () mutable { + auto shard = get_shard_for_device(*this, worker, worker_index); + if (shard.storage_type() == StorageType::OWNED) { + shard = tensor_impl::to_device_wrapper(shard, worker, mem_config, std::nullopt); + } + insert_buffer_and_shape_for_device(worker, shard, device_tensor, worker_index); + uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; + if (not num_workers_completed) { + device_tensor.set_shape(this->get_shape()); + device_tensor.set_dtype(this->get_dtype()); + device_tensor.set_layout(this->get_layout()); + device_tensor.tensor_attributes->metadata_populated = true; + } }); } device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count); @@ -528,22 +491,18 @@ Tensor Tensor::cpu(bool blocking) const { auto shard = get_shard_for_device(*this, target_device); shard = tensor_impl::to_host_wrapper(shard, blocking); insert_buffer_and_shape_for_device(target_device, shard, host_tensor, worker_index); - if (not target_device->id() or workers.size() == 1) { + uint32_t num_workers_completed = (host_tensor.tensor_attributes->num_workers_completed)++; + if (not num_workers_completed) { host_tensor.set_shape(this->get_shape()); host_tensor.set_dtype(this->get_dtype()); host_tensor.set_layout(this->get_layout()); - } - if (workers.size() == 1) { - host_tensor.set_populated(); - } else { - host_tensor.set_populated(target_device); + host_tensor.tensor_attributes->metadata_populated = true; } }); } + if (blocking) { - for (auto target_device : workers) { - target_device->synchronize(); - } + detail::SynchronizeWorkerThreads(workers); } // Update main_thread_ref_count for tensor after pushing to queue. this->tensor_attributes->update_main_thread_ref_count(workers.at(0), original_tensor_ref_count); @@ -613,12 +572,13 @@ Tensor Tensor::to(Layout target_layout, DeviceMesh* device_mesh) const { auto shard = get_shard_for_device(*this, worker, worker_index); shard = tensor_impl::to_layout_wrapper(shard, target_layout); insert_buffer_and_shape_for_device(worker, shard, tensor_modified_layout, worker_index); - if (not(worker->id())) { + uint32_t num_workers_completed = (tensor_modified_layout.tensor_attributes->num_workers_completed)++; + if (not num_workers_completed) { tensor_modified_layout.set_shape(this->get_shape()); tensor_modified_layout.set_dtype(this->get_dtype()); tensor_modified_layout.set_layout(target_layout); - } - tensor_modified_layout.set_populated(worker); + tensor_modified_layout.tensor_attributes->metadata_populated = true; + }; }); } return tensor_modified_layout; @@ -987,15 +947,18 @@ Tensor allocate_tensor_on_device( for (int worker_index = 0; worker_index < num_workers; ++worker_index) { auto& worker = workers[worker_index]; - worker->push_work([shape, data_type, layout, worker, memory_config, device_tensor, worker_index]() mutable { - auto local_tensor = create_device_tensor(shape.value(), data_type, layout, worker, memory_config); - insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); - if (not worker->id()) { - device_tensor.set_shape(ttnn::Shape(shape)); - device_tensor.set_dtype(data_type); - device_tensor.set_layout(layout); - } - device_tensor.set_populated(worker); + worker->push_work( + [shape, data_type, layout, worker, memory_config, device_tensor, worker_index] () mutable { + auto local_tensor = create_device_tensor(shape.value(), data_type, layout, worker, memory_config); + insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); + + uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; + if (not num_workers_completed) { + device_tensor.set_shape(ttnn::Shape(shape)); + device_tensor.set_dtype(data_type); + device_tensor.set_layout(layout); + device_tensor.tensor_attributes->metadata_populated = true; + } }); } device_tensor.tensor_attributes->update_main_thread_ref_count(workers.at(0), device_tensor_ref_count); diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index 16c9665d2c35..fedbf54cb42e 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -32,10 +32,12 @@ struct Tensor { DataType dtype; Layout layout; std::mutex populated_mutex; - std::vector tensor_populated = {}; + uint32_t num_shards_to_be_populated = 0; uint32_t main_thread_ref_count = 0; std::atomic num_sibling_workers_sharing_tensor = 0; std::atomic main_thread_tensor = true; + std::atomic metadata_populated = false; + std::atomic num_workers_completed = 0; bool deallocated = false; // Set to true if device side storage was deallocated bool dynamic_storage = false; // Storage type can change, depending on op behaviour bool track_ref_count = false; @@ -155,7 +157,7 @@ struct Tensor { std::get(this->tensor_attributes->storage).ordered_device_ids), [](const Device *worker) { return worker->id(); }); } - this->tensor_attributes->tensor_populated = std::vector(workers.size(), false); + this->tensor_attributes->num_shards_to_be_populated = workers.size(); } else if (num_buffers) { if (num_buffers == 1) { this->tensor_attributes->storage = OwnedStorage(); @@ -167,7 +169,7 @@ struct Tensor { std::get(this->tensor_attributes->storage).shapes = std::vector(num_buffers, this->tensor_attributes->shape.value()); } - this->tensor_attributes->tensor_populated = std::vector(num_buffers, false); + this->tensor_attributes->num_shards_to_be_populated = num_buffers; } } @@ -286,19 +288,26 @@ struct Tensor { const ttnn::Shape &get_shape() const; const DataType &get_dtype() const; const Layout &get_layout() const; + + // ====================================================================================== + // Non-Blocking Getters. Query attributes directly, without waiting for worker completion + // ====================================================================================== + inline const Storage &storage() const { return this->tensor_attributes->storage; }; + inline const Shape &legacy_shape() const { return this->tensor_attributes->shape.value(); }; + inline const ttnn::Shape &shape() const { return this->tensor_attributes->shape; }; + inline const DataType &dtype() const { return this->tensor_attributes->dtype; }; + inline const Layout &layout() const { return this->tensor_attributes->layout; }; + // ====================================================================================== // Setters // ====================================================================================== - void set_storage(const Storage &storage) { this->tensor_attributes->storage = storage; } - void set_shape(const ttnn::Shape &shape) { this->tensor_attributes->shape = shape; } - void set_dtype(const DataType &dtype) { this->tensor_attributes->dtype = dtype; } - void set_layout(const Layout &layout) { this->tensor_attributes->layout = layout; } - void set_populated(Device *worker = nullptr); + inline void set_storage(const Storage &storage) { this->tensor_attributes->storage = storage; } + inline void set_shape(const ttnn::Shape &shape) { this->tensor_attributes->shape = shape; } + inline void set_dtype(const DataType &dtype) { this->tensor_attributes->dtype = dtype; } + inline void set_layout(const Layout &layout) { this->tensor_attributes->layout = layout; } // ====================================================================================== // Extra Helper Functions // ====================================================================================== - void wait_for_tensor_data_populated() const; - void wait_for_tensor_metadata_populated() const; StorageType storage_type() const; const Shape strides() const; uint32_t volume() const; @@ -355,13 +364,31 @@ struct Tensor { static constexpr auto attribute_names = std::make_tuple("storage", "shape", "dtype", "layout"); const auto attribute_values() const { return std::make_tuple( - std::cref(this->get_storage()), - std::cref(this->get_shape()), - std::cref(this->get_dtype()), - std::cref(this->get_layout())); + std::cref(this->tensor_attributes->storage), + std::cref(this->tensor_attributes->shape), + std::cref(this->tensor_attributes->dtype), + std::cref(this->tensor_attributes->layout)); } std::vector host_page_ordering(); + + // Main Thread - Wait for all workers in this tensor to populate the entire tensor + inline void wait_for_tensor_data_populated() const { + ZoneScoped; + // Stall until all the workers for this tensor + // have populated the full tensor + while (this->tensor_attributes->num_workers_completed < this->tensor_attributes->num_shards_to_be_populated) { + } + } + + // Main Thread - Wait for the first worker in this tensor to populate the global metadata fields + inline void wait_for_tensor_metadata_populated() const { + ZoneScoped; + // First worker is responsible for updating all metadata fields + // Stall until this worker is done + while (not this->tensor_attributes->metadata_populated) { + } + } }; Tensor create_device_tensor( diff --git a/tt_eager/tensor/tensor_impl.hpp b/tt_eager/tensor/tensor_impl.hpp index a16047e02b01..2bf7bbdbcb53 100644 --- a/tt_eager/tensor/tensor_impl.hpp +++ b/tt_eager/tensor/tensor_impl.hpp @@ -392,7 +392,6 @@ inline Tensor to_host(const Tensor& tensor, bool blocking = true) { host_tensor.set_dtype(tensor.get_dtype()); host_tensor.set_layout(tensor.get_layout()); insert_buffer_and_shape_for_device(device, shard, host_tensor, device_index); - host_tensor.set_populated(device); } return host_tensor; } else { @@ -942,7 +941,7 @@ inline std::string to_string(const Tensor& tensor, std::optional origi } if (is_tensor_on_device(tensor)) { - return to_string(to_host(tensor)); + return to_string(tensor.cpu()); } return std::visit( @@ -985,7 +984,7 @@ inline std::string to_string(const Tensor& tensor, std::optional origi TT_THROW("Cannot print a device tensor!"); } else if constexpr (std::is_same_v) { auto devices = get_devices(tensor); - auto host_tensor = to_host(tensor); + auto host_tensor = tensor.cpu(); auto device_index = 0; std::stringstream ss; apply(host_tensor, [&](const Tensor& device_tensor) { diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index f6cd958d7911..a5bf1dba55f8 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -11,189 +11,214 @@ namespace tt { namespace tt_metal { - - template - Tensor to_weight_special_padding_tile_layout(const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { - auto w_shape = conv_weight_tensor.get_legacy_shape(); - auto compute = - [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { - uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; - uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; - auto weight_matrix_cols = w_shape[0]; - // width padding - if (weight_matrix_cols % in1_block_w_datums != 0) { - weight_matrix_cols = (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * - in1_block_w_datums; - } - // height padding - assert(in1_block_h_datums >= w_shape[1] * w_shape[3]); - uint32_t block_height_padding = in1_block_h_datums - (w_shape[1] * w_shape[3]); - auto weight_matrix_rows = ((w_shape[1] * w_shape[3]) + block_height_padding) * w_shape[2]; - Shape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(compute_volume(output_shape)); - for (auto r = 0; r < w_shape[2]; r++) { - for (auto s = 0; s < w_shape[3]; s++) { - for (auto c = 0; c < w_shape[1]; c++) { - for (auto k = 0; k < w_shape[0]; k++) { - auto matrix_idx = - k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + - r * ((w_shape[3] * w_shape[1]) + block_height_padding) * weight_matrix_cols; - auto idx = k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + - r * w_shape[3] + s; - output_buffer[matrix_idx] = input_buffer[idx]; - } - } - } - } - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = pack_fp32_vec_as_bfp8_tiles( - output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - } - if (output_dtype == DataType::BFLOAT4_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = pack_fp32_vec_as_bfp4_tiles( - output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_uint32_buffer)}), - output_shape, - output_dtype, - Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); +template +Tensor to_weight_special_padding_tile_layout( + const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { + auto w_shape = conv_weight_tensor.get_legacy_shape(); + auto compute = [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { + uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; + uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; + auto weight_matrix_cols = w_shape[0]; + // width padding + if (weight_matrix_cols % in1_block_w_datums != 0) { + weight_matrix_cols = + (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * in1_block_w_datums; + } + // height padding + assert(in1_block_h_datums >= w_shape[1] * w_shape[3]); + uint32_t block_height_padding = in1_block_h_datums - (w_shape[1] * w_shape[3]); + auto weight_matrix_rows = ((w_shape[1] * w_shape[3]) + block_height_padding) * w_shape[2]; + Shape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(compute_volume(output_shape)); + for (auto r = 0; r < w_shape[2]; r++) { + for (auto s = 0; s < w_shape[3]; s++) { + for (auto c = 0; c < w_shape[1]; c++) { + for (auto k = 0; k < w_shape[0]; k++) { + auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + + r * ((w_shape[3] * w_shape[1]) + block_height_padding) * weight_matrix_cols; + auto idx = + k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; + output_buffer[matrix_idx] = input_buffer[idx]; } - } else { - TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); } + } + } + if constexpr (std::is_same::value) { + if (output_dtype == DataType::BFLOAT8_B) { + auto output_float_data = output_buffer.get(); + auto output_packed_data = + pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); auto rm_tensor = Tensor( - std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + output_shape, + output_dtype, + Layout::ROW_MAJOR); return rm_tensor.to(Layout::TILE); - }; - return std::visit( - [&compute](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return compute(owned_buffer::get_as(storage.buffer)); - } else if constexpr (std::is_same_v) { - return compute(borrowed_buffer::get_as(storage.buffer)); - } else { - TT_THROW("Unsupported storage type"); - } - }, - conv_weight_tensor.get_storage()); - } - - - template - Tensor to_weight_tile_layout(const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { - auto w_shape = conv_weight_tensor.get_legacy_shape(); - auto compute = - [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { - auto weight_matrix_cols = w_shape[0]; - // width padding - uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; - if(weight_matrix_cols%in1_block_w_datums != 0) { - weight_matrix_cols = (uint32_t) std::ceil( (double) weight_matrix_cols / (double) in1_block_w_datums ) * in1_block_w_datums; } - // height padding - auto weight_matrix_rows = w_shape[1]*w_shape[2]*w_shape[3]; - uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; - if (weight_matrix_rows % in1_block_h_datums != 0) { - weight_matrix_rows = (uint32_t) std::ceil( (double) weight_matrix_rows / (double) in1_block_h_datums ) * in1_block_h_datums; + if (output_dtype == DataType::BFLOAT4_B) { + auto output_float_data = output_buffer.get(); + auto output_packed_data = + pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + auto rm_tensor = Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + output_shape, + output_dtype, + Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + } + } else { + TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); + } + auto rm_tensor = + Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + }; + return std::visit( + [&compute](auto&& storage) -> Tensor { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return compute(owned_buffer::get_as(storage.buffer)); + } else if constexpr (std::is_same_v) { + return compute(borrowed_buffer::get_as(storage.buffer)); + } else { + TT_THROW("Unsupported storage type"); } - Shape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(compute_volume(output_shape)); - for(auto r = 0; r < w_shape[2]; r++) { - for(auto s = 0; s < w_shape[3]; s++) { - for(auto c = 0; c < w_shape[1]; c++) { - for(auto k = 0; k < w_shape[0]; k++) { - auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + r * w_shape[3] * w_shape[1] * weight_matrix_cols; - auto idx = k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; - output_buffer[matrix_idx] = input_buffer[idx]; - } + }, + conv_weight_tensor.get_storage()); +} + +template +Tensor to_weight_tile_layout( + const Tensor& conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, DataType output_dtype) { + auto w_shape = conv_weight_tensor.get_legacy_shape(); + auto compute = [&w_shape, &in1_block_h, &in1_block_w, &output_dtype](const auto& input_buffer) { + auto weight_matrix_cols = w_shape[0]; + // width padding + uint32_t in1_block_w_datums = in1_block_w * constants::TILE_WIDTH; + if (weight_matrix_cols % in1_block_w_datums != 0) { + weight_matrix_cols = + (uint32_t)std::ceil((double)weight_matrix_cols / (double)in1_block_w_datums) * in1_block_w_datums; + } + // height padding + auto weight_matrix_rows = w_shape[1] * w_shape[2] * w_shape[3]; + uint32_t in1_block_h_datums = in1_block_h * constants::TILE_HEIGHT; + if (weight_matrix_rows % in1_block_h_datums != 0) { + weight_matrix_rows = + (uint32_t)std::ceil((double)weight_matrix_rows / (double)in1_block_h_datums) * in1_block_h_datums; + } + Shape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(compute_volume(output_shape)); + for (auto r = 0; r < w_shape[2]; r++) { + for (auto s = 0; s < w_shape[3]; s++) { + for (auto c = 0; c < w_shape[1]; c++) { + for (auto k = 0; k < w_shape[0]; k++) { + auto matrix_idx = k + c * weight_matrix_cols + s * w_shape[1] * weight_matrix_cols + + r * w_shape[3] * w_shape[1] * weight_matrix_cols; + auto idx = + k * w_shape[1] * w_shape[2] * w_shape[3] + c * w_shape[2] * w_shape[3] + r * w_shape[3] + s; + output_buffer[matrix_idx] = input_buffer[idx]; } } } - if constexpr (std::is_same::value) { - if (output_dtype == DataType::BFLOAT8_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - } - if (output_dtype == DataType::BFLOAT4_B) { - auto output_float_data = output_buffer.get(); - auto output_packed_data = pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); - auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); - auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(output_uint32_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - } + } + if constexpr (std::is_same::value) { + if (output_dtype == DataType::BFLOAT8_B) { + auto output_float_data = output_buffer.get(); + auto output_packed_data = + pack_fp32_vec_as_bfp8_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + auto rm_tensor = Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + output_shape, + output_dtype, + Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + } + if (output_dtype == DataType::BFLOAT4_B) { + auto output_float_data = output_buffer.get(); + auto output_packed_data = + pack_fp32_vec_as_bfp4_tiles(output_float_data, /*row_major_input=*/false, /*is_exp_a=*/false); + auto output_uint32_buffer = owned_buffer::create(std::move(output_packed_data)); + auto rm_tensor = Tensor( + std::move(OwnedStorage{std::move(output_uint32_buffer)}), + output_shape, + output_dtype, + Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + } + } else { + TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); + } + auto rm_tensor = + Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); + return rm_tensor.to(Layout::TILE); + }; + return std::visit( + [&compute](auto&& storage) -> Tensor { + using StorageType = std::decay_t; + if constexpr (std::is_same_v) { + return compute(owned_buffer::get_as(storage.buffer)); + } else if constexpr (std::is_same_v) { + return compute(borrowed_buffer::get_as(storage.buffer)); } else { - TT_ASSERT((output_dtype != DataType::BFLOAT8_B) || (output_dtype != DataType::BFLOAT4_B)); + TT_THROW("Unsupported storage type"); } - auto rm_tensor = Tensor(std::move(OwnedStorage{std::move(output_buffer)}), output_shape, output_dtype, Layout::ROW_MAJOR); - return rm_tensor.to(Layout::TILE); - }; - return std::visit( - [&compute](auto&& storage) -> Tensor { - using StorageType = std::decay_t; - if constexpr (std::is_same_v) { - return compute(owned_buffer::get_as(storage.buffer)); - } else if constexpr (std::is_same_v) { - return compute(borrowed_buffer::get_as(storage.buffer)); - } else { - TT_THROW("Unsupported storage type"); - } - }, - conv_weight_tensor.get_storage()); - } + }, + conv_weight_tensor.get_storage()); +} - // Converts convolution weights to tilized 2d matrix layout. - // Returns a new tensor with layout=Tile - Tensor convert_conv_weight_tensor_to_tiled_layout(Tensor conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { - TT_ASSERT(conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && "Convolution weights should be in row major layout for conversion to tilized layout."); - const static std::map> to_w_tile_layout_map = { +// Converts convolution weights to tilized 2d matrix layout. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_tiled_layout( + Tensor conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { + TT_ASSERT( + conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && + "Convolution weights should be in row major layout for conversion to tilized layout."); + const static std::map< + DataType, + std::function> + to_w_tile_layout_map = { {DataType::BFLOAT16, &to_weight_tile_layout}, {DataType::FLOAT32, &to_weight_tile_layout}, {DataType::UINT32, &to_weight_tile_layout}, }; - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); - } + if (output_dtype.has_value()) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); + } else { + TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); } - return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); } + return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())( + conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); +} - // Converts convolution weights to tilized 2d matrix layout. - // Returns a new tensor with layout=Tile - Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout(Tensor conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { - TT_ASSERT(conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && "Convolution weights should be in row major layout for conversion to tilized layout."); - const static std::map> to_w_tile_layout_map = { +// Converts convolution weights to tilized 2d matrix layout. +// Returns a new tensor with layout=Tile +Tensor convert_conv_weight_tensor_to_special_padding_tiled_layout( + Tensor conv_weight_tensor, uint32_t in1_block_h, uint32_t in1_block_w, std::optional output_dtype) { + TT_ASSERT( + conv_weight_tensor.get_layout() == Layout::ROW_MAJOR && + "Convolution weights should be in row major layout for conversion to tilized layout."); + const static std::map< + DataType, + std::function> + to_w_tile_layout_map = { {DataType::BFLOAT16, &to_weight_special_padding_tile_layout}, {DataType::FLOAT32, &to_weight_special_padding_tile_layout}, - {DataType::UINT32, &to_weight_special_padding_tile_layout} - }; - if (output_dtype.has_value()) { - if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { - TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); - } else { - TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); - } + {DataType::UINT32, &to_weight_special_padding_tile_layout}}; + if (output_dtype.has_value()) { + if (output_dtype == DataType::BFLOAT8_B || output_dtype == DataType::BFLOAT4_B) { + TT_ASSERT(conv_weight_tensor.get_dtype() == DataType::FLOAT32); + } else { + TT_ASSERT(conv_weight_tensor.get_dtype() == conv_weight_tensor.get_dtype()); } - return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())(conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); } + return to_w_tile_layout_map.at(conv_weight_tensor.get_dtype())( + conv_weight_tensor, in1_block_h, in1_block_w, output_dtype.value_or(conv_weight_tensor.get_dtype())); +} /* Helper function to aid in converting grouped weight tensor to ungrouped weight tensor with padded zero channels @@ -323,44 +348,39 @@ const Shape infer_dims_for_reshape(int N, int C, int H, int W, uint32_t old_volu switch (neg_idx) { case 0: - TT_ASSERT(old_volume % C*H*W == 0); - N = old_volume/(C*H*W); + TT_ASSERT(old_volume % C * H * W == 0); + N = old_volume / (C * H * W); break; case 1: - TT_ASSERT(old_volume % N*H*W == 0); - C = old_volume/(N*H*W); + TT_ASSERT(old_volume % N * H * W == 0); + C = old_volume / (N * H * W); break; case 2: - TT_ASSERT(old_volume % N*C*W == 0); - H = old_volume/(N*C*W); + TT_ASSERT(old_volume % N * C * W == 0); + H = old_volume / (N * C * W); break; case 3: - TT_ASSERT(old_volume % N*C*H == 0); - W = old_volume/(N*C*H); + TT_ASSERT(old_volume % N * C * H == 0); + W = old_volume / (N * C * H); break; - case -1: // In case where there is no negative value in ns - TT_ASSERT(N*C*H*W == old_volume); + case -1: // In case where there is no negative value in ns + TT_ASSERT(N * C * H * W == old_volume); break; - default: - TT_ASSERT(false && "Unexpected neg_idx in reshape!"); + default: TT_ASSERT(false && "Unexpected neg_idx in reshape!"); } return {(uint32_t)N, (uint32_t)C, (uint32_t)H, (uint32_t)W}; } - bool is_arch_gs(const tt::ARCH& arch) { - return arch == tt::ARCH::GRAYSKULL; - } +bool is_arch_gs(const tt::ARCH& arch) { return arch == tt::ARCH::GRAYSKULL; } - bool is_arch_whb0(const tt::ARCH& arch) { - return arch == tt::ARCH::WORMHOLE_B0; - } +bool is_arch_whb0(const tt::ARCH& arch) { return arch == tt::ARCH::WORMHOLE_B0; } - bool is_cpu_tensor(const Tensor& tensor) { - return tensor.storage_type() == StorageType::OWNED || tensor.storage_type() == StorageType::BORROWED; - } +bool is_cpu_tensor(const Tensor& tensor) { + return tensor.storage_type() == StorageType::OWNED || tensor.storage_type() == StorageType::BORROWED; +} - bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } +bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) { if (std::holds_alternative(multi_device_tensor.get_storage())) { @@ -384,10 +404,10 @@ Tensor get_device_tensor(const Tensor& multi_device_tensor, const Device* device } bool is_multi_device_tensor(const Tensor& tensor) { - return tensor.storage_type() == StorageType::MULTI_DEVICE or tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; + return tensor.storage_type() == StorageType::MULTI_DEVICE or + tensor.storage_type() == StorageType::MULTI_DEVICE_HOST; } - std::vector get_tensors_from_multi_device_storage(const Tensor& multi_device_tensor) { std::vector tensors; if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE) { @@ -399,8 +419,7 @@ std::vector get_tensors_from_multi_device_storage(const Tensor& multi_de DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, tensor_storage.shapes.at(device_id), multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout() - }; + multi_device_tensor.get_layout()}; } return tensors; } else if (multi_device_tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { @@ -410,11 +429,9 @@ std::vector get_tensors_from_multi_device_storage(const Tensor& multi_de OwnedStorage{tensor_storage.get_buffer(i)}, tensor_storage.shapes[i], multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout() - }); + multi_device_tensor.get_layout()}); } - } - else { + } else { TT_FATAL(false, "get_tensors_from_multi_device_storage only support multi device tensors"); } return tensors; @@ -424,15 +441,15 @@ DistributedTensorConfig get_distributed_tensor_config_from_tensor(const Tensor& if (tensor.storage_type() == StorageType::MULTI_DEVICE) { const auto& tensor_storage = std::get(tensor.get_storage()); return tensor_storage.strategy; - } - else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { + } else if (tensor.storage_type() == StorageType::MULTI_DEVICE_HOST) { const auto& tensor_storage = std::get(tensor.get_storage()); return tensor_storage.strategy; } TT_THROW("Tensor is not a multi-device tensor"); } -Tensor create_multi_device_tensor(const std::vector& tensors, StorageType storage_type, const DistributedTensorConfig& strategy) { +Tensor create_multi_device_tensor( + const std::vector& tensors, StorageType storage_type, const DistributedTensorConfig& strategy) { if (tensors.empty()) { TT_THROW("Cannot create multi-device tensor with empty tensor list"); } @@ -452,8 +469,7 @@ Tensor create_multi_device_tensor(const std::vector& tensors, StorageTyp MultiDeviceStorage{strategy, ordered_device_ids, device_buffers, shapes}, tensors.at(0).get_legacy_shape(), tensors.at(0).get_dtype(), - tensors.at(0).get_layout() - }; + tensors.at(0).get_layout()}; } else if (storage_type == StorageType::MULTI_DEVICE_HOST) { std::vector owned_buffers; std::vector shapes; @@ -465,8 +481,7 @@ Tensor create_multi_device_tensor(const std::vector& tensors, StorageTyp MultiDeviceHostStorage{strategy, owned_buffers, shapes}, tensors.at(0).get_legacy_shape(), tensors.at(0).get_dtype(), - tensors.at(0).get_layout() - }; + tensors.at(0).get_layout()}; } else { TT_THROW("Invalid storage type for multi-device tensor"); } @@ -475,9 +490,11 @@ Tensor create_multi_device_tensor(const std::vector& tensors, StorageTyp Tensor transform(const Tensor& tensor, std::function transform_func) { auto input_tensors = get_tensors_from_multi_device_storage(tensor); std::vector output_tensors(input_tensors.size()); - std::transform(input_tensors.begin(), input_tensors.end(), output_tensors.begin(), - [&](const auto& device_tensor) { return transform_func(device_tensor); }); - return create_multi_device_tensor(output_tensors, tensor.storage_type(), get_distributed_tensor_config_from_tensor(tensor)); + std::transform(input_tensors.begin(), input_tensors.end(), output_tensors.begin(), [&](const auto& device_tensor) { + return transform_func(device_tensor); + }); + return create_multi_device_tensor( + output_tensors, tensor.storage_type(), get_distributed_tensor_config_from_tensor(tensor)); } void apply(const Tensor& tensor, std::function callable) { @@ -487,7 +504,6 @@ void apply(const Tensor& tensor, std::function callable) { } } - std::vector get_devices(const Tensor& tensor) { std::vector devices; if (tensor.storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) { @@ -509,7 +525,10 @@ uint32_t num_buffers_in_tensor(const Tensor& tensor) { } else if (std::holds_alternative(tensor.get_storage())) { auto host_storage = std::get(tensor.get_storage()); return host_storage.num_buffers(); - } else if (std::holds_alternative(tensor.get_storage()) || std::holds_alternative(tensor.get_storage()) || std::holds_alternative(tensor.get_storage())) { + } else if ( + std::holds_alternative(tensor.get_storage()) || + std::holds_alternative(tensor.get_storage()) || + std::holds_alternative(tensor.get_storage())) { return 1; } else { TT_FATAL(false, "num_buffers_in_tensor only supports multi-device or device tensors"); @@ -519,45 +538,64 @@ uint32_t num_buffers_in_tensor(const Tensor& tensor) { Tensor get_shard_for_device(const Tensor& tensor, Device* target_device, std::optional buffer_index) { ZoneScopedN("GetShardForDevice"); Tensor shard = Tensor(); - auto& storage = tensor.get_storage(); - std::visit([target_device, buffer_index, &tensor, &shard] (auto&& s) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - auto shard_shape = s.get_tensor_shape_for_device(target_device); - auto shard_buffer = s.get_buffer_for_device(target_device); - shard = Tensor{DeviceStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()}; - } else if constexpr (std::is_same_v) { - auto shard_shape = s.get_tensor_shape(buffer_index.value()); - auto shard_buffer = s.get_buffer(buffer_index.value()); - shard = Tensor{OwnedStorage{shard_buffer}, shard_shape, tensor.get_dtype(), tensor.get_layout()}; - } else if constexpr (std::is_same_v || std::is_same_v || std::is_same_v) { - shard = tensor; - } else { - TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors"); - } - }, storage); + auto& storage = tensor.tensor_attributes->storage; + std::visit( + [target_device, buffer_index, &tensor, &shard](auto&& s) { + using T = std::decay_t; + // Stalling reads for tensor data-type and layout are needed here + // since some worker might have raced ahead to these lookups, while + // another worker is populating this metadata. + if constexpr (std::is_same_v) { + shard = Tensor{ + DeviceStorage{s.get_buffer_for_device(target_device)}, + s.get_tensor_shape_for_device(target_device), + tensor.get_dtype(), + tensor.get_layout()}; + } else if constexpr (std::is_same_v) { + shard = Tensor{ + OwnedStorage{s.get_buffer(buffer_index.value())}, + s.get_tensor_shape(buffer_index.value()), + tensor.get_dtype(), + tensor.get_layout()}; + } else if constexpr ( + std::is_same_v || std::is_same_v || + std::is_same_v) { + shard = tensor; + } else { + TT_FATAL(false, "get_shard_for_device only supports multi-device or device tensors"); + } + }, + storage); return shard; } -void insert_buffer_and_shape_for_device(Device* target_device, const Tensor& shard, Tensor& tensor_to_modify, std::optional buffer_index) { +void insert_buffer_and_shape_for_device( + Device* target_device, const Tensor& shard, Tensor& tensor_to_modify, std::optional buffer_index) { ZoneScopedN("InsertBufferAndShapeForDevice"); - std::visit([target_device, &shard, &tensor_to_modify, buffer_index] (auto&& s) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - s.insert_buffer_and_shape_for_device(buffer_index.value(), std::get(shard.get_storage()).get_buffer(), shard.get_legacy_shape()); - } else if constexpr (std::is_same_v) { - s.insert_buffer_and_shape_for_device(target_device, std::get(shard.get_storage()).get_buffer(), shard.get_legacy_shape()); - } else if constexpr (std::is_same_v) { - s.insert_buffer(std::get(shard.get_storage()).get_buffer()); - } else if constexpr (std::is_same_v) { - s.insert_buffer(std::get(shard.get_storage()).get_buffer()); - } else { - TT_FATAL(false, "Unsupported storage in insert_buffer_and_shape_for_device"); - } - }, tensor_to_modify.tensor_attributes->storage); + std::visit( + [target_device, &shard, &tensor_to_modify, buffer_index](auto&& s) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + s.insert_buffer_and_shape_for_device( + buffer_index.value(), + std::get(shard.tensor_attributes->storage).get_buffer(), + shard.tensor_attributes->shape.value()); + } else if constexpr (std::is_same_v) { + s.insert_buffer_and_shape_for_device( + target_device, + std::get(shard.tensor_attributes->storage).get_buffer(), + shard.tensor_attributes->shape.value()); + } else if constexpr (std::is_same_v) { + s.insert_buffer(std::get(shard.tensor_attributes->storage).get_buffer()); + } else if constexpr (std::is_same_v) { + s.insert_buffer(std::get(shard.tensor_attributes->storage).get_buffer()); + } else { + TT_FATAL(false, "Unsupported storage in insert_buffer_and_shape_for_device"); + } + }, + tensor_to_modify.tensor_attributes->storage); } - Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) { // When using async mode, tensors with borrowed storage cannot be passed to workers. // They need to be copied to owned storage before being passed to the worker. @@ -565,23 +603,26 @@ Tensor copy_borrowed_tensor_in_async_mode(Device* worker, const Tensor& tensor) // Tensor has workers (on device) or runtime mode is synchronous or tensor has multiple buffers. // No need to check for borrowed storage. if (worker->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS or - tensor.get_workers().size() or - tensor.tensor_attributes->tensor_populated.size() > 1) return tensor; + tensor.tensor_attributes->num_shards_to_be_populated > 1) + return tensor; if (tensor.storage_type() == StorageType::BORROWED) { ZoneScopedN("CopyBorrowedStorage"); auto borrowed_buffer = std::get(tensor.get_storage()).buffer; Tensor owned_tensor; - std::visit([&owned_tensor, &tensor] (auto&& buffer) { - using BorrowedStorageType = std::vector>; - auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end())); - owned_tensor = Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout()); - }, borrowed_buffer); + std::visit( + [&owned_tensor, &tensor](auto&& buffer) { + using BorrowedStorageType = std::vector>; + auto owned_buf = owned_buffer::create(BorrowedStorageType(buffer.begin(), buffer.end())); + owned_tensor = + Tensor(OwnedStorage{owned_buf}, tensor.get_shape(), tensor.get_dtype(), tensor.get_layout()); + }, + borrowed_buffer); return owned_tensor; } return tensor; } -} +} // namespace tt_metal -} +} // namespace tt diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index 81247e39c871..dc0a421c6f1a 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -455,7 +455,8 @@ struct MultiDeviceHostStorage { std::vector ordered_device_ids; std::unordered_map buffers; std::unordered_map shapes; - mutable std::mutex mtx; + mutable std::mutex buffer_mtx; + mutable std::mutex shape_mtx; MultiDeviceStorage() = default; MultiDeviceStorage( @@ -465,14 +466,14 @@ struct MultiDeviceHostStorage { std::unordered_map shapes_) : strategy(strategy_), ordered_device_ids(ordered_device_ids_), buffers(buffers_), shapes(shapes_) {} MultiDeviceStorage(MultiDeviceStorage &&other) { - std::lock_guard lock(mtx); + std::scoped_lock buf_lock(buffer_mtx, shape_mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; shapes = other.shapes; } MultiDeviceStorage(const MultiDeviceStorage &other) { - std::lock_guard lock(other.mtx); + std::scoped_lock buf_lock(buffer_mtx, shape_mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -480,7 +481,7 @@ struct MultiDeviceHostStorage { } MultiDeviceStorage &operator=(const MultiDeviceStorage &other) { - std::lock_guard lock(other.mtx); + std::scoped_lock buf_lock(buffer_mtx, shape_mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -489,7 +490,7 @@ struct MultiDeviceHostStorage { } MultiDeviceStorage &operator=( MultiDeviceStorage &&other) { - std::lock_guard lock(mtx); + std::scoped_lock buf_lock(buffer_mtx, shape_mtx); ordered_device_ids = other.ordered_device_ids; strategy = other.strategy; buffers = other.buffers; @@ -501,8 +502,8 @@ struct MultiDeviceHostStorage { return this->ordered_device_ids == other.ordered_device_ids and this->strategy == other.strategy and this->buffers == other.buffers and this->shapes == other.shapes; } - const MemoryConfig memory_config() const { - std::lock_guard lock(mtx); + inline const MemoryConfig memory_config() const { + std::lock_guard lock(buffer_mtx); auto first_device_id = this->ordered_device_ids.at(0); if (this->buffers.at(first_device_id).get() == nullptr) { TT_THROW("MemoryConfig can only be obtained if the buffer is not null"); @@ -523,50 +524,54 @@ struct MultiDeviceHostStorage { // Helper Functions - Getters and setters to get/modify storage attributes. These are needed to // preinitialize empty tensor handles and use/populate them in the worker threads. - void insert_buffer_and_shape_for_device(Device* device, const DeviceBuffer buffer, const Shape shape) { + + inline void insert_buffer_and_shape_for_device(Device* device, const DeviceBuffer buffer, const Shape shape) { TT_ASSERT(device == buffer->device(), "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); - std::lock_guard lock(mtx); - buffers.insert({device->id(), buffer}); + { + std::lock_guard lock(buffer_mtx); + buffers.insert({device->id(), buffer}); + } + std::lock_guard lock(shape_mtx); shapes.insert({device->id(), shape}); } inline DeviceBuffer get_buffer_for_device(Device* device) const { - std::lock_guard lock(mtx); + std::lock_guard lock(buffer_mtx); TT_ASSERT(buffers.find(device->id()) != buffers.end(), "Buffer not found for device " + std::to_string(device->id())); TT_ASSERT(buffers.at(device->id())->device() == device, "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); return buffers.at(device->id()); } inline DeviceBuffer& get_buffer_for_device(Device* device) { - std::lock_guard lock(mtx); + std::lock_guard lock(buffer_mtx); TT_ASSERT(buffers.find(device->id()) != buffers.end(), "Buffer not found for device " + std::to_string(device->id())); TT_ASSERT(buffers.at(device->id())->device() == device, "Mismatch between device derived from buffer and device derived from MultiDeviceStorage."); return buffers.at(device->id()); } inline DeviceBuffer get_buffer_for_device_id(uint32_t device_id) const { - std::lock_guard lock(mtx); + std::lock_guard lock(buffer_mtx); return buffers.at(device_id); } inline Shape get_tensor_shape_for_device(Device* device) const { - std::lock_guard lock(mtx); + std::lock_guard lock(shape_mtx); TT_ASSERT(shapes.find(device->id()) != shapes.end(), "Shape not found for device " + std::to_string(device->id())); return shapes.at(device->id()); } - uint32_t num_buffers() const { - std::lock_guard lock(mtx); + inline uint32_t num_buffers() const { + std::lock_guard lock(buffer_mtx); return buffers.size(); } inline bool has_buffer_for_device(Device* device) const { - std::lock_guard lock(mtx); + std::lock_guard lock(buffer_mtx); return buffers.find(device->id()) != buffers.end(); } inline bool has_buffer_for_device_id(uint32_t device_id) const { - std::lock_guard lock(mtx); + std::lock_guard lock(buffer_mtx); return buffers.find(device_id) != buffers.end(); } }; diff --git a/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp b/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp index 9ecc86c31052..cb6db5e822d7 100644 --- a/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp +++ b/tt_eager/tt_dnn/op_library/bcast/bcast_op.cpp @@ -166,10 +166,10 @@ const operation::Hash EltwiseBinaryBroadcast::compute_program_hash( return operation::hash_operation( *this, parallelization_strategy, - input_tensors.at(0).memory_config(), - input_tensors.at(0).get_dtype(), - input_tensors.at(1).memory_config(), - input_tensors.at(1).get_dtype(), + std::get(input_tensors.at(0).storage()).memory_config(), + input_tensors.at(0).dtype(), + std::get(input_tensors.at(1).storage()).memory_config(), + input_tensors.at(1).dtype(), bcast_scalar, this->in_place); } diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp index 6fdc8edfa8de..ea091ce92695 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp @@ -267,10 +267,10 @@ const operation::Hash EltwiseBinary::compute_program_hash(const std::vectorop_type, parallelization_strategy, - input_tensor_a.get_dtype(), - input_tensor_a.memory_config(), - input_tensor_b.get_dtype(), - input_tensor_b.memory_config(), + input_tensor_a.dtype(), + std::get(input_tensor_a.storage()).memory_config(), + input_tensor_b.dtype(), + std::get(input_tensor_b.storage()).memory_config(), this->output_dtype, this->output_mem_config, this->in_place); diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index 65b89afee03c..d958fc0c1f0b 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -380,13 +380,13 @@ UnaryOpParallelizationStrategy EltwiseUnary::get_parallelization_strategy( const operation::Hash EltwiseUnary::compute_program_hash(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - const auto& input_shape = input_tensor.get_legacy_shape(); + const auto& input_shape = input_tensor.legacy_shape(); operation::Hash hash = tt::stl::hash::hash_objects_with_default_seed( typeid(*this).hash_code(), compute_volume(input_shape), - input_tensor.get_dtype(), - input_tensor.memory_config(), + input_tensor.dtype(), + std::get(input_tensor.storage()).memory_config(), this->output_mem_config); for (const auto& unary_with_param_op : this->op_chain) { diff --git a/tt_eager/tt_dnn/op_library/run_operation.cpp b/tt_eager/tt_dnn/op_library/run_operation.cpp index 788cc30adf6f..4d53c4f4ebce 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.cpp +++ b/tt_eager/tt_dnn/op_library/run_operation.cpp @@ -14,26 +14,29 @@ #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/third_party/tracy/public/tracy/Tracy.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" -#include "tt_numpy/functions.hpp" #include "tt_metal/tt_stl/reflection.hpp" +#include "tt_numpy/functions.hpp" namespace tt::tt_metal::operation { namespace detail { inline bool any_tensor_on_multi_device(const Tensors& tensors) { - return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& tensor) { return tensor.storage_type() == StorageType::MULTI_DEVICE; }); + return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& tensor) { + return tensor.storage_type() == StorageType::MULTI_DEVICE; + }); } Device* get_device(const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors) { for (auto& input_tensor : input_tensors) { - if (input_tensor.storage_type() == StorageType::DEVICE) { - return input_tensor.device(); + if (std::holds_alternative(input_tensor.tensor_attributes->storage)) { + return input_tensor.workers.at(0); } } for (auto& optional_input_tensor : optional_input_tensors) { - if (optional_input_tensor.has_value() and optional_input_tensor.value().storage_type() == StorageType::DEVICE) { - return optional_input_tensor.value().device(); + if (optional_input_tensor.has_value() and + std::holds_alternative(optional_input_tensor.value().tensor_attributes->storage)) { + return optional_input_tensor.value().workers.at(0); } } auto device = AutoFormat::GetDefaultDevice(); @@ -43,18 +46,19 @@ Device* get_device(const Tensors& input_tensors, const OptionalConstTensors& opt void validate_op_launch(Device* worker) { if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS) { - TT_FATAL(not worker->in_main_thread(), "launch_op or launch_with_autoformat must be used when running in async mode."); + TT_FATAL( + not worker->in_main_thread(), + "launch_op or launch_with_autoformat must be used when running in async mode."); } } -template +template void override_addresses( const OverrideAddressesCallback& override_addresses_callback, - const Program &program, + const Program& program, const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, - const OutputTensors& output_tensors -) { + const OutputTensors& output_tensors) { std::vector input_buffers; for (auto& tensor : input_tensors) { input_buffers.push_back(tensor.buffer()); @@ -66,11 +70,10 @@ void override_addresses( std::vector output_buffers; for (auto& tensor : output_tensors) { - if constexpr(std::is_same_v){ + if constexpr (std::is_same_v) { auto buffer = tensor.has_value() ? tensor.value().buffer() : nullptr; output_buffers.push_back(buffer); - } - else{ + } else { output_buffers.push_back(tensor.buffer()); } } @@ -80,19 +83,18 @@ void override_addresses( template void override_addresses( const OverrideAddressesCallback& override_addresses_callback, - const Program &program, + const Program& program, const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, const Tensors& output_tensors); template void override_addresses( const OverrideAddressesCallback& override_addresses_callback, - const Program &program, + const Program& program, const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, const OptionalTensors& output_tensors); - template constexpr auto decorate_host_operation(const Function& function) { return [function](const Operation& operation, Args&&... args) { @@ -114,7 +116,7 @@ constexpr auto decorate_device_operation(const Function& function) { }; } -template +template OutputTensors run_host_operation(const HostOperation& operation, const Tensors& input_tensors) { ZoneScopedN("TT_DNN_HOST_OP"); uint32_t op_id = assign_id(); @@ -128,11 +130,12 @@ OutputTensors run_host_operation(const HostOperation& operation, } template Tensors run_host_operation(const HostOperation& operation, const Tensors& input_tensors); -template OptionalTensors run_host_operation(const HostOperation& operation, const Tensors& input_tensors); +template OptionalTensors run_host_operation( + const HostOperation& operation, const Tensors& input_tensors); inline const auto USE_FAST_DISPATCH = std::getenv("TT_METAL_SLOW_DISPATCH_MODE") == nullptr; -template +template OutputTensors run_device_operation( std::reference_wrapper queue, const DeviceOperation& operation, @@ -171,10 +174,12 @@ OutputTensors run_device_operation( } if (not cache_hit) { - program_ptr = std::make_shared>(operation.create_program(input_tensors, optional_input_tensors, output_tensors)); + program_ptr = std::make_shared>( + operation.create_program(input_tensors, optional_input_tensors, output_tensors)); program_cache.insert(program_hash, program_ptr.value()); } - auto& program_with_callbacks = *(reinterpret_cast*>(program_ptr.value().get())); + auto& program_with_callbacks = + *(reinterpret_cast*>(program_ptr.value().get())); TT_ASSERT(program_with_callbacks.supports_program_cache()); if (cache_hit) { @@ -183,7 +188,11 @@ OutputTensors run_device_operation( auto override_addresses_callback = program_with_callbacks.override_addresses_callback.value(); // Deprecated override_addresses( - override_addresses_callback, program_with_callbacks.program, input_tensors, optional_input_tensors, output_tensors); + override_addresses_callback, + program_with_callbacks.program, + input_tensors, + optional_input_tensors, + output_tensors); } if (program_with_callbacks.override_runtime_arguments_callback.has_value()) { @@ -222,18 +231,20 @@ OutputTensors run_device_operation( [&operation, &input_tensors, &optional_input_tensors, &output_tensors, queue](auto&& program) { auto device = detail::get_device(input_tensors, optional_input_tensors); using T = std::decay_t; - if constexpr (std::is_same_v> || std::is_same_v> ) { + if constexpr ( + std::is_same_v> || std::is_same_v>) { if (USE_FAST_DISPATCH) { - // Program will temporarily own the input buffers. This is required, since with Async command queues, the input - // tensor can preemptively be deallocted on device, unless program maintains explicit ownership. - // This invocation of the program will give up ownership once its enqueued. - for (const auto& input_tensor: input_tensors) { + // Program will temporarily own the input buffers. This is required, since with Async command + // queues, the input tensor can preemptively be deallocted on device, unless program maintains + // explicit ownership. This invocation of the program will give up ownership once its enqueued. + for (const auto& input_tensor : input_tensors) { if (input_tensor.storage_type() == StorageType::DEVICE) { AssignGlobalBufferToProgram(input_tensor.device_buffer(), program); } } for (auto& optional_input_tensor : optional_input_tensors) { - if (optional_input_tensor.has_value() and optional_input_tensor.value().storage_type() == StorageType::DEVICE) { + if (optional_input_tensor.has_value() and + optional_input_tensor.value().storage_type() == StorageType::DEVICE) { AssignGlobalBufferToProgram(optional_input_tensor.value().device_buffer(), program); } } @@ -245,10 +256,20 @@ OutputTensors run_device_operation( }, program); - TracyOpTTNNDevice(op_id, program_hash, program_cache.is_enabled(), device_id, operation, program, input_tensors, optional_input_tensors, output_tensors); + TracyOpTTNNDevice( + op_id, + program_hash, + program_cache.is_enabled(), + device_id, + operation, + program, + input_tensors, + optional_input_tensors, + output_tensors); return output_tensors; } + template Tensors run_device_operation( std::reference_wrapper queue, const DeviceOperation& operation, @@ -263,17 +284,16 @@ template OptionalTensors run_device_operation( const OptionalConstTensors& optional_input_tensors, const OptionalTensors& optional_output_tensors); - } // namespace detail -template +template OutputTensors run(const HostOperation& operation, const Tensors& input_tensors) { return detail::decorate_host_operation(detail::run_host_operation)(operation, input_tensors); } template Tensors run(const HostOperation& operation, const Tensors& input_tensors); template OptionalTensors run(const HostOperation& operation, const Tensors& input_tensors); -template +template OutputTensors run( const DeviceOperation& operation, const Tensors& input_tensors, @@ -283,15 +303,16 @@ OutputTensors run( auto device = detail::get_device(input_tensors, optional_input_tensors); #ifdef DEBUG operation.validate(input_tensors, optional_input_tensors, optional_output_tensors); -#endif detail::validate_op_launch(device); +#endif return detail::decorate_device_operation(detail::run_device_operation)( std::ref(device->command_queue(cq_id)), operation, input_tensors, optional_input_tensors, optional_output_tensors); - } +} + template Tensors run( const DeviceOperation& operation, const Tensors& input_tensors, @@ -306,7 +327,7 @@ template OptionalTensors run( const OptionalTensors& optional_output_tensors, uint8_t cq_id); -template +template OutputTensors run_without_autoformat( const DeviceOperation& operation, const Tensors& input_tensors, @@ -328,7 +349,8 @@ OutputTensors run_without_autoformat( optional_input_tensors_on_dev.reserve(optional_input_tensors.size()); for (auto& optional_input_tensor : optional_input_tensors) { if (optional_input_tensor.has_value() and optional_input_tensor.value().storage_type() != StorageType::DEVICE) { - optional_input_tensors_on_dev.push_back(AutoFormat::move_tensor_to_device(optional_input_tensor.value(), device)); + optional_input_tensors_on_dev.push_back( + AutoFormat::move_tensor_to_device(optional_input_tensor.value(), device)); } else { optional_input_tensors_on_dev.push_back(optional_input_tensor); } @@ -348,7 +370,7 @@ template OptionalTensors run_without_autoformat( const OptionalConstTensors& optional_input_tensors, uint8_t cq_id); -template +template OutputTensors run_without_autoformat( const DeviceOperation& operation, const Tensors& input_tensors, @@ -371,7 +393,8 @@ OutputTensors run_without_autoformat( optional_input_tensors_on_dev.reserve(optional_input_tensors.size()); for (auto& optional_input_tensor : optional_input_tensors) { if (optional_input_tensor.has_value() and optional_input_tensor.value().storage_type() != StorageType::DEVICE) { - optional_input_tensors_on_dev.push_back(AutoFormat::move_tensor_to_device(optional_input_tensor.value(), device)); + optional_input_tensors_on_dev.push_back( + AutoFormat::move_tensor_to_device(optional_input_tensor.value(), device)); } else { optional_input_tensors_on_dev.push_back(optional_input_tensor); } @@ -402,9 +425,6 @@ Tensors run_with_autoformat( const bool pad_c, uint8_t cq_id) { ZoneScoped; - if (detail::any_tensor_on_multi_device(input_tensors)) { - return run(operation, input_tensors, optional_input_tensors); - } Device* device = detail::get_device(input_tensors, optional_input_tensors); detail::validate_op_launch(device); auto output_shapes = operation.compute_output_shapes(input_tensors); @@ -415,7 +435,8 @@ Tensors run_with_autoformat( auto padded_input_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape(), pad_c); auto pad_input = not AutoFormat::check_input_tensor_format(input_tensor, padded_input_shape); if (pad_input) { - formatted_input_tensors.push_back(AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, pad_value, Layout::TILE)); + formatted_input_tensors.push_back( + AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, pad_value, Layout::TILE)); } else { formatted_input_tensors.push_back(input_tensor); } @@ -429,7 +450,8 @@ Tensors run_with_autoformat( auto padded_input_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape(), pad_c); auto pad_input = not AutoFormat::check_input_tensor_format(input_tensor, padded_input_shape); if (pad_input) { - formatted_optional_input_tensors.push_back(AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, pad_value, Layout::TILE)); + formatted_optional_input_tensors.push_back( + AutoFormat::format_input_tensor(input_tensor, device, padded_input_shape, pad_value, Layout::TILE)); } else { formatted_optional_input_tensors.push_back(input_tensor); } @@ -460,9 +482,6 @@ Tensors run_with_autoformat( const std::vector>& optional_input_formatting, uint8_t cq_id) { ZoneScoped; - if (detail::any_tensor_on_multi_device(input_tensors)) { - return run(operation, input_tensors, optional_input_tensors); - } Device* device = detail::get_device(input_tensors, optional_input_tensors); detail::validate_op_launch(device); auto output_shapes = operation.compute_output_shapes(input_tensors); @@ -473,7 +492,12 @@ Tensors run_with_autoformat( Tensors formatted_input_tensors; formatted_input_tensors.reserve(input_tensors.size()); for (uint32_t i = 0; i < input_tensors.size(); ++i) { - formatted_input_tensors.push_back(AutoFormat::format_input_tensor(input_tensors[i], device, input_formatting[i].pad_shape, input_formatting[i].pad_value, input_formatting[i].target_layout)); + formatted_input_tensors.push_back(AutoFormat::format_input_tensor( + input_tensors[i], + device, + input_formatting[i].pad_shape, + input_formatting[i].pad_value, + input_formatting[i].target_layout)); } OptionalConstTensors formatted_optional_input_tensors; @@ -483,7 +507,12 @@ Tensors run_with_autoformat( auto& input_tensor = optional_input_tensors[i].value(); TT_ASSERT(optional_input_formatting[i].has_value()); auto& input_formatting = optional_input_formatting[i].value(); - formatted_optional_input_tensors.push_back(AutoFormat::format_input_tensor(input_tensor, device, input_formatting.pad_shape, input_formatting.pad_value, input_formatting.target_layout)); + formatted_optional_input_tensors.push_back(AutoFormat::format_input_tensor( + input_tensor, + device, + input_formatting.pad_shape, + input_formatting.pad_value, + input_formatting.target_layout)); } else { formatted_optional_input_tensors.push_back(optional_input_tensors[i]); } @@ -498,7 +527,8 @@ Tensors run_with_autoformat( formatted_optional_input_tensors.clear(); for (auto i = 0; i < output_tensors.size(); ++i) { - output_tensors[i] = AutoFormat::format_output_tensor(output_tensors[i], output_shapes[i], device, output_layouts[i]); + output_tensors[i] = + AutoFormat::format_output_tensor(output_tensors[i], output_shapes[i], device, output_layouts[i]); } return output_tensors; @@ -509,8 +539,7 @@ void launch_with_autoformat( const Tensors input_tensors, Tensors& output_tensors, const OptionalConstTensors optional_input_tensors, - const OptionalTensors optional_output_tensors -) { + const OptionalTensors optional_output_tensors) { // Mark each output tensor as having dynamic storage (can be on host or device, depending // on autoformat behaviour). Multi device tensors do not support dynamic storage. for (auto& output_tensor : output_tensors) { @@ -525,28 +554,33 @@ void launch_op( Tensors& output_tensors, const OptionalConstTensors optional_input_tensors, const OptionalTensors optional_output_tensors, - bool enable_autoformat_device -) { + bool enable_autoformat_device) { // Send host side op compile and run to the worker queue // Assert to ensure that worker threads are specified. ZoneScopedN("LaunchOp"); auto& workers = output_tensors.at(0).workers; std::size_t workers_size = workers.size(); - if (not enable_autoformat_device and workers.empty()) { - // Run on the host + if (not enable_autoformat_device and workers.empty() or not workers.at(0)->in_main_thread()) { + // Run in main thread or immediately in worker thread output_tensors = op_func(input_tensors, optional_input_tensors, optional_output_tensors); return; } for (auto& output_tensor : output_tensors) { - TT_FATAL(output_tensor.workers.size(), "Worker threads must be specified for outputs populated by launch_op. This API can only be used for creating output tensors on device."); - TT_FATAL(output_tensor.workers == workers, "Worker threads must be consistent across all outputs populated by launch_op."); + TT_FATAL( + output_tensor.workers.size(), + "Worker threads must be specified for outputs populated by launch_op. This API can only be used for " + "creating output tensors on device."); + TT_FATAL( + output_tensor.workers == workers, + "Worker threads must be consistent across all outputs populated by launch_op."); } validate_worker_modes(workers); // Record ref counts for all tensors before pushing to worker queue. std::vector input_tensor_ref_count = std::vector(input_tensors.size()); std::vector optional_input_tensor_ref_count = std::vector(optional_input_tensors.size()); std::vector output_tensor_ref_count = std::vector(output_tensors.size()); - std::vector optional_output_tensor_ref_count = std::vector(optional_output_tensors.size());; + std::vector optional_output_tensor_ref_count = std::vector(optional_output_tensors.size()); + ; std::vector async_safe_input_tensors = std::vector(input_tensors.size()); std::vector> async_safe_optional_input_tensors = {}; @@ -560,10 +594,11 @@ void launch_op( } for (int i = 0; i < optional_input_tensors.size(); i++) { if (optional_input_tensors[i].has_value()) { - async_safe_optional_input_tensors.push_back(copy_borrowed_tensor_in_async_mode(workers.at(0), optional_input_tensors[i].value())); - optional_input_tensor_ref_count[i] = async_safe_optional_input_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); - } - else { + async_safe_optional_input_tensors.push_back( + copy_borrowed_tensor_in_async_mode(workers.at(0), optional_input_tensors[i].value())); + optional_input_tensor_ref_count[i] = + async_safe_optional_input_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); + } else { async_safe_optional_input_tensors.push_back(std::nullopt); optional_input_tensor_ref_count[i] = 0; } @@ -573,9 +608,9 @@ void launch_op( } for (int i = 0; i < optional_output_tensors.size(); i++) { if (optional_output_tensors[i].has_value()) { - optional_output_tensor_ref_count[i] = optional_output_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); - } - else { + optional_output_tensor_ref_count[i] = + optional_output_tensors[i].value().tensor_attributes->record_main_thread_ref_count(); + } else { optional_output_tensor_ref_count[i] = 0; } } @@ -586,14 +621,18 @@ void launch_op( if (workers_size == 1) { // Single worker per tensor and. for (int i = 0; i < async_safe_input_tensors.size(); i++) { - if (async_safe_input_tensors.at(i).get_workers().size() and async_safe_input_tensors.at(i).get_workers().at(0) != workers.at(0)) { - // This input has a worker assigned that doesn't match the worker of the output being created (its shared). + if (async_safe_input_tensors.at(i).get_workers().size() and + async_safe_input_tensors.at(i).get_workers().at(0) != workers.at(0)) { + // This input has a worker assigned that doesn't match the worker of the output being created (its + // shared). async_safe_input_tensors.at(i).tensor_attributes->num_sibling_workers_sharing_tensor++; cross_worker_input_tensor_idx.insert(i); } } for (int i = 0; i < async_safe_optional_input_tensors.size(); i++) { - if (async_safe_optional_input_tensors.at(i).has_value() and async_safe_optional_input_tensors.at(i).value().get_workers().size() and async_safe_optional_input_tensors.at(i).value().get_workers().at(0) != workers.at(0)) { + if (async_safe_optional_input_tensors.at(i).has_value() and + async_safe_optional_input_tensors.at(i).value().get_workers().size() and + async_safe_optional_input_tensors.at(i).value().get_workers().at(0) != workers.at(0)) { async_safe_optional_input_tensors.at(i).value().tensor_attributes->num_sibling_workers_sharing_tensor++; cross_worker_optional_input_tensor_idx.insert(i); } @@ -602,89 +641,98 @@ void launch_op( { ZoneScopedN("PushOpToWorkers"); - auto work_lambda = std::make_shared>([workers_size, op_func, optional_output_tensors, async_safe_optional_input_tensors, inputs = async_safe_input_tensors, outputs = output_tensors, shared_input_idx = cross_worker_input_tensor_idx, shared_optional_input_idx = cross_worker_optional_input_tensor_idx] (Device* target_device) mutable { - std::vector input_shards = std::vector(inputs.size(), Tensor()); - std::vector> optional_input_shards = {}; - std::vector> optional_output_shards = {}; - // Initialize all optional_outputs to std::nullopt - optional_output_shards.resize(optional_output_tensors.size()); - - { - ZoneScopedN("CreateShards"); - for (int i = 0; i < input_shards.size(); i++) { - input_shards[i] = get_shard_for_device(inputs[i], target_device); - } - - for (auto& input : async_safe_optional_input_tensors) { - if (input.has_value()) { - optional_input_shards.push_back(get_shard_for_device(input.value(), target_device)); + auto work_lambda = std::make_shared>( + [workers_size, + op_func, + optional_output_tensors, + async_safe_optional_input_tensors, + inputs = async_safe_input_tensors, + outputs = output_tensors, + shared_input_idx = cross_worker_input_tensor_idx, + shared_optional_input_idx = cross_worker_optional_input_tensor_idx](Device* target_device) mutable { + std::vector input_shards = std::vector(inputs.size(), Tensor()); + std::vector> optional_input_shards = {}; + std::vector> optional_output_shards = {}; + // Initialize all optional_outputs to std::nullopt + optional_output_shards.resize(optional_output_tensors.size()); + + { + ZoneScopedN("CreateShards"); + for (int i = 0; i < input_shards.size(); i++) { + input_shards[i] = get_shard_for_device(inputs[i], target_device); } - else { - optional_input_shards.push_back(std::nullopt); + + for (auto& input : async_safe_optional_input_tensors) { + if (input.has_value()) { + optional_input_shards.push_back(get_shard_for_device(input.value(), target_device)); + } else { + optional_input_shards.push_back(std::nullopt); + } } - } - for (std::size_t optional_output_idx = 0; optional_output_idx < optional_output_tensors.size(); optional_output_idx++) { - if (optional_output_tensors[optional_output_idx].has_value()) { - optional_output_shards[optional_output_idx] = get_shard_for_device(optional_output_tensors[optional_output_idx].value(), target_device); + for (std::size_t optional_output_idx = 0; optional_output_idx < optional_output_tensors.size(); + optional_output_idx++) { + if (optional_output_tensors[optional_output_idx].has_value()) { + optional_output_shards[optional_output_idx] = get_shard_for_device( + optional_output_tensors[optional_output_idx].value(), target_device); + } } } - } - auto local_tensors = op_func(input_shards, optional_input_shards, optional_output_shards); + auto local_tensors = op_func(input_shards, optional_input_shards, optional_output_shards); - { - ZoneScopedN("OpPostProcess"); - // Release shared ownership of tensors belonging to other workers. - // If the workers for this tensor are stalled to deallocate - for (auto& shared_input : shared_input_idx) { - inputs.at(shared_input).tensor_attributes->num_sibling_workers_sharing_tensor--; - } - - for (auto& shared_optional_input : shared_optional_input_idx) { - async_safe_optional_input_tensors.at(shared_optional_input).value().tensor_attributes->num_sibling_workers_sharing_tensor--; - } - - for (int i = 0; i < local_tensors.size(); i++) { - if (local_tensors.at(i).storage_type() == StorageType::OWNED) { - TT_ASSERT(outputs.at(i).tensor_attributes->dynamic_storage, "launch_with_autoformat must be used if output tensor for op can be placed on host."); - // Make this a host side tensor - Set storage = Owned and clear workers - outputs.at(i).tensor_attributes->storage = OwnedStorage(); - outputs.at(i).workers = {}; - } - else { - outputs.at(i).tensor_attributes->dynamic_storage = false; - } - insert_buffer_and_shape_for_device(target_device, local_tensors.at(i), outputs.at(i)); - if (not target_device->id() or workers_size == 1) { - outputs.at(i).set_shape(local_tensors.at(i).get_shape()); - outputs.at(i).set_dtype(local_tensors.at(i).get_dtype()); - outputs.at(i).set_layout(local_tensors.at(i).get_layout()); + { + ZoneScopedN("OpPostProcess"); + // Release shared ownership of tensors belonging to other workers. + // If the workers for this tensor are stalled to deallocate + for (auto& shared_input : shared_input_idx) { + inputs.at(shared_input).tensor_attributes->num_sibling_workers_sharing_tensor--; } - if (workers_size == 1) { - outputs.at(i).set_populated(); + + for (auto& shared_optional_input : shared_optional_input_idx) { + async_safe_optional_input_tensors.at(shared_optional_input) + .value() + .tensor_attributes->num_sibling_workers_sharing_tensor--; } - else { - outputs.at(i).set_populated(target_device); + + for (int i = 0; i < local_tensors.size(); i++) { + if (std::holds_alternative(local_tensors.at(i).tensor_attributes->storage)) { + TT_ASSERT( + outputs.at(i).tensor_attributes->dynamic_storage, + "launch_with_autoformat must be used if output tensor for op can be placed on host."); + // Make this a host side tensor - Set storage = Owned and clear workers + outputs.at(i).tensor_attributes->storage = OwnedStorage(); + outputs.at(i).workers = {}; + } else { + outputs.at(i).tensor_attributes->dynamic_storage = false; + } + insert_buffer_and_shape_for_device(target_device, local_tensors.at(i), outputs.at(i)); + int num_workers_completed = (outputs.at(i).tensor_attributes->num_workers_completed)++; + if (not num_workers_completed) { + outputs.at(i).tensor_attributes->shape = local_tensors.at(i).tensor_attributes->shape; + outputs.at(i).tensor_attributes->dtype = local_tensors.at(i).tensor_attributes->dtype; + outputs.at(i).tensor_attributes->layout = local_tensors.at(i).tensor_attributes->layout; + outputs.at(i).tensor_attributes->metadata_populated = true; + } } } - } - }); + }); for (auto target_device : workers) { - target_device->push_work(std::make_shared>([target_device, work_lambda] () mutable { - (*work_lambda)(target_device); - })); + target_device->push_work(std::make_shared>( + [target_device, work_lambda]() mutable { (*work_lambda)(target_device); })); } } // Update ref counts of all tensors after push was performed (done only in main thread). for (int i = 0; i < async_safe_input_tensors.size(); i++) { - async_safe_input_tensors[i].tensor_attributes->update_main_thread_ref_count(workers.at(0), input_tensor_ref_count[i]); + async_safe_input_tensors[i].tensor_attributes->update_main_thread_ref_count( + workers.at(0), input_tensor_ref_count[i]); } for (int i = 0; i < async_safe_optional_input_tensors.size(); i++) { if (async_safe_optional_input_tensors[i].has_value()) { - async_safe_optional_input_tensors[i].value().tensor_attributes->update_main_thread_ref_count(workers.at(0), optional_input_tensor_ref_count[i]); + async_safe_optional_input_tensors[i].value().tensor_attributes->update_main_thread_ref_count( + workers.at(0), optional_input_tensor_ref_count[i]); } } for (int i = 0; i < output_tensors.size(); i++) { @@ -692,37 +740,53 @@ void launch_op( } for (int i = 0; i < optional_output_tensors.size(); i++) { if (optional_output_tensors[i].has_value()) { - optional_output_tensors[i].value().tensor_attributes->update_main_thread_ref_count(workers.at(0), optional_output_tensor_ref_count[i]); + optional_output_tensors[i].value().tensor_attributes->update_main_thread_ref_count( + workers.at(0), optional_output_tensor_ref_count[i]); } } } -void validate_workers_and_storage(const std::vector& inputs, const std::vector>& optional_inputs, const std::vector& workers) { +void validate_workers_and_storage( + const std::vector& inputs, + const std::vector>& optional_inputs, + const std::vector& workers) { bool single_device_storage = false; bool multi_device_storage = false; - // Verify that storage types are consistent - cannot mix single and multi-device storage. For multi-device tensors, ensure that workers are specified, since they cannot be inferred. - // This means that launch_op/launch_with_autoformat cannot be called with MultiDeviceHostStorage. - for (const auto& input: inputs) { - if (std::holds_alternative(input.tensor_attributes->storage) or std::holds_alternative(input.tensor_attributes->storage)) { + // Verify that storage types are consistent - cannot mix single and multi-device storage. For multi-device tensors, + // ensure that workers are specified, since they cannot be inferred. This means that + // launch_op/launch_with_autoformat cannot be called with MultiDeviceHostStorage. + for (const auto& input : inputs) { + if (std::holds_alternative(input.tensor_attributes->storage) or + std::holds_alternative(input.tensor_attributes->storage)) { single_device_storage |= true; - } else if (std::holds_alternative(input.tensor_attributes->storage) or std::holds_alternative(input.tensor_attributes->storage)) { + } else if ( + std::holds_alternative(input.tensor_attributes->storage) or + std::holds_alternative(input.tensor_attributes->storage)) { multi_device_storage |= true; } } for (auto& input : optional_inputs) { if (input.has_value()) { - if (std::holds_alternative(input.value().tensor_attributes->storage) or std::holds_alternative(input.value().tensor_attributes->storage)) { + if (std::holds_alternative(input.value().tensor_attributes->storage) or + std::holds_alternative(input.value().tensor_attributes->storage)) { single_device_storage |= true; - } else if (std::holds_alternative(input.value().tensor_attributes->storage) or std::holds_alternative(input.value().tensor_attributes->storage)) { + } else if ( + std::holds_alternative(input.value().tensor_attributes->storage) or + std::holds_alternative(input.value().tensor_attributes->storage)) { multi_device_storage |= true; } } } - TT_FATAL(not (single_device_storage and multi_device_storage), "Cannot mix single and multi-device tensors when calling launch op!"); + TT_FATAL( + not(single_device_storage and multi_device_storage), + "Cannot mix single and multi-device tensors when calling launch op!"); if (multi_device_storage) { - TT_FATAL(workers.size(), "Workers must be specified when calling launch_op with with multi-device tensors. Workers cannot be inferred in this case."); + TT_FATAL( + workers.size(), + "Workers must be specified when calling launch_op with with multi-device tensors. Workers cannot be " + "inferred in this case."); } } @@ -760,10 +824,13 @@ std::vector get_workers_for_op_output( // Workers not specified - inputs are on host and not multi-device. // Use the default device from autoformat. if (not workers_for_op.size()) { - TT_FATAL(AutoFormat::GetDefaultDevice(), "Default device must be specified using AutoFormat::SetDefaultDevice, if workers are not specified for inputs to op."); + TT_FATAL( + AutoFormat::GetDefaultDevice(), + "Default device must be specified using AutoFormat::SetDefaultDevice, if workers are not specified for " + "inputs to op."); workers_for_op = {AutoFormat::GetDefaultDevice()}; } } return workers_for_op; } -} +} // namespace tt::tt_metal::operation diff --git a/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp b/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp index d21e511e99b8..c46675bcc7f3 100644 --- a/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp +++ b/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp @@ -162,11 +162,11 @@ const operation::Hash Softmax::compute_program_hash( const std::vector &input_tensors, const std::vector>& optional_input_tensors) const { return operation::hash_operation( - input_tensors.at(0).memory_config(), - input_tensors.at(0).get_dtype(), - optional_input_tensors.at(0).has_value() ? std::optional{optional_input_tensors.at(0).value().memory_config()} + std::get(input_tensors.at(0).storage()).memory_config(), + input_tensors.at(0).dtype(), + optional_input_tensors.at(0).has_value() ? std::optional{std::get(optional_input_tensors.at(0).value().storage()).memory_config()} : std::nullopt, - optional_input_tensors.at(0).has_value() ? std::optional{optional_input_tensors.at(0).value().get_dtype()} + optional_input_tensors.at(0).has_value() ? std::optional{optional_input_tensors.at(0).value().dtype()} : std::nullopt, this->output_mem_config); } diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index da1fa273b773..0af4c11bf4b2 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -292,10 +292,10 @@ const operation::Hash AttnMatmul::compute_program_hash(const std::vector this->transpose_hw, this->output_mem_config, this->output_dtype, - input_tensors.at(0).memory_config(), - input_tensors.at(0).get_dtype(), - input_tensors.at(1).memory_config(), - input_tensors.at(1).get_dtype()); + std::get(input_tensors.at(0).storage()).memory_config(), + input_tensors.at(0).dtype(), + std::get(input_tensors.at(1).storage()).memory_config(), + input_tensors.at(1).dtype()); } void GroupAttnMatmul::validate(const std::vector& input_tensors) const { @@ -502,14 +502,14 @@ const operation::Hash GroupAttnMatmul::compute_program_hash(const std::vectoroutput_mem_config.buffer_type, this->output_dtype, this->row_major, - input_tensor_a.memory_config().memory_layout, - input_tensor_a.memory_config().buffer_type, - input_tensor_a.get_dtype(), - input_tensor_a.device()->id(), - input_tensor_b.memory_config().memory_layout, - input_tensor_b.memory_config().buffer_type, - input_tensor_b.get_dtype(), - input_tensor_b.device()->id()); + std::get(input_tensor_a.storage()).memory_config().memory_layout, + std::get(input_tensor_a.storage()).memory_config().buffer_type, + input_tensor_a.dtype(), + std::get(input_tensor_b.storage()).buffer->device()->id(), + std::get(input_tensor_b.storage()).memory_config().memory_layout, + std::get(input_tensor_b.storage()).memory_config().buffer_type, + input_tensor_b.dtype(), + std::get(input_tensor_b.storage()).buffer->device()->id()); } // SSM eltwise mul diff --git a/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp b/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp index 2a06d74f1f0a..1d3a6be8798a 100644 --- a/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp +++ b/tt_eager/tt_dnn/op_library/transpose/transpose_op.cpp @@ -156,9 +156,9 @@ tt::stl::reflection::Attributes Transpose::attributes() const { const operation::Hash Transpose::compute_program_hash( const std::vector &input_tensors) const { auto input_tensor = input_tensors.at(0); - auto input_mem_config = input_tensor.memory_config(); + auto input_mem_config = std::get(input_tensor.storage()).memory_config(); auto output_mem_config = this->output_mem_config; - auto dtype = input_tensor.get_dtype(); + auto dtype = input_tensor.dtype(); return operation::hash_operation( input_mem_config, output_mem_config, dtype, this->dim, get_parallelization_strategy(input_tensors)); } diff --git a/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp b/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp index b2482bffa2a6..b8f437d21387 100644 --- a/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp @@ -147,19 +147,19 @@ tt::stl::reflection::Attributes Unpad::attributes() const { const operation::Hash Unpad::compute_program_hash(const std::vector &input_tensors) const { auto input_tensor = input_tensors.at(0); - auto input_mem_config = input_tensor.memory_config(); + auto input_mem_config = std::get(input_tensor.storage()).memory_config(); auto output_mem_config = this->output_mem_config; - auto dtype = input_tensor.get_dtype(); - auto num_dims = input_tensor.get_legacy_shape().rank(); + auto dtype = input_tensor.dtype(); + auto num_dims = input_tensor.shape().rank(); std::string rm_width = "TILE"; if (input_tensor.get_layout() == Layout::ROW_MAJOR) { - rm_width = fmt::format("{}", input_tensor.get_legacy_shape()[3]); + rm_width = fmt::format("{}", input_tensor.legacy_shape()[3]); } auto str = operation::hash_operation( num_dims, - input_tensor.get_layout(), + input_tensor.layout(), input_mem_config.memory_layout, input_mem_config.buffer_type, output_mem_config.memory_layout, diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index 235f4f7b0921..7345da4c3360 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -18,7 +18,7 @@ set(TT_METAL_OBJECTS add_library(tt_metal ${TT_METAL_OBJECTS}) if(BUILD_SHARED_LIBS) - target_link_libraries(tt_metal PUBLIC device) + target_link_libraries(tt_metal PUBLIC device metal_common_libs) add_dependencies(tt_metal umd_device) else() target_link_libraries(tt_metal PUBLIC ${UMD_STATIC_LIB} metal_common_libs) diff --git a/tt_metal/detail/tt_metal.hpp b/tt_metal/detail/tt_metal.hpp index 507a58a3aa29..bcc80005d875 100644 --- a/tt_metal/detail/tt_metal.hpp +++ b/tt_metal/detail/tt_metal.hpp @@ -493,5 +493,17 @@ namespace tt::tt_metal{ specified_core_spec ); } + + inline void SynchronizeWorkerThreads(const std::vector& workers) { + // Push empty work to threads and ensure its been picked up + static auto empty_work = std::make_shared>([](){}); + for (auto target_device : workers) { + target_device->work_executor.push_work(empty_work); + } + // Block until work has been picked up, to flush the queue + for (auto target_device : workers) { + while(not target_device->work_executor.worker_queue.empty()); + } + } } } diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 4d36a99e41d3..6e9892c130c4 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -69,8 +69,8 @@ bool ActiveDevices::is_device_active(chip_id_t id) { } Device::Device( - chip_id_t device_id, const uint8_t num_hw_cqs, size_t l1_small_size, const std::vector &l1_bank_remap, bool minimal) : - id_(device_id), num_hw_cqs_(num_hw_cqs), work_executor(device_id) { + chip_id_t device_id, const uint8_t num_hw_cqs, size_t l1_small_size, const std::vector &l1_bank_remap, bool minimal, uint32_t worker_core) : + id_(device_id), num_hw_cqs_(num_hw_cqs), worker_thread_core(worker_core), work_executor(worker_core, device_id) { ZoneScoped; TT_ASSERT(num_hw_cqs > 0 and num_hw_cqs < 3, "num_hw_cqs can be between 1 and 2"); this->build_key_ = tt::Cluster::instance().get_harvesting_mask(device_id); diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index ade5235ae9f3..12df80a6bee1 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -77,7 +77,8 @@ class Device { const uint8_t num_hw_cqs, std::size_t l1_small_size, const std::vector &l1_bank_remap = {}, - bool minimal = false); + bool minimal = false, + uint32_t worker_core = 0); ~Device(); @@ -277,6 +278,7 @@ class Device { // Work Executor for this device - can asynchronously process host side work for // all tasks scheduled on this device WorkExecutor work_executor; + uint32_t worker_thread_core; std::unique_ptr sysmem_manager_; uint8_t num_hw_cqs_; diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 7a84851109f3..8c061bd40eeb 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -1238,7 +1238,7 @@ HWCommandQueue::HWCommandQueue(Device* device, uint32_t id, NOC noc_index) : std::thread completion_queue_thread = std::thread(&HWCommandQueue::read_completion_queue, this); this->completion_queue_thread = std::move(completion_queue_thread); // Set the affinity of the completion queue reader. - set_device_thread_affinity(this->completion_queue_thread, device->id()); + set_device_thread_affinity(this->completion_queue_thread, device->worker_thread_core); this->expected_num_workers_completed = 0; } @@ -1932,24 +1932,29 @@ void HWCommandQueue::read_completion_queue() { }); } if (this->num_entries_in_completion_q > this->num_completed_completion_q_reads) { + ZoneScopedN("CompletionQueueReader"); uint32_t num_events_to_read = this->num_entries_in_completion_q - this->num_completed_completion_q_reads; for (uint32_t i = 0; i < num_events_to_read; i++) { - std::variant read_descriptor = - *(this->issued_completion_q_reads.pop()); - - this->manager.completion_queue_wait_front( - this->id, this->exit_condition); // CQ DISPATCHER IS NOT HANDSHAKING WITH HOST RN - + ZoneScopedN("CompletionQueuePopulated"); + std::variant read_descriptor = *(this->issued_completion_q_reads.pop()); + { + ZoneScopedN("CompletionQueueWait"); + this->manager.completion_queue_wait_front(this->id, this->exit_condition); // CQ DISPATCHER IS NOT HANDSHAKING WITH HOST RN + } if (this->exit_condition) { // Early exit return; } std::visit( - [&](auto&& read_descriptor) { + [&](auto&& read_descriptor) + { using T = std::decay_t; if constexpr (std::is_same_v) { + ZoneScopedN("CompletionQueueReadData"); this->copy_into_user_space(read_descriptor, mmio_device_id, channel); - } else if constexpr (std::is_same_v) { + } + else if constexpr (std::is_same_v) { + ZoneScopedN("CompletionQueueReadEvent"); uint32_t read_ptr = this->manager.get_completion_queue_read_ptr(this->id); thread_local static std::vector dispatch_cmd_and_event( (sizeof(CQDispatchCmd) + dispatch_constants::EVENT_PADDED_SIZE) / sizeof(uint32_t)); diff --git a/tt_metal/impl/dispatch/work_executor.hpp b/tt_metal/impl/dispatch/work_executor.hpp index 323f5e7f7e29..a164f3a8795d 100644 --- a/tt_metal/impl/dispatch/work_executor.hpp +++ b/tt_metal/impl/dispatch/work_executor.hpp @@ -44,12 +44,11 @@ enum class WorkerState { IDLE = 2, }; -inline void set_device_thread_affinity(std::thread& thread_, int managed_device_id) { +inline void set_device_thread_affinity(std::thread& thread_, int cpu_core_for_worker) { // Bind a device worker/reader thread to a CPU core, determined using round-robin. - static int num_online_cores = sysconf(_SC_NPROCESSORS_ONLN); cpu_set_t cpuset; CPU_ZERO(&cpuset); - CPU_SET(managed_device_id % num_online_cores, &cpuset); + CPU_SET(cpu_core_for_worker, &cpuset); int rc = pthread_setaffinity_np(thread_.native_handle(), sizeof(cpu_set_t), &cpuset); if (rc) { log_warning( @@ -80,7 +79,7 @@ class WorkExecutor { public: LockFreeQueue> worker_queue; - WorkExecutor(int device_id) : managed_device_id(device_id) { + WorkExecutor(int cpu_core, int device_id) : cpu_core_for_worker(cpu_core), managed_device_id(device_id) { set_process_priority(0); if (this->work_executor_mode == WorkExecutorMode::ASYNCHRONOUS) { this->set_worker_queue_mode(this->worker_queue_mode); @@ -89,14 +88,16 @@ class WorkExecutor { } WorkExecutor(WorkExecutor&& other) { - worker_state = other.worker_state; - managed_device_id = other.managed_device_id; + worker_state = std::move(other.worker_state); + cpu_core_for_worker = std::move(other.managed_device_id); + managed_device_id = std::move(other.managed_device_id); } WorkExecutor& operator=(WorkExecutor &&other) { if (this != &other) { worker_state = std::move(other.worker_state); managed_device_id = std::move(other.managed_device_id); + cpu_core_for_worker = std::move(other.cpu_core_for_worker); } return *this; } @@ -218,6 +219,7 @@ class WorkExecutor { private: std::thread worker_thread; WorkerState worker_state = WorkerState::IDLE; + int cpu_core_for_worker = 0; int managed_device_id = 0; std::condition_variable cv; std::mutex cv_mutex; @@ -228,7 +230,7 @@ class WorkExecutor { this->worker_thread = std::thread(&WorkExecutor::run_worker, this); this->worker_queue.worker_thread_id = std::hash{}(this->worker_thread.get_id()); // Bind a worker tied to a device to a specific CPU core in round robin fashion. Thread affinity == Better Perf. - set_device_thread_affinity(this->worker_thread, this->managed_device_id); + set_device_thread_affinity(this->worker_thread, this->cpu_core_for_worker); } inline void stop_worker() { diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 2038c3b4baea..665de904b462 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -4,6 +4,7 @@ #include "tt_metal/detail/tt_metal.hpp" +#include #include #include #include @@ -171,6 +172,78 @@ std::vector devices; } // namespace device_pool +namespace device_cpu_allocator { +std::unordered_map> get_cpu_cores_per_numa_node(std::unordered_set &free_cores) { + std::unordered_map> cpu_cores_per_numa_node = {}; + if (numa_available() != -1) { + // Host has NUMA enabled. Group CPU IDs by the NUMA nodes they belong to. + for (int cpu = 0; cpu < numa_num_configured_cpus(); ++cpu) { + int node = numa_node_of_cpu(cpu); + if (cpu_cores_per_numa_node.find(node) == cpu_cores_per_numa_node.end()) { + cpu_cores_per_numa_node.insert({node, {}}); + } + free_cores.insert(cpu); + cpu_cores_per_numa_node.at(node).push_back(cpu); + } + } else { + // Host does not have NUMA. Place all CPU Ids under a single node (0). + log_warning(tt::LogMetal, "Host does not use NUMA. May see reduced performance."); + for (int cpu = 0; cpu < sysconf(_SC_NPROCESSORS_ONLN); ++cpu) { + free_cores.insert(cpu); + } + } + return cpu_cores_per_numa_node; +} + +int get_cpu_core_for_device_worker_thread( + int mmio_controlled_device_id, + const std::unordered_map> &cpu_cores_per_numa_node, + std::unordered_set &free_cores) { + int core_assigned_to_device = 0; + if (numa_available() != -1) { + // Get NUMA node that the current device is mapped to through UMD + int numa_node_for_device = tt::Cluster::instance().get_numa_node_for_device(mmio_controlled_device_id); + if (cpu_cores_per_numa_node.find(numa_node_for_device) != cpu_cores_per_numa_node.end()) { + // NUMA node reported by UMD exists on host. Choose a core on this numa-node using round robin policy + int num_cores_in_numa_node = cpu_cores_per_numa_node.at(numa_node_for_device).size(); + core_assigned_to_device = + cpu_cores_per_numa_node.at(numa_node_for_device).at(mmio_controlled_device_id % num_cores_in_numa_node); + } else { + // NUMA node reported by UMD does not exist on host. Use round-robin binding policy for this worker thread. + log_warning( + tt::LogMetal, + "NUMA node {} for device {} does not exist on host.", + numa_node_for_device, + mmio_controlled_device_id); + core_assigned_to_device = mmio_controlled_device_id % sysconf(_SC_NPROCESSORS_ONLN); + } + } else { + // System does not use NUMA. Use-round robin binding strategy. + core_assigned_to_device = mmio_controlled_device_id % sysconf(_SC_NPROCESSORS_ONLN); + } + free_cores.erase(core_assigned_to_device); + return core_assigned_to_device; +} + +void bind_current_thread_to_free_cores(const std::unordered_set &free_cores) { + cpu_set_t cpuset; + pthread_t current_thread = pthread_self(); + CPU_ZERO(&cpuset); + + for (const auto &free_core : free_cores) { + CPU_SET(free_core, &cpuset); + } + int rc = pthread_setaffinity_np(current_thread, sizeof(cpu_set_t), &cpuset); + if (rc) { + log_warning( + tt::LogMetal, + "Unable to bind main thread to free CPU cores. May see performance degradation. Error Code: {}", + rc); + } +} + +} // namespace device_cpu_allocator + namespace detail { std::map CreateDevices( @@ -185,10 +258,14 @@ std::map CreateDevices( if (active_devices.find(mmio_device_id) == active_devices.end()) { for (const auto &mmio_controlled_device_id : tt::Cluster::instance().get_devices_controlled_by_mmio_device(mmio_device_id)) { - // if (mmio_controlled_device_id != mmio_device_id) { - // continue; - // } - Device *dev = new Device(mmio_controlled_device_id, num_hw_cqs, l1_small_size, l1_bank_remap); + int core_assigned_to_device = mmio_controlled_device_id % sysconf(_SC_NPROCESSORS_ONLN); + Device *dev = new Device( + mmio_controlled_device_id, + num_hw_cqs, + l1_small_size, + l1_bank_remap, + false, + core_assigned_to_device); active_devices.insert({mmio_controlled_device_id, dev}); detail::InitDeviceProfiler(dev); } @@ -666,12 +743,10 @@ void CompileProgram(Device *device, Program &program) { } void AllocateBuffer(Buffer *buffer, bool bottom_up) { - detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); EnqueueAllocateBuffer(buffer->device()->command_queue(), buffer, bottom_up, false); } void DeallocateBuffer(Buffer *buffer) { - detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); EnqueueDeallocateBuffer( buffer->device()->command_queue(), *(buffer->device()->allocator_), @@ -681,7 +756,6 @@ void DeallocateBuffer(Buffer *buffer) { } void GetBufferAddress(const Buffer *buffer, uint32_t *address_on_host) { - detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); EnqueueGetBufferAddr(buffer->device()->command_queue(), address_on_host, buffer, false); } @@ -720,7 +794,8 @@ Device *CreateDevice( const size_t l1_small_size, const std::vector &l1_bank_remap) { ZoneScoped; - Device *dev = new Device(device_id, num_hw_cqs, l1_small_size, l1_bank_remap); + int core_assigned_to_device = device_id % sysconf(_SC_NPROCESSORS_ONLN); + Device *dev = new Device(device_id, num_hw_cqs, l1_small_size, l1_bank_remap, false, core_assigned_to_device); tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true); detail::InitDeviceProfiler(dev); return dev; diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp index 5569bd65ab4e..243b6ef4808a 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp @@ -296,10 +296,10 @@ const operation::Hash Binary::compute_program_hash(const std::vector& in typeid(*this).hash_code(), this->program_config, program_type, - input_tensor_a.get_dtype(), - input_tensor_a.memory_config(), - input_tensor_b.get_dtype(), - input_tensor_b.memory_config()); + input_tensor_a.dtype(), + std::get(input_tensor_a.storage()).memory_config(), + input_tensor_b.dtype(), + std::get(input_tensor_b.storage()).memory_config()); return hash; } From 9c6bf9f9a1021be63d92155025ca990862a9d7bf Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Tue, 4 Jun 2024 23:53:17 +0000 Subject: [PATCH 36/53] #0: Relax input data type constraints `ssm_eltwise_mul` --- .../tt_dnn/op_library/transformer_tms/transformer_tms.cpp | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index 0af4c11bf4b2..a742be885e29 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -544,11 +544,8 @@ void SSMEltwiseMul::validate(const std::vector& input_tensors) const { "Unsupported data format for input a!"); TT_FATAL( input_tensor_b.get_dtype() == tt::tt_metal::DataType::BFLOAT16 || - input_tensor_a.get_dtype() == tt::tt_metal::DataType::BFLOAT8_B, + input_tensor_b.get_dtype() == tt::tt_metal::DataType::BFLOAT8_B, "Unsupported data format for input b!"); - TT_FATAL( - input_tensor_a.get_dtype() == input_tensor_b.get_dtype(), - "Input a and input b must have the same data format!"); TT_FATAL( this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, From 550c6905f16abedb88f7b7344affbd7e6852f4c3 Mon Sep 17 00:00:00 2001 From: Evan Smal Date: Thu, 23 May 2024 20:12:59 +0000 Subject: [PATCH 37/53] #0: Add support for bfloat8 activations in Mamba --- models/demos/mamba/tests/test_full_model.py | 2 +- models/demos/mamba/tt/full_model.py | 18 ++++++++-- models/demos/mamba/tt/mamba_block.py | 37 +++++++++++++++------ models/demos/mamba/tt/mamba_one_step_ssm.py | 28 +++++++++++++--- models/demos/mamba/tt/model_config.py | 1 + models/demos/mamba/tt/residual_block.py | 2 +- 6 files changed, 68 insertions(+), 20 deletions(-) diff --git a/models/demos/mamba/tests/test_full_model.py b/models/demos/mamba/tests/test_full_model.py index 6790dd186524..c0a5fac3c6c9 100644 --- a/models/demos/mamba/tests/test_full_model.py +++ b/models/demos/mamba/tests/test_full_model.py @@ -87,7 +87,7 @@ def run_inference( ( "state-spaces/mamba-2.8b", 32, - 0.984, + 0.98, 64, 1, ), diff --git a/models/demos/mamba/tt/full_model.py b/models/demos/mamba/tt/full_model.py index a06ad6b9f800..509eb6ff6d32 100644 --- a/models/demos/mamba/tt/full_model.py +++ b/models/demos/mamba/tt/full_model.py @@ -4,6 +4,7 @@ import torch import ttnn +import tt_lib as ttl from loguru import logger @@ -63,7 +64,12 @@ def load_tt_tensor( class MambaTT(torch.nn.Module): def __init__( - self, reference_model, device: ttnn.Device, configs, tt_cache_path: Optional[str] = None, num_layers=None + self, + reference_model, + device: ttnn.Device, + configs, + tt_cache_path: Optional[str] = None, + num_layers=None, ): super().__init__() self.args = reference_model.args @@ -95,6 +101,11 @@ def __init__( lambda x: x.transpose(-1, -2), tt_dtype=ttnn.bfloat16, ) + self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=ttl.tensor.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + ) def forward(self, x): assert len(x.shape) == 2, f"Mamba expects inputs to be rank 2 (was {len(x.shape)})" @@ -109,7 +120,7 @@ def forward(self, x): device=self.device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, - dtype=ttnn.bfloat16, + dtype=self.configs["dtype"]["activations"], ) for layer in self.layers: @@ -129,7 +140,8 @@ def forward(self, x): self.lm_head_weights, memory_config=ttnn.L1_MEMORY_CONFIG, use_1d_systolic_array=True, - core_grid=ttnn.CoreGrid(y=7, x=8), + compute_kernel_config=self.compute_kernel_config, + dtype=self.configs["dtype"]["activations"], ) x = ttnn.to_torch(x).to(torch.float32) # (1, 1, B, E) diff --git a/models/demos/mamba/tt/mamba_block.py b/models/demos/mamba/tt/mamba_block.py index d5fe4adffde9..c2fd778f8ea9 100644 --- a/models/demos/mamba/tt/mamba_block.py +++ b/models/demos/mamba/tt/mamba_block.py @@ -82,8 +82,6 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable): math_approx_mode=False, fp32_dest_acc_en=True, ) - self.core_grid_row = 4 - self.core_grid_col = 8 def forward(self, x): assert len(x.shape) == 4, "Mamba block expects inputs to be rank 4" @@ -96,7 +94,7 @@ def forward(self, x): memory_config=ttnn.L1_MEMORY_CONFIG, compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, - core_grid=ttnn.CoreGrid(y=4, x=8), + dtype=self.configs["dtype"]["activations"], ) # shift the states leftward @@ -111,24 +109,38 @@ def forward(self, x): # do the convolution conv1d_wt = ttnn.to_memory_config(self.conv1d_weights[0], memory_config=self.configs["sharded_d"]) conv_state = ttnn.to_memory_config(self.conv_states[0], memory_config=self.configs["sharded_d"]) - conv_accumulator = ttnn.mul(conv_state, conv1d_wt, memory_config=self.configs["sharded_d"]) + conv_accumulator = ttnn.mul( + conv_state, conv1d_wt, memory_config=self.configs["sharded_d"], dtype=self.configs["dtype"]["activations"] + ) ttnn.deallocate(conv1d_wt) ttnn.deallocate(conv_state) for i in range(1, 4): conv1d_wt = ttnn.to_memory_config(self.conv1d_weights[i], memory_config=self.configs["sharded_d"]) conv_state = ttnn.to_memory_config(self.conv_states[i], memory_config=self.configs["sharded_d"]) - prod = ttnn.mul(conv_state, conv1d_wt, memory_config=self.configs["sharded_d"]) + prod = ttnn.mul( + conv_state, + conv1d_wt, + memory_config=self.configs["sharded_d"], + dtype=self.configs["dtype"]["activations"], + ) ttnn.deallocate(conv1d_wt) ttnn.deallocate(conv_state) - conv_out = ttnn.add(conv_accumulator, prod, memory_config=self.configs["sharded_d"]) + conv_out = ttnn.add( + conv_accumulator, + prod, + memory_config=self.configs["sharded_d"], + dtype=self.configs["dtype"]["activations"], + ) ttnn.deallocate(conv_accumulator) ttnn.deallocate(prod) conv_accumulator = conv_out conv1d_bias = ttnn.to_memory_config(self.conv1d_bias, memory_config=self.configs["sharded_d"]) - conv_out_with_bias = ttnn.add(conv_out, conv1d_bias, memory_config=self.configs["sharded_d"]) + conv_out_with_bias = ttnn.add( + conv_out, conv1d_bias, memory_config=self.configs["sharded_d"], dtype=self.configs["dtype"]["activations"] + ) ttnn.deallocate(conv_out) ttnn.deallocate(conv1d_bias) @@ -142,16 +154,21 @@ def forward(self, x): residual_connection, self.mlp_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG, - core_grid=ttnn.CoreGrid(y=4, x=8), compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, + dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(residual_connection) residual_with_silu = ttnn.silu(residual, memory_config=ttnn.L1_MEMORY_CONFIG) ttnn.deallocate(residual) - out = ttnn.mul(ssm_output, residual_with_silu, memory_config=ttnn.L1_MEMORY_CONFIG) + out = ttnn.mul( + ssm_output, + residual_with_silu, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=self.configs["dtype"]["activations"], + ) ttnn.deallocate(residual_with_silu) ttnn.deallocate(ssm_output) @@ -159,9 +176,9 @@ def forward(self, x): out, self.out_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG, - core_grid=ttnn.CoreGrid(y=4, x=8), compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, + dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(out) diff --git a/models/demos/mamba/tt/mamba_one_step_ssm.py b/models/demos/mamba/tt/mamba_one_step_ssm.py index f5d07996c783..833af10c2696 100644 --- a/models/demos/mamba/tt/mamba_one_step_ssm.py +++ b/models/demos/mamba/tt/mamba_one_step_ssm.py @@ -113,6 +113,7 @@ def forward(self, x): compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + dtype=self.configs["dtype"]["activations"], ) delta_t1 = ttnn.linear( @@ -123,10 +124,16 @@ def forward(self, x): compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(delta_t0) - delta_t2 = ttnn.softplus(delta_t1, beta=1.0, threshold=20.0, memory_config=ttnn.L1_MEMORY_CONFIG) + delta_t2 = ttnn.softplus( + delta_t1, + beta=1.0, + threshold=20.0, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) ttnn.deallocate(delta_t1) # calculate abar @@ -137,6 +144,7 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(abar0) @@ -151,7 +159,9 @@ def forward(self, x): # multiply abar and hidden_state hidden_state0 = ttnn.to_memory_config(self.tt_hidden_state, memory_config=ttnn.L1_MEMORY_CONFIG) - amulh0 = ttnn.mul(abar2, hidden_state0, memory_config=ttnn.L1_MEMORY_CONFIG) + amulh0 = ttnn.mul( + abar2, hidden_state0, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.configs["dtype"]["activations"] + ) ttnn.deallocate(abar2) ttnn.deallocate(hidden_state0) @@ -163,6 +173,7 @@ def forward(self, x): compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + dtype=self.configs["dtype"]["activations"], ) # bbar @@ -172,6 +183,7 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(delta_t2) ttnn.deallocate(B0) @@ -183,13 +195,16 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) # deallocate bbar ttnn.deallocate(bbar0) # add amulh and bmulx - hidden_state1 = ttnn.add(amulh0, bmulx0, memory_config=ttnn.L1_MEMORY_CONFIG) + hidden_state1 = ttnn.add( + amulh0, bmulx0, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.configs["dtype"]["activations"] + ) ttnn.deallocate(self.tt_hidden_state) self.tt_hidden_state = ttnn.to_memory_config(hidden_state1, memory_config=ttnn.DRAM_MEMORY_CONFIG) ttnn.deallocate(amulh0) @@ -203,6 +218,7 @@ def forward(self, x): compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + dtype=self.configs["dtype"]["activations"], ) # b,n # c * hidden_state @@ -212,6 +228,7 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(hidden_state1) ttnn.deallocate(C0) @@ -222,16 +239,17 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(C1) # x * D D = ttnn.to_memory_config(self.D, memory_config=ttnn.L1_MEMORY_CONFIG) - xD = ttnn.mul(x, D, memory_config=ttnn.L1_MEMORY_CONFIG) + xD = ttnn.mul(x, D, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.configs["dtype"]["activations"]) ttnn.deallocate(x) # add xD and x - output = ttnn.add(xD, C2, memory_config=ttnn.L1_MEMORY_CONFIG) + output = ttnn.add(xD, C2, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.configs["dtype"]["activations"]) ttnn.deallocate(C2) ttnn.deallocate(xD) diff --git a/models/demos/mamba/tt/model_config.py b/models/demos/mamba/tt/model_config.py index ac6e30a9c50a..3823034d3131 100644 --- a/models/demos/mamba/tt/model_config.py +++ b/models/demos/mamba/tt/model_config.py @@ -34,6 +34,7 @@ def create_model_config(batch_size, hidden_size): block_w=(hidden_size // (col * row)) // 32, inplace=False, ) + configs["dtype"] = {"activations": ttnn.bfloat8_b} return configs diff --git a/models/demos/mamba/tt/residual_block.py b/models/demos/mamba/tt/residual_block.py index dbe3ff1236a4..ff80dc199ef8 100644 --- a/models/demos/mamba/tt/residual_block.py +++ b/models/demos/mamba/tt/residual_block.py @@ -42,4 +42,4 @@ def forward(self, x): ttnn.deallocate(rms_norm_weights) mamba_x = self.tt_mamba_block(mamba_x) - return ttnn.add(residual, mamba_x) + return ttnn.add(residual, mamba_x, dtype=self.configs["dtype"]["activations"]) From 47227830f3367b14d89e5b3442d14f73533b333d Mon Sep 17 00:00:00 2001 From: hschoi Date: Tue, 4 Jun 2024 22:22:34 +0000 Subject: [PATCH 38/53] #9118: fix moreh_getitem validation --- .../unit_testing/misc/test_moreh_getitem.py | 7 ------- .../tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp | 4 +++- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py index 345dc51fe2bb..989c0430d544 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_getitem.py @@ -20,7 +20,6 @@ def to_output_4d_shape(shape, index_dims, index_size): return output_4d_shape -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dim", ( @@ -80,7 +79,6 @@ def test_getitem_RAW_MJOR_one_index(shape_index_dim, dtype, index_size, device): assert passing -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", ( @@ -139,7 +137,6 @@ def test_getitem_RAW_MAJOR_two_indices(shape_index_dims, dtype, index_size, devi assert passing -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", (((10, 15, 7, 80), (0, 1, 2)),), @@ -193,7 +190,6 @@ def test_getitem_RAW_MAJOR_three_indices(shape_index_dims, dtype, index_size, de assert passing -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dim", ( @@ -286,7 +282,6 @@ def test_getitem_tilized_one_index(shape_index_dim, dtype, index_size, row_major assert passing -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", ( @@ -372,7 +367,6 @@ def test_getitem_tilized_two_indices(shape_index_dims, dtype, index_size, row_ma assert passing -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", ( @@ -455,7 +449,6 @@ def test_getitem_tilized_three_indices(shape_index_dims, dtype, index_size, row_ assert passing -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape_index_dims", (((10, 15, 7, 80), (0, 1, 2, 3)),), diff --git a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp index 4aedc2407ce9..aa203229ec66 100644 --- a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp @@ -211,7 +211,9 @@ Tensor moreh_getitem( optional_output_tensors); }, new_input_tensors, - output_tensors); + output_tensors, + {}, + {output_tensor}); return output_tensors.at(0); } From 1aec13dfcf674d0fc85a26350644d5e2613a719c Mon Sep 17 00:00:00 2001 From: hschoi Date: Tue, 4 Jun 2024 22:23:24 +0000 Subject: [PATCH 39/53] #9118: fix moreh_nllloss validation --- .../unit_testing/misc/test_moreh_nll_loss.py | 2 -- .../moreh_nll_loss_backward_op.cpp | 14 ++++---------- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py index af6d27c8e717..7bd8b21160e7 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_nll_loss.py @@ -207,7 +207,6 @@ def test_moreh_nll_loss_callback(shape, reduction, none_weight, device, use_prog assert passing -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape", [ @@ -291,7 +290,6 @@ def test_moreh_nll_loss_backward( assert passing -@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @pytest.mark.parametrize( "shape", [ diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp index 9fffeb7de044..32c79199bac8 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp @@ -24,25 +24,19 @@ void MorehNllLossBackward::validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& output_tensors) const { - TT_ASSERT(input_tensors.size() == 3, "Must have 3 input tensors"); + TT_ASSERT(input_tensors.size() == 2, "Must have 2 input tensors"); TT_ASSERT(optional_input_tensors.size() == 2, "Must have 2 optional input tensors"); - auto& input_tensor = input_tensors.at(0); - auto& target_tensor = input_tensors.at(1); - auto& output_grad_tensor = input_tensors.at(2); + auto& target_tensor = input_tensors.at(0); + auto& output_grad_tensor = input_tensors.at(1); auto& weight_tensor = optional_input_tensors.at(0); auto& divisor_tensor = optional_input_tensors.at(1); auto& input_grad_tensor = output_tensors.at(0); - TT_ASSERT(input_tensor.storage_type() == StorageType::DEVICE, "Operands to nll_loss need to be on device!"); - TT_ASSERT(input_tensor.buffer() != nullptr, "Operands to nll_loss need to be allocated in buffers on device!"); - TT_ASSERT((input_tensor.get_layout() == Layout::TILE), "intput_tensor to nll_loss must be tilized"); - TT_ASSERT(input_tensor.get_dtype() == DataType::BFLOAT16); - TT_ASSERT(target_tensor.storage_type() == StorageType::DEVICE, "Operands to nll_loss need to be on device!"); TT_ASSERT(target_tensor.buffer() != nullptr, "Operands to nll_loss need to be allocated in buffers on device!"); TT_ASSERT((target_tensor.get_layout() == Layout::TILE), "target_tensor to nll_loss must be tilized"); - TT_ASSERT(target_tensor.get_dtype() == DataType::UINT32); + TT_ASSERT(target_tensor.get_dtype() == DataType::INT32); TT_ASSERT(output_grad_tensor.storage_type() == StorageType::DEVICE, "Operands to nll_loss need to be on device!"); TT_ASSERT( From b679434ce8cd33afcc2b035bac162f093b8aba08 Mon Sep 17 00:00:00 2001 From: Mohamed Bahnas <116673264+mbahnasTT@users.noreply.github.com> Date: Tue, 4 Jun 2024 18:41:43 -0700 Subject: [PATCH 40/53] Update ViT E2E number in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index bcee552db2bb..ca1b108b91e0 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ | [ResNet-50](./models/demos/resnet) (fps) | 20 | 4,400 | 7,700 | 10,000 | | [BERT-Large](./models/demos/bert) (sen/s) | 12 | 362 | 406 | 410 | | [Falcon7B-decode](./models/demos/ttnn_falcon7b) (t/s) | 32 | 135 | 135 | 140 | -| [ViT](./models/demos/grayskull/vit) (fps) | 8 | 480 | 1570 | 2000 | +| [ViT](./models/demos/grayskull/vit) (fps) | 8 | 860 | 1570 | 2000 | | [T5 small](.models/demos/grayskull/t5) (sen/s) | | 140 | | | | [Bloom](.models/demos/grayskull/functional_bloom) (sen/s) | | 70 | | | | U-Net | coming soon | | | | From 3aff6a4db297195550dafc0ad5ff1c4f6c8338e9 Mon Sep 17 00:00:00 2001 From: Radomir Djogo Date: Tue, 4 Jun 2024 23:57:30 +0000 Subject: [PATCH 41/53] #4858: enable typecast fp32 to uint16 --- .../sweep_tests/pytorch_ops.py | 9 +++++++-- .../sweep_tests/tt_lib_ops.py | 3 ++- .../eltwise_unary/eltwise_unary_op.cpp | 8 +++++--- .../eltwise_unary/eltwise_unary_op.hpp | 9 +++++---- .../csrc/tt_lib_bindings_tensor_xary_ops.cpp | 8 +++++++- .../llk_api/llk_sfpu/ckernel_sfpu_typecast.h | 16 ++++++++++++++++ .../llk_math_eltwise_unary_sfpu_typecast.h | 18 +++++++++++++----- .../eltwise_unary/typecast.h | 9 +++++++-- 8 files changed, 62 insertions(+), 18 deletions(-) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 8a588493e48a..6a8047855130 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -1331,8 +1331,13 @@ def eltwise_identity(x, *args, **kwargs): return x -def eltwise_typecast(x, *args, **kwargs): - return torch.relu(x.to(torch.int32)) # due to no uint32 support +def eltwise_typecast(x, *args, tt_output_dtype, **kwargs): + if tt_output_dtype[0] == ttl.tensor.DataType.UINT16: + return torch.clamp(x.to(torch.int32), min=0, max=65535) # due to no uint16 support + elif tt_output_dtype[0] == ttl.tensor.DataType.UINT32: + return torch.relu(x.to(torch.int32)) # due to no uint32 support + else: + return x def eltwise_rdiv(x, *args, **kwargs): diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index d7c116b794b9..e22d65583297 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -2192,13 +2192,14 @@ def eltwise_typecast( *args, device, dtype, + tt_output_dtype, layout, input_mem_config, output_mem_config, **kwargs, ): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.eltwise_typecast(t0, output_mem_config=output_mem_config) + t1 = ttl.tensor.eltwise_typecast(t0, tt_output_dtype[0], output_mem_config=output_mem_config) return tt2torch_tensor(t1) diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index d958fc0c1f0b..73b14e2b1127 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -177,6 +177,11 @@ std::pair get_op_init_and_func_parameterized( Converter::to_hex(param1))}; break; } + case UnaryOpType::TYPECAST: + op_init_and_name = { + "typecast_tile_init();", + fmt::format("typecast_tile<{1}u>({0});", idst, std::to_string((uint32_t)datatype_to_dataformat_converter((DataType)param0)))}; + break; default: TT_ASSERT(false && "unexpected parameterized type"); }; return op_init_and_name; @@ -258,9 +263,6 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, stri case UnaryOpType::NEG: op_init_and_name = {"negative_tile_init();", fmt::format("negative_tile({});", idst)}; break; - case UnaryOpType::TYPECAST: - op_init_and_name = {"typecast_tile_init();", fmt::format("typecast_tile({});", idst)}; - break; default: TT_ASSERT(false && "Undefined non-parametrized op type"); } return op_init_and_name; diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index f9f8a2521c07..2a26f0f4c5c3 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -104,7 +104,8 @@ bool is_parametrized_type(T val) { case UnaryOpType::DIV_UNARY_SFPU: case UnaryOpType::UNARY_NE: case UnaryOpType::UNARY_GT: - case UnaryOpType::UNARY_LT: return true; + case UnaryOpType::UNARY_LT: + case UnaryOpType::TYPECAST: return true; default: return false; } return false; @@ -195,7 +196,7 @@ inline Tensor run_eltwise_unary( const std::vector& ops_chain, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified"); - DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? DataType::UINT32 : input_tensor.get_dtype(); + DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? (DataType)ops_chain[0].params[0] : input_tensor.get_dtype(); bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_tensor.get_dtype() == DataType::UINT32 or @@ -241,7 +242,7 @@ inline Tensor run_eltwise_unary( const std::vector& ops_chain, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified"); - DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? DataType::UINT32 : input_tensor.get_dtype(); + DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? (DataType)ops_chain[0].params[0] : input_tensor.get_dtype(); bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_tensor.get_dtype() == DataType::UINT32 or @@ -369,7 +370,7 @@ constexpr auto rsub = make_eltwise_unary_with_param{}; constexpr auto silu = make_eltwise_unary{}; constexpr auto identity = make_eltwise_unary{}; constexpr auto identity_uint32 = make_eltwise_unary{}; -constexpr auto eltwise_typecast = make_eltwise_unary{}; +constexpr auto eltwise_typecast = make_eltwise_unary_with_param{}; constexpr auto add_unary_sfpu = make_eltwise_symmetric_binop_unary_with_param{}; constexpr auto mul_unary_sfpu = make_eltwise_symmetric_binop_unary_with_param{}; constexpr auto unary_gt = make_eltwise_unary_with_param{}; diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp index 6b9b80896470..1ffffc67ea3d 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp @@ -85,7 +85,13 @@ namespace tt::tt_metal::detail { detail::bind_unary_op(m_tensor, "i0", i0, R"doc(Computes the zeroth order modified Bessel function of the first kind applied on the elements of the input tensor ``{0}``, for the input range -10 to 10.)doc"); detail::bind_unary_op(m_tensor, "silu", silu, R"doc(Returns tensor with the silu all of elements of the input tensor ``{0}``.)doc"); detail::bind_unary_op(m_tensor, "neg", neg, R"doc(Returns tensor with the negate all of elements of the input tensor ``{0}``.)doc"); - detail::bind_unary_op(m_tensor, "eltwise_typecast", eltwise_typecast, R"doc(Returns tensor with all of the elements of the input tensor ``{0}`` typecasted from fp32 to uint32.)doc"); + + detail::bind_unary_op_with_param( + m_tensor, "eltwise_typecast", eltwise_typecast, + py::arg("tt_output_dtype"), + R"doc(Returns tensor with all of the elements of the input tensor ``{0}`` typecasted from fp32 to uint32.)doc", + R"doc("Indicates output dtype of typecast", "ttl.tensor.DataType", "")doc" + ); detail::bind_unary_op_with_param( m_tensor, "exp", py::overload_cast(&exp), diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h index b3fdd91a568b..0d2a43ff7ace 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h @@ -51,5 +51,21 @@ inline void calculate_typecast_fp16b_to_uint32() } } +template +inline void calculate_typecast_fp16b_to_uint16() +{ + #pragma GCC unroll 0 + for (int d = 0; d < ITERATIONS; d++) { + TTI_SFPENCC(0,0,0,0); + TTI_SFPLOAD(0,0,3,0); + TTI_SFPSETCC(0,0,0,0); + TTI_SFPLOADI(0,0,0); + TTI_SFPENCC(0,0,0,0); + TTI_SFP_STOCH_RND(0,0,2,0,1,14); + TTI_SFPSTORE(1,6,3,0); + dst_reg++; + } +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h index b4ac44225b69..b5a9a6bf0c3d 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h @@ -12,12 +12,20 @@ namespace ckernel { // New LLK SFPU APIs -template +template inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_unary_sfpu_0_param - (ckernel::sfpu::calculate_typecast_fp16b_to_uint32, - ckernel::sfpu::calculate_typecast_fp16b_to_uint32, - dst_index, vector_mode); + if constexpr (OUT_DTYPE == (uint32_t)DataFormat::UInt16) { + llk_math_eltwise_unary_sfpu_0_param + (ckernel::sfpu::calculate_typecast_fp16b_to_uint16, + ckernel::sfpu::calculate_typecast_fp16b_to_uint16, + dst_index, vector_mode); + } + else if constexpr (OUT_DTYPE == (uint32_t)DataFormat::UInt32) { + llk_math_eltwise_unary_sfpu_0_param + (ckernel::sfpu::calculate_typecast_fp16b_to_uint32, + ckernel::sfpu::calculate_typecast_fp16b_to_uint32, + dst_index, vector_mode); + } } template diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h b/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h index e29d02434594..22ebaba89e54 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h @@ -20,16 +20,21 @@ namespace ckernel { /** * Performs an elementwise typecast operation on the input. - * Supports typecast from fp32 to uint32. + * Supports following typecasts: + * fp32/fp16b -> uint32 + * fp32/fp16b -> uint16 + * For output to be uint32, Dest must be in 32 bit mode. * * Return value: None * * | Argument | Description | Type | Valid Range | Required | * |----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| * | tile_index | The index of the tile in DST register buffer to perform typecast operation | uint32_t | Must be less than the size of the DST register buffer | True | + * | OUT_DTYPE | Desired output data format | uint32_t | Must be valid tt::DataFormat | True | */ +template ALWI void typecast_tile(uint32_t idst) { - MATH(( llk_math_eltwise_unary_sfpu_typecast(idst) )); + MATH(( llk_math_eltwise_unary_sfpu_typecast(idst) )); } /** From 337800f02b1dca39ec130c79c2e85fd980b307b4 Mon Sep 17 00:00:00 2001 From: Radomir Djogo Date: Wed, 5 Jun 2024 00:27:25 +0000 Subject: [PATCH 42/53] #4858: update typecast description to include uint16 --- tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp | 2 +- ttnn/cpp/ttnn/operations/unary.hpp | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp index 1ffffc67ea3d..c4693d83a559 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp @@ -89,7 +89,7 @@ namespace tt::tt_metal::detail { detail::bind_unary_op_with_param( m_tensor, "eltwise_typecast", eltwise_typecast, py::arg("tt_output_dtype"), - R"doc(Returns tensor with all of the elements of the input tensor ``{0}`` typecasted from fp32 to uint32.)doc", + R"doc(Returns tensor with all of the elements of the input tensor ``{0}`` typecasted from fp32 to uint32 or uint16.)doc", R"doc("Indicates output dtype of typecast", "ttl.tensor.DataType", "")doc" ); diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index 2ab4686b5f48..88e41cf5766f 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -42,8 +42,9 @@ inline Tensor execute_on_worker_thread( const Tensor& input_tensor, const std::vector& op_chain, const std::optional& memory_config = std::nullopt) { - DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? DataType::UINT32 : input_tensor.get_dtype(); - bool fp32_dest_acc_en = input_tensor.get_dtype() == DataType::UINT32 or + DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? (DataType)op_chain[0].params[0] : input_tensor.get_dtype(); + bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or + input_tensor.get_dtype() == DataType::UINT32 or input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to // DST directly, fp32 is converted to fp16b return operation::run( From c236d94b61e1dfe7dcebc14bcbaea6b8e070bf34 Mon Sep 17 00:00:00 2001 From: Radomir Djogo Date: Wed, 5 Jun 2024 00:54:15 +0000 Subject: [PATCH 43/53] #4858: use static_cast and update llk_math_eltwise_unary_sfpu_params --- .../op_library/eltwise_unary/eltwise_unary_op.hpp | 4 ++-- .../llk_math_eltwise_unary_sfpu_typecast.h | 14 +++++++------- ttnn/cpp/ttnn/operations/unary.hpp | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index 2a26f0f4c5c3..6dece1630529 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -196,7 +196,7 @@ inline Tensor run_eltwise_unary( const std::vector& ops_chain, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified"); - DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? (DataType)ops_chain[0].params[0] : input_tensor.get_dtype(); + DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast(ops_chain[0].params[0]) : input_tensor.get_dtype(); bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_tensor.get_dtype() == DataType::UINT32 or @@ -242,7 +242,7 @@ inline Tensor run_eltwise_unary( const std::vector& ops_chain, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified"); - DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? (DataType)ops_chain[0].params[0] : input_tensor.get_dtype(); + DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast(ops_chain[0].params[0]) : input_tensor.get_dtype(); bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_tensor.get_dtype() == DataType::UINT32 or diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h index b5a9a6bf0c3d..8a7f9d95a531 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h @@ -5,7 +5,7 @@ #pragma once #include "llk_math_eltwise_unary_sfpu_init.h" -#include "llk_math_eltwise_unary_sfpu_0_param.h" +#include "llk_math_eltwise_unary_sfpu_params.h" #include "ckernel_sfpu_typecast.h" namespace ckernel { @@ -15,16 +15,16 @@ namespace ckernel { template inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode = (int)VectorMode::RC) { if constexpr (OUT_DTYPE == (uint32_t)DataFormat::UInt16) { - llk_math_eltwise_unary_sfpu_0_param - (ckernel::sfpu::calculate_typecast_fp16b_to_uint16, + llk_math_eltwise_unary_sfpu_params( ckernel::sfpu::calculate_typecast_fp16b_to_uint16, - dst_index, vector_mode); + dst_index, + vector_mode); } else if constexpr (OUT_DTYPE == (uint32_t)DataFormat::UInt32) { - llk_math_eltwise_unary_sfpu_0_param - (ckernel::sfpu::calculate_typecast_fp16b_to_uint32, + llk_math_eltwise_unary_sfpu_params( ckernel::sfpu::calculate_typecast_fp16b_to_uint32, - dst_index, vector_mode); + dst_index, + vector_mode); } } diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index 88e41cf5766f..2b95096d2fc7 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -42,7 +42,7 @@ inline Tensor execute_on_worker_thread( const Tensor& input_tensor, const std::vector& op_chain, const std::optional& memory_config = std::nullopt) { - DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? (DataType)op_chain[0].params[0] : input_tensor.get_dtype(); + DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast(op_chain[0].params[0]) : input_tensor.get_dtype(); bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_tensor.get_dtype() == DataType::UINT32 or input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to From 56049f39feb6a5218060a20eba7aaaacb7716431 Mon Sep 17 00:00:00 2001 From: Artem Yerofieiev <169092593+ayerofieiev-tt@users.noreply.github.com> Date: Tue, 4 Jun 2024 20:32:02 -0700 Subject: [PATCH 44/53] #8540: Upgrade eltwise binary ops to support queue_id /output_tensor / uint output dtype (#9071) * Support output_tensor in ttnn eltwise binary ops * Respect output_dtype or output_tensor dtype and call typecast to uint16/32 if required * Added queue_id support * Updated eq test to make sure all cases work as expected --- tests/ttnn/unit_tests/operations/test_math.py | 56 +++++++++++++ .../host/reduce_scatter_full_worker_grid.cpp | 2 +- .../eltwise_binary/eltwise_binary_op.cpp | 18 +++-- .../eltwise_binary/eltwise_binary_op.hpp | 23 +++--- .../eltwise_binary_op_multi_core.cpp | 6 +- ttnn/cpp/pybind11/operations/binary.hpp | 33 +++++--- ttnn/cpp/ttnn/op_library/binary/binary_op.cpp | 16 +++- ttnn/cpp/ttnn/op_library/binary/binary_op.hpp | 79 +++++++++++++++---- 8 files changed, 188 insertions(+), 45 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_math.py b/tests/ttnn/unit_tests/operations/test_math.py index c1cf8198b436..1fc6f66619e6 100644 --- a/tests/ttnn/unit_tests/operations/test_math.py +++ b/tests/ttnn/unit_tests/operations/test_math.py @@ -7,6 +7,8 @@ import torch import ttnn +import tt_lib +from models.utility_functions import is_grayskull from tests.ttnn.utils_for_testing import assert_with_pcc from models.utility_functions import torch_random @@ -69,6 +71,60 @@ def test_lgamma(device, h, w): run_math_unary_test(device, h, w, ttnn.lgamma, torch.lgamma, pcc=0.999) +@pytest.mark.parametrize("h", [32]) +@pytest.mark.parametrize("w", [32]) +@pytest.mark.parametrize("output_dtype", [ttnn.DataType.BFLOAT16, ttnn.DataType.UINT16, ttnn.DataType.UINT32]) +def test_eq(device, h, w, output_dtype): + if is_grayskull() and output_dtype in (ttnn.DataType.UINT32, ttnn.DataType.UINT16): + pytest.skip("GS does not support fp32/uint32/uint16 data types") + + torch.manual_seed(0) + + same = 50 + torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16) + torch_input_tensor_a[0, 0] = same + torch_input_tensor_a[0, 1] = same + torch_input_tensor_a[0, 2] = same + + torch_input_tensor_b = torch.rand((h, w), dtype=torch.bfloat16) + torch_input_tensor_b[0, 0] = same + torch_input_tensor_b[0, 1] = same + torch_input_tensor_b[0, 2] = same + + torch_output_tensor = torch.eq(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + pages_before = ttnn._ttnn.reports.get_buffer_pages() + output_tensor = ttnn.eq(input_tensor_a, input_tensor_b, dtype=output_dtype) + assert output_tensor.get_dtype() == output_dtype + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - 1 + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(torch_output_tensor, output_tensor, 0.999) + + # EQ with a preallocated output tensor + output_tensor_preallocated_bfloat16 = ttnn.ones( + [h, w], ttnn.DataType.BFLOAT16, ttnn.TILE_LAYOUT, device, ttnn.L1_MEMORY_CONFIG + ) + output_tensor_preallocated = output_tensor_preallocated_bfloat16 + # There is no good way to create uint16 tensor in ttnn/torch, so we create bfloat16 and typecast to target + if output_dtype != ttnn.DataType.BFLOAT16: + output_tensor_preallocated = tt_lib.tensor.typecast( + output_tensor_preallocated_bfloat16, output_dtype, ttnn.L1_MEMORY_CONFIG + ) + + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.eq(input_tensor_a, input_tensor_b, dtype=output_dtype, output_tensor=output_tensor_preallocated) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + torch_output_tensor_preallocated = ttnn.to_torch(output_tensor_preallocated) + assert_with_pcc(torch_output_tensor, torch_output_tensor_preallocated, 0.999) + + @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_log10(device, h, w): diff --git a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp index 73fde5957023..4982a1e2110d 100644 --- a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp +++ b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp @@ -472,7 +472,7 @@ static std::tuple build_reduce_scatter_worker( vector compute_kernel_args = {}; constexpr bool fp32_dest_acc_en = false; constexpr bool math_approx_mode = false; - std::map eltwise_defines = eltwise_binary_op_utils::get_defines(binary_math_op, std::nullopt); + std::map eltwise_defines = eltwise_binary_op_utils::get_defines(binary_math_op); KernelHandle worker_reduce_kernel_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/eltwise_binary/kernels/compute/eltwise_binary.cpp", diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp index ea091ce92695..bdd11b215dae 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp @@ -13,11 +13,13 @@ using namespace tt::constants; +namespace tt { +namespace tt_metal { namespace eltwise_binary_op_utils { using namespace tt::tt_metal; std::map get_defines( - BinaryOpType op_type, const std::optional> fused_activations) { + BinaryOpType op_type, const std::optional output_dtype, const std::optional> fused_activations) { std::map defines; string op_name = "sub_tiles"; string op_binary_type = "EltwiseBinaryType::ELWSUB"; @@ -104,6 +106,15 @@ std::map get_defines( default: TT_ASSERT(false && "Undefined op type"); } + if(output_dtype.has_value() && output_dtype.value() == DataType::UINT32){ + TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined"); + + auto dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(output_dtype.value())); + defines.insert({"SFPU_OP_CHAIN_0", + fmt::format("typecast_tile_init(); typecast_tile<{0}u>(i);", dataformat)}); + defines.insert({"SFPU_OP_TYPECAST_INCLUDE", "1"}); + } + defines["ELTWISE_OP"] = op_name.c_str(); defines["ELTWISE_OP_TYPE"] = op_binary_type.c_str(); if (fused_activations.has_value()) { @@ -120,11 +131,6 @@ std::map get_defines( } // namespace eltwise_binary_op_utils -namespace tt { - -namespace tt_metal { - - void EltwiseBinary::validate_with_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp index a774520904f8..d69e84c3265c 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp @@ -12,6 +12,7 @@ #include "tt_dnn/op_library/repeat/repeat_op.hpp" #include "tt_dnn/op_library/run_operation.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/common/logger.hpp" namespace tt { @@ -38,6 +39,14 @@ enum class BinaryOpType { DIV_FAST }; +namespace eltwise_binary_op_utils { + +std::map get_defines(BinaryOpType op_type, const std::optional out_dtype = std::nullopt, + const std::optional> fused_activations = std::nullopt); + +} // namespace eltwise_binary_op_utils + + enum class BinaryOpParallelizationStrategy { MULTI_CORE }; operation::ProgramWithCallbacks eltwise_binary_multi_core( @@ -132,14 +141,16 @@ struct make_eltwise_binary { (in_a.get_legacy_shape() == in_b.get_legacy_shape()) or (in_a.get_legacy_shape().without_padding() == in_b.get_legacy_shape().without_padding()), "Input shapes must be the same!"); - return operation::run( + + auto output_tensors = operation::run( EltwiseBinary{ binary_op_type, fused_activations, output_mem_config, output_dtype.value_or(in_a.get_dtype()), - false}, + false /*in place*/}, {in_a, in_b}, {}, {output_tensor}); + return output_tensors; }, {input_tensor_a, input_tensor_b}, output_tensors, {}, {output_tensor}); return output_tensors.at(0); @@ -231,11 +242,3 @@ inline Tensor add( } // namespace operations } // namespace tt - -namespace eltwise_binary_op_utils { -using namespace tt::tt_metal; - -std::map get_defines( - BinaryOpType op_typee, const std::optional> fused_activations); - -} // namespace eltwise_binary_op_utils diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp index 37a772afb208..f9bf11ef33e3 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp @@ -312,7 +312,7 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const Tensor &a, const } auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config); - std::map eltwise_defines = eltwise_binary_op_utils::get_defines(op_type, fused_activations); + std::map eltwise_defines = eltwise_binary_op_utils::get_defines(op_type, output.get_dtype(), fused_activations); if (eltwise_defines.find("SFPU_OP_INIT_PRE_IN0_0") != eltwise_defines.end()) { tt_metal::CircularBufferConfig cb_interm_config = tt_metal::CircularBufferConfig(1 * src0_single_tile_size, {{CB::c_intermed0, src0_cb_data_format}}) @@ -371,12 +371,12 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const Tensor &a, const all_device_cores, tt_metal::WriterDataMovementConfig(writer_compile_time_args, writer_defines)); + bool fp32_dest_acc_en = dst_cb_data_format == tt::DataFormat::UInt32 || dst_cb_data_format == tt::DataFormat::Int32 || dst_cb_data_format == tt::DataFormat::Float32; auto eltwise_binary_kernel_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/eltwise_binary/kernels/compute/eltwise_binary.cpp", all_device_cores, - tt_metal::ComputeConfig{.defines = eltwise_defines} - ); + tt_metal::ComputeConfig{.fp32_dest_acc_en=fp32_dest_acc_en, .defines = eltwise_defines}); set_eltwise_binary_runtime_args( diff --git a/ttnn/cpp/pybind11/operations/binary.hpp b/ttnn/cpp/pybind11/operations/binary.hpp index 4c9f2104b58a..7bbf43ff2a14 100644 --- a/ttnn/cpp/pybind11/operations/binary.hpp +++ b/ttnn/cpp/pybind11/operations/binary.hpp @@ -33,9 +33,11 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati * :attr:`input_tensor_b` (ttnn.Tensor or Number): the tensor or number to add to :attr:`input_tensor_a`. Keyword args: - * :attr:`memory_config` (ttnn.MemoryConfig): memory config for the output tensor - * :attr:`dtype` (ttnn.DataType): data type for the output tensor - * :attr:`activations` (List[str]): list of activation functions to apply to the output tensor + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + * :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id Example:: @@ -51,34 +53,47 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati module, operation, doc, + // tensor and scalar ttnn::pybind_overload_t{ [](const binary_operation_t& self, const ttnn::Tensor& input_tensor_a, const float scalar, const std::optional& memory_config, const std::optional& dtype, - const std::optional>& activations) -> ttnn::Tensor { - return self(input_tensor_a, scalar, memory_config, dtype, activations); + const std::optional& output_tensor, + const std::optional>& activations, + const uint8_t& queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor_a, scalar, memory_config, dtype, output_tensor, activations); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), + py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("dtype") = std::nullopt, - py::arg("activations") = std::nullopt}, + py::arg("output_tensor") = std::nullopt, + py::arg("activations") = std::nullopt, + py::arg("queue_id") = 0}, + + // tensor and tensor ttnn::pybind_overload_t{ [](const binary_operation_t& self, const ttnn::Tensor& input_tensor_a, const ttnn::Tensor& input_tensor_b, const std::optional& memory_config, const std::optional& dtype, - const std::optional>& activations) -> ttnn::Tensor { - return self(input_tensor_a, input_tensor_b, memory_config, dtype, activations); + const std::optional& output_tensor, + const std::optional>& activations, + const uint8_t& queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor_a, input_tensor_b, memory_config, dtype, output_tensor, activations); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), + py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("dtype") = std::nullopt, - py::arg("activations") = std::nullopt}); + py::arg("output_tensor") = std::nullopt, + py::arg("activations") = std::nullopt, + py::arg("queue_id") = 0}); } } // namespace detail diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp index 243b6ef4808a..b73c63d819f2 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp @@ -80,7 +80,7 @@ inline BinaryProgramType get_program_type(const Binary& operation, const std::ve TT_THROW("ttnn::operations::binary::Binary: unsupported broadcast"); } -void Binary::validate(const std::vector& input_tensors) const { +void Binary::validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const { auto program_type = get_program_type(*this, input_tensors); const auto& input_tensor_a = input_tensors.at(0); @@ -170,6 +170,14 @@ void Binary::validate(const std::vector& input_tensors) const { if (program_type != BinaryProgramType::ElementWiseMultiCore) { TT_FATAL(not this->program_config.activations.has_value()); } + + if (!output_tensors.empty()) { + TT_FATAL(output_tensors.size() == 1, "Must have 1 output tensors"); + + if(output_tensors.at(0).has_value()) { + TT_FATAL(!this->program_config.in_place, "Operation is configured as in_place. First input is used as output. Provided output tensor is ignored"); + } + } } std::vector Binary::compute_output_shapes(const std::vector& input_tensors) const { @@ -181,12 +189,16 @@ std::vector Binary::compute_output_shapes(const std::vector return {input_tensor_b.get_legacy_shape()}; } -std::vector Binary::create_output_tensors(const std::vector& input_tensors) const { +std::vector Binary::create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); if (this->program_config.in_place) { return {input_tensor_a}; } else { + if (!output_tensors.empty() && output_tensors.at(0).has_value()) { + return {output_tensors.at(0).value()}; + } + auto program_type = get_program_type(*this, input_tensors); if (program_type == BinaryProgramType::ElementWiseMultiCore) { diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp index 006097c36331..6ae72c5a9822 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp @@ -29,6 +29,8 @@ namespace binary { using BinaryOpType = tt::tt_metal::BinaryOpType; +constexpr uint8_t DefaultQueueId = 0; + struct BinaryProgramConfig { BinaryOpType binary_op_type; bool in_place; @@ -48,9 +50,9 @@ struct Binary { const BinaryProgramConfig program_config; std::optional compute_kernel_config; - void validate(const std::vector &input_tensors) const; + void validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; @@ -92,16 +94,23 @@ struct ExecuteBinary { } template - static auto input_tensors_to_validate(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { + static auto input_tensors_to_validate(uint8_t queue_id, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { return std::forward_as_tuple(input_tensor_a, input_tensor_b); } static Tensor execute_on_worker_thread( + uint8_t queue_id, const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg, const std::optional &memory_config = std::nullopt, - const std::optional &dtype = std::nullopt, + const std::optional &output_dtype = std::nullopt, + std::optional optional_output_tensor = std::nullopt, std::optional> activations = std::nullopt) { + + if(output_dtype.has_value() && optional_output_tensor.has_value()){ + TT_FATAL(output_dtype.value() == optional_output_tensor.value().get_dtype(), "If both output dtype and output tensor provided dtype should match"); + } + auto &&[input_tensor_a, input_tensor_b] = [](const auto &input_tensor_a_arg, const auto &input_tensor_b_arg) { const auto input_shape_a = input_tensor_a_arg.get_shape(); const auto input_shape_b = input_tensor_b_arg.get_shape(); @@ -111,6 +120,7 @@ struct ExecuteBinary { } return std::make_tuple(input_tensor_a_arg, input_tensor_b_arg); }(input_tensor_a_arg, input_tensor_b_arg); + auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config()); // TODO(arakhmati): #7731 - remove this! @@ -124,15 +134,38 @@ struct ExecuteBinary { input_tensor_b = tt::tt_metal::repeat(input_tensor_b, repeats.value(), output_memory_config); } - return operation::run( - Binary{BinaryProgramConfig{ - binary_op_type, - in_place, - activations, - output_memory_config, - dtype.value_or(input_tensor_a.get_dtype())}}, - {input_tensor_a, input_tensor_b}) - .at(0); + DataType dtype = output_dtype.value_or(input_tensor_a.get_dtype()); + if(optional_output_tensor.has_value()) { + dtype = optional_output_tensor.value().get_dtype(); + } + + auto output_tensors = operation::run(Binary{BinaryProgramConfig{binary_op_type, + in_place, + activations, + output_memory_config, + dtype}}, + {input_tensor_a, input_tensor_b}, + {}, + {optional_output_tensor}, + queue_id); + + return output_tensors.at(0); + } + + template + static auto input_tensors_to_validate(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { + return std::forward_as_tuple(input_tensor_a, input_tensor_b); + } + + static Tensor execute_on_worker_thread( + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + const std::optional &memory_config = std::nullopt, + const std::optional &output_dtype = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + std::optional> activations = std::nullopt) + { + return execute_on_worker_thread(DefaultQueueId, input_tensor_a_arg, input_tensor_b_arg, memory_config, output_dtype, optional_output_tensor, activations); } template @@ -147,6 +180,24 @@ struct ExecuteBinary { const float scalar, const std::optional &memory_config = std::nullopt, const std::optional &dtype = std::nullopt, + const std::optional &optional_output_tensor = std::nullopt, + std::optional> activations = std::nullopt) { + + return ExecuteBinary::execute_on_worker_thread(DefaultQueueId, input_tensor_a, scalar, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, dtype, optional_output_tensor, activations); + } + + template + static auto input_tensors_to_validate(uint8_t queue_id, const Tensor &input_tensor_a, const float input_tensor_b, Args &&...args) { + return std::forward_as_tuple(input_tensor_a, input_tensor_b); + } + + static Tensor execute_on_worker_thread( + uint8_t queue_id, + const ttnn::Tensor &input_tensor_a, + const float scalar, + const std::optional &memory_config = std::nullopt, + const std::optional &dtype = std::nullopt, + const std::optional &optional_output_tensor = std::nullopt, std::optional> activations = std::nullopt) { // Cast Float Scalar to a device tensor auto host_buffer = owned_buffer::create<::bfloat16>(static_cast(TILE_HEIGHT * TILE_WIDTH)); @@ -159,7 +210,7 @@ struct ExecuteBinary { Tensor scalar_tensor_device = scalar_tensor_host.to(input_tensor_a.device()); // TODO(arakhmati): #7637 pass in memory_config instead of operation::DEFAULT_OUTPUT_MEMORY_CONFIG return ExecuteBinary::execute_on_worker_thread( - input_tensor_a, scalar_tensor_device, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, dtype, activations); + input_tensor_a, scalar_tensor_device, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, dtype, optional_output_tensor, activations); } }; From 5ecec99882966671f9119a64b65618dc6b4f5098 Mon Sep 17 00:00:00 2001 From: hschoi Date: Tue, 4 Jun 2024 07:57:20 +0000 Subject: [PATCH 45/53] #9095: implement callback helper function --- .../op_library/moreh_helper_functions.hpp | 132 ++++++++++++++++++ 1 file changed, 132 insertions(+) diff --git a/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp b/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp index bcbe4abf5d5c..dd882937a0d7 100644 --- a/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp +++ b/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp @@ -126,6 +126,138 @@ struct CircularBufferArg { tt::DataFormat data_format, CircularBufferArg arg); + +struct CallbackArgMap { + std::map input; + std::map optional_input; + std::map output; +}; + +using Tensors = std::vector; +using OptionalConstTensors = std::vector>; + +// To use this function, the arguments in the reader kernel must always be sorted in the order of input followed by +// optional_input. Furthermore, input and output tensors must always start from the 0th argument. +template +const std::function +create_override_runtime_arguments_callback( + KernelHandle reader_kernel_id, KernelHandle writer_kernel_id, uint32_t num_cores, uint32_t core_h) { + return [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( + const void *operation, + Program &program, + const Tensors &input_tensors, + const OptionalConstTensors &optional_input_tensors, + const OutputTensors &output_tensors) -> void { + for (uint32_t icore = 0; icore < num_cores; icore++) { + CoreCoord core = {icore / core_h, icore % core_h}; + + // readers + { + uint32_t rt_idx = 0; + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + for (uint32_t idx = 0; idx < input_tensors.size(); idx++) { + runtime_args[rt_idx++] = input_tensors.at(idx).buffer()->address(); + } + for (uint32_t idx = 0; idx < optional_input_tensors.size(); idx++) { + auto optional_input_tensor = optional_input_tensors.at(idx); + runtime_args[rt_idx++] = + optional_input_tensor.has_value() ? optional_input_tensor.value().buffer()->address() : 0; + } + } + + // writer + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + for (uint32_t idx = 0; idx < output_tensors.size(); idx++) { + runtime_args[idx] = output_tensors.at(idx).buffer()->address(); + } + } + } + }; +} + +// Using this structure is not recommended because directly setting the callback argument map doesn't significantly +// reduce the amount of code. +template +const std::function +create_override_runtime_arguments_callback( + KernelHandle reader_kernel_id, + KernelHandle writer_kernel_id, + uint32_t num_cores, + uint32_t core_h, + CallbackArgMap &arg_map) { + return [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, arg_map, num_cores, core_h]( + const void *operation, + Program &program, + const Tensors &input_tensors, + const OptionalConstTensors &optional_input_tensors, + const OutputTensors &output_tensors) -> void { + for (uint32_t icore = 0; icore < num_cores; icore++) { + CoreCoord core = {icore / core_h, icore % core_h}; + + // readers + { + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + for (const auto &pair : arg_map.input) { + runtime_args[pair.first] = input_tensors.at(pair.second).buffer()->address(); + } + for (const auto &pair : arg_map.optional_input) { + auto optional_input_tensor = optional_input_tensors.at(pair.second); + runtime_args[pair.first] = + optional_input_tensor.has_value() ? optional_input_tensor.value().buffer()->address() : 0; + } + } + + // writer + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + for (const auto &pair : arg_map.output) { + runtime_args[pair.first] = output_tensors.at(pair.second).buffer()->address(); + } + } + } + }; +} + +// To use this function, the arguments in the reader kernel must always be sorted in the order of input followed by +// optional_input. Furthermore, input and output tensors must always start from the 0th argument. +template +const std::function&, const std::vector&)> +create_override_addresses_callback( + KernelHandle reader_kernel_id, KernelHandle writer_kernel_id, uint32_t num_cores, uint32_t core_h) { + return [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) -> void { + for (uint32_t icore = 0; icore < num_cores; icore++) { + CoreCoord core = {icore / core_h, icore % core_h}; + + // readers + { + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + for (uint32_t idx = 0; idx < input_buffers.size(); idx++) { + auto buffer = input_buffers.at(idx); + if (buffer != nullptr) { + runtime_args[idx] = buffer->address(); + } + } + } + + // writer + { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + for (uint32_t idx = 0; idx < output_buffers.size(); idx++) { + auto buffer = output_buffers.at(idx); + if (buffer != nullptr) { + runtime_args[idx] = buffer->address(); + } + } + } + } + }; +} + + } // namespace primary } // namespace operations } // namespace tt From 2aea68f6c37c623981dc1e43cd9a73af5e7a0965 Mon Sep 17 00:00:00 2001 From: hschoi Date: Tue, 4 Jun 2024 07:58:22 +0000 Subject: [PATCH 46/53] #9095: apply callback helper function to moreh_adamw --- .../unit_testing/misc/test_moreh_adamw.py | 52 ++++--- .../op_library/moreh_adamw/moreh_adamw.cpp | 136 ++++++++---------- 2 files changed, 92 insertions(+), 96 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py index 08f033f23288..6826e3724d98 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py @@ -17,22 +17,7 @@ from loguru import logger -@pytest.mark.parametrize( - "shape", - ( - (1, 1, 32, 32), # single - (12, 6, 64, 64), # multi tile - ), -) -@pytest.mark.parametrize("lr", [0.0, 1e-2]) -@pytest.mark.parametrize("betas", ((0.9, 0.999), (0.5, 0.555))) -@pytest.mark.parametrize("eps", [1e-06, 1e-08]) -@pytest.mark.parametrize("weight_decay", [0.0, 0.3]) -@pytest.mark.parametrize("amsgrad", [True, False]) -@pytest.mark.parametrize("step", [1, 2, 8]) -def test_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device): - torch.manual_seed(0) - +def run_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device): N = shape[0] C = shape[1] H = shape[2] @@ -205,3 +190,38 @@ def forward(self, x): whole_passing &= passing assert whole_passing + + +@pytest.mark.parametrize( + "shape", + ( + (1, 1, 32, 32), # single + (12, 6, 64, 64), # multi tile + ), +) +@pytest.mark.parametrize("lr", [0.0, 1e-2]) +@pytest.mark.parametrize("betas", ((0.9, 0.999), (0.5, 0.555))) +@pytest.mark.parametrize("eps", [1e-06, 1e-08]) +@pytest.mark.parametrize("weight_decay", [0.0, 0.3]) +@pytest.mark.parametrize("amsgrad", [True, False]) +@pytest.mark.parametrize("step", [1, 2, 8]) +def test_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device): + torch.manual_seed(0) + + run_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device) + + +@pytest.mark.parametrize( + "shape", + ((1, 1, 32, 32),), # single +) +@pytest.mark.parametrize("lr", [1e-2]) +@pytest.mark.parametrize("betas", [[0.9, 0.999], [0.5, 0.555]]) +@pytest.mark.parametrize("eps", [1e-08]) +@pytest.mark.parametrize("weight_decay", [0.3]) +@pytest.mark.parametrize("amsgrad", [True, False]) +@pytest.mark.parametrize("step", [8]) +def test_moreh_adamw_callback(shape, lr, betas, eps, weight_decay, amsgrad, step, device, use_program_cache): + torch.manual_seed(0) + for _ in range(2): + run_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device) diff --git a/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp b/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp index 3f0d618224c7..8823e9aa5c68 100644 --- a/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp @@ -10,8 +10,8 @@ #include "tt_dnn/op_library/run_operation.hpp" #include "tt_eager/tensor/tensor.hpp" #include "tt_eager/tensor/tensor_impl.hpp" -#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" #include "tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw_op.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" #include "tt_eager/tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/math.hpp" #include "tt_metal/detail/util.hpp" @@ -26,9 +26,14 @@ operation::ProgramWithCallbacks moreh_adamw_( const Tensor& grad, const Tensor& exp_avg, const Tensor& exp_avg_sq, - float lr, float beta1, float beta2, float eps, float weight_decay, uint32_t step, bool amsgrad, + float lr, + float beta1, + float beta2, + float eps, + float weight_decay, + uint32_t step, + bool amsgrad, const std::optional> max_exp_avg_sq) { - uint32_t num_tiles = param.volume() / TILE_HW; Program program{}; @@ -36,14 +41,15 @@ operation::ProgramWithCallbacks moreh_adamw_( //////////////////////////////////////////////////////////////////////////// // Device Setup //////////////////////////////////////////////////////////////////////////// - tt_metal::Device *device = param.device(); + tt_metal::Device* device = param.device(); auto grid = device->compute_with_storage_grid_size(); const auto num_cores_y = grid.y; // auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); // uint32_t num_cores_x = compute_with_storage_grid_size.x; // uint32_t num_cores_y = compute_with_storage_grid_size.y; - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = tt_metal::split_work_to_cores(grid, num_tiles); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt_metal::split_work_to_cores(grid, num_tiles); //////////////////////////////////////////////////////////////////////////// // CircularBuffer Setup @@ -54,27 +60,27 @@ operation::ProgramWithCallbacks moreh_adamw_( all_cores, data_format, { - {CB::c_in0, 1}, // param - {CB::c_in1, 1}, // grad - {CB::c_in2, 1}, // exp_avg - {CB::c_in3, 1}, // exp_avg_sq - {CB::c_in4, 1}, // max_exp_avg_sq (optional) - {CB::c_in5, 5}, // lr, beta1, beta2, eps, weight_decay - {CB::c_in6, 1}, // 1.0f - - {CB::c_intermed0, 1}, // tmp_grad - {CB::c_intermed1, 1}, // tmp_exp_avg - {CB::c_intermed2, 1}, // tmp_exp_avg_sq - {CB::c_intermed3, 1}, // tmp_max_exp_avg_sq - {CB::c_intermed4, 1}, // - {CB::c_intermed5, 1}, // - {CB::c_intermed6, 1}, // tmp1 - {CB::c_intermed7, 1}, // tmp2 - - {CB::c_out0, 1}, // param - {CB::c_out1, 1}, // exp_avg - {CB::c_out2, 1}, // exp_avg_sq - {CB::c_out3, 1}, // max_exp_avg_sq (optional) + {CB::c_in0, 1}, // param + {CB::c_in1, 1}, // grad + {CB::c_in2, 1}, // exp_avg + {CB::c_in3, 1}, // exp_avg_sq + {CB::c_in4, 1}, // max_exp_avg_sq (optional) + {CB::c_in5, 5}, // lr, beta1, beta2, eps, weight_decay + {CB::c_in6, 1}, // 1.0f + + {CB::c_intermed0, 1}, // tmp_grad + {CB::c_intermed1, 1}, // tmp_exp_avg + {CB::c_intermed2, 1}, // tmp_exp_avg_sq + {CB::c_intermed3, 1}, // tmp_max_exp_avg_sq + {CB::c_intermed4, 1}, // + {CB::c_intermed5, 1}, // + {CB::c_intermed6, 1}, // tmp1 + {CB::c_intermed7, 1}, // tmp2 + + {CB::c_out0, 1}, // param + {CB::c_out1, 1}, // exp_avg + {CB::c_out2, 1}, // exp_avg_sq + {CB::c_out3, 1}, // max_exp_avg_sq (optional) }); //////////////////////////////////////////////////////////////////////////// @@ -117,19 +123,20 @@ operation::ProgramWithCallbacks moreh_adamw_( compute_defines["AMSGRAD"] = "1"; } - const std::vector compute_args_group_1{ - num_tiles_per_core_group_1}; + const std::vector compute_args_group_1{num_tiles_per_core_group_1}; const auto compute_kernel_file = "tt_eager/tt_dnn/op_library/moreh_adamw/kernels/" "moreh_adamw.cpp"; auto compute_kernel_1_id = CreateComputeKernel( - program, compute_kernel_file, {core_group_1, num_tiles_per_core_group_1, compute_args_group_1}, compute_defines); + program, + compute_kernel_file, + {core_group_1, num_tiles_per_core_group_1, compute_args_group_1}, + compute_defines); KernelHandle compute_kernel_2_id = -1; if (!core_group_2.ranges().empty()) { - const std::vector compute_args_group_2{ - num_tiles_per_core_group_2}; + const std::vector compute_args_group_2{num_tiles_per_core_group_2}; compute_kernel_2_id = CreateComputeKernel( program, @@ -170,14 +177,24 @@ operation::ProgramWithCallbacks moreh_adamw_( } const std::vector reader_runtime_args{ - param_addr, grad_addr, exp_avg_addr, exp_avg_sq_addr, max_exp_avg_sq_addr, - f2u_lr.u, f2u_beta1.u, f2u_beta2.u, f2u_eps.u, f2u_weight_decay.u, step, static_cast(amsgrad), - num_tiles_per_core, tile_offset}; + param_addr, + grad_addr, + exp_avg_addr, + exp_avg_sq_addr, + max_exp_avg_sq_addr, + f2u_lr.u, + f2u_beta1.u, + f2u_beta2.u, + f2u_eps.u, + f2u_weight_decay.u, + step, + static_cast(amsgrad), + num_tiles_per_core, + tile_offset}; tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); const std::vector writer_runtime_args{ - param_addr, exp_avg_addr, exp_avg_sq_addr, max_exp_avg_sq_addr, - num_tiles_per_core, tile_offset}; + param_addr, exp_avg_addr, exp_avg_sq_addr, max_exp_avg_sq_addr, num_tiles_per_core, tile_offset}; tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args); if (core_group_1.core_coord_in_core_ranges(core)) { @@ -191,50 +208,9 @@ operation::ProgramWithCallbacks moreh_adamw_( tile_offset += num_tiles_per_core; } - //////////////////////////////////////////////////////////////////////////// - // Callback SetUp - //////////////////////////////////////////////////////////////////////////// - auto override_runtime_args_callback = [reader_kernel_id = reader_kernel_id, - writer_kernel_id = writer_kernel_id, - num_cores = num_cores, - num_cores_y = num_cores_y]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto param_buffer = input_buffers.at(0); - auto grad_buffer = input_buffers.at(1); - auto exp_avg_buffer = input_buffers.at(2); - auto exp_avg_sq_buffer = input_buffers.at(3); - auto max_exp_avg_sq_buffer = input_buffers.at(4); - - for (uint32_t i = 0; i < num_cores; ++i) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = param_buffer->address(); - runtime_args[1] = grad_buffer->address(); - runtime_args[2] = exp_avg_buffer->address(); - runtime_args[3] = exp_avg_sq_buffer->address(); - if (max_exp_avg_sq_buffer != nullptr) { - runtime_args[4] = max_exp_avg_sq_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = param_buffer->address(); - runtime_args[1] = grad_buffer->address(); - runtime_args[2] = exp_avg_buffer->address(); - runtime_args[3] = exp_avg_sq_buffer->address(); - if (max_exp_avg_sq_buffer != nullptr) { - runtime_args[4] = max_exp_avg_sq_buffer->address(); - } - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + std::move(program), + create_override_addresses_callback(reader_kernel_id, writer_kernel_id, num_cores, num_cores_y)}; } } // namespace primary From 0f30a5c613092d3b31022df38d5b743edf0c70c2 Mon Sep 17 00:00:00 2001 From: hschoi Date: Tue, 4 Jun 2024 21:49:07 +0000 Subject: [PATCH 47/53] #9095: apply callback helper function to moreh_nll_loss --- .../moreh_nll_loss_step1.cpp | 35 +---- .../moreh_nll_loss_step2.cpp | 123 ++-------------- .../reader_moreh_nll_loss_backward_2d.cpp | 3 +- .../reader_moreh_nll_loss_backward_3d.cpp | 2 +- .../reader_moreh_nll_loss_backward_4d.cpp | 2 +- .../moreh_nll_loss_backward.cpp | 137 ++---------------- 6 files changed, 35 insertions(+), 267 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp index e2a5b0752c39..f2646aa0b862 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp @@ -146,37 +146,10 @@ operation::ProgramWithCallbacks moreh_nll_loss_step1_impl( tile_offset += num_units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto target_dram_buffer = input_buffers.at(0); - auto weight_dram_buffer = input_buffers.at(1); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = target_dram_buffer->address(); - if (weight_dram_buffer != nullptr) { - runtime_args[1] = weight_dram_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp index fc2957d69c01..085416713c01 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp @@ -186,43 +186,10 @@ operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_2d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(input_buffers.size() == 4); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto target_dram_buffer = input_buffers.at(1); - auto weight_dram_buffer = input_buffers.at(2); - auto divisor_dram_buffer = input_buffers.at(3); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - runtime_args[1] = target_dram_buffer->address(); - if (weight_dram_buffer != nullptr) { - runtime_args[2] = weight_dram_buffer->address(); - } - if (divisor_dram_buffer != nullptr) { - runtime_args[3] = divisor_dram_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_3d( @@ -397,43 +364,10 @@ operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_3d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(input_buffers.size() == 4); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto target_dram_buffer = input_buffers.at(1); - auto weight_dram_buffer = input_buffers.at(2); - auto divisor_dram_buffer = input_buffers.at(3); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - runtime_args[1] = target_dram_buffer->address(); - if (weight_dram_buffer != nullptr) { - runtime_args[2] = weight_dram_buffer->address(); - } - if (divisor_dram_buffer != nullptr) { - runtime_args[3] = divisor_dram_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_4d( @@ -616,43 +550,10 @@ operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_4d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(input_buffers.size() == 4); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto target_dram_buffer = input_buffers.at(1); - auto weight_dram_buffer = input_buffers.at(2); - auto divisor_dram_buffer = input_buffers.at(3); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - runtime_args[1] = target_dram_buffer->address(); - if (weight_dram_buffer != nullptr) { - runtime_args[2] = weight_dram_buffer->address(); - } - if (divisor_dram_buffer != nullptr) { - runtime_args[3] = divisor_dram_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } operation::ProgramWithCallbacks moreh_nll_loss_step2_impl( diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp index 0fa899feb4a4..f83def0d88a5 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp @@ -3,13 +3,14 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp" +#include "dprint.h" void kernel_main() { uint32_t i = 0; auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); auto weight_addr = get_arg_val(i++); auto divisor_addr = get_arg_val(i++); - auto output_grad_addr = get_arg_val(i++); auto ignore_index = static_cast(get_arg_val(i++)); auto num_tiles_per_core = get_arg_val(i++); auto start_id = get_arg_val(i++); diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_3d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_3d.cpp index 6c8697bc352d..e48c188d9c00 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_3d.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_3d.cpp @@ -7,9 +7,9 @@ void kernel_main() { uint32_t i = 0; auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); auto weight_addr = get_arg_val(i++); auto divisor_addr = get_arg_val(i++); - auto output_grad_addr = get_arg_val(i++); auto ignore_index = static_cast(get_arg_val(i++)); auto num_tiles_per_core = get_arg_val(i++); auto start_id = get_arg_val(i++); diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_4d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_4d.cpp index 073298d147ab..3ca374cf0e8d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_4d.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_4d.cpp @@ -8,9 +8,9 @@ void kernel_main() { uint32_t i = 0; auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); auto weight_addr = get_arg_val(i++); auto divisor_addr = get_arg_val(i++); - auto output_grad_addr = get_arg_val(i++); auto ignore_index = static_cast(get_arg_val(i++)); auto num_tiles_per_core = get_arg_val(i++); auto start_id = get_arg_val(i++); diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp index 78570bb4f0d9..fa389814e759 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp @@ -156,9 +156,9 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_4d( std::vector reader_args = { target_addr, + output_grad_addr, weight_addr, divisor_addr, - output_grad_addr, static_cast(ignore_index), units_per_core, tile_offset, @@ -187,47 +187,12 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_4d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const void *operation, - Program &program, - const std::vector &input_tensors, - const std::vector> &optional_input_tensors, - const std::vector &output_tensors) { - TT_ASSERT(input_tensors.size() == 2); - TT_ASSERT(optional_input_tensors.size() == 2); - TT_ASSERT(output_tensors.size() == 1); - - auto target_addr = input_tensors.at(0).buffer()->address(); - auto output_grad_addr = input_tensors.at(1).buffer()->address(); - auto weight_addr = - optional_input_tensors.at(0).has_value() ? optional_input_tensors.at(0).value().buffer()->address() : 0; - auto divisor_addr = - optional_input_tensors.at(1).has_value() ? optional_input_tensors.at(1).value().buffer()->address() : 0; - auto input_grad_addr = output_tensors.at(0).buffer()->address(); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = target_addr; - runtime_args[1] = weight_addr; - runtime_args[2] = divisor_addr; - runtime_args[3] = output_grad_addr; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_addr; - } - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } - operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_3d( const Tensor &target, const std::optional weight, @@ -238,7 +203,6 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_3d( const bool reduction_mean, const CoreRange core_range, const DeviceComputeKernelConfig compute_kernel_config) { - // split work // input_grad: (N, C, W) @@ -370,9 +334,9 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_3d( std::vector reader_args = { target_addr, + output_grad_addr, weight_addr, divisor_addr, - output_grad_addr, static_cast(ignore_index), units_per_core, tile_offset, @@ -401,47 +365,12 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_3d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const void *operation, - Program &program, - const std::vector &input_tensors, - const std::vector> &optional_input_tensors, - const std::vector &output_tensors) { - TT_ASSERT(input_tensors.size() == 2); - TT_ASSERT(optional_input_tensors.size() == 2); - TT_ASSERT(output_tensors.size() == 1); - - auto target_addr = input_tensors.at(0).buffer()->address(); - auto output_grad_addr = input_tensors.at(1).buffer()->address(); - auto weight_addr = - optional_input_tensors.at(0).has_value() ? optional_input_tensors.at(0).value().buffer()->address() : 0; - auto divisor_addr = - optional_input_tensors.at(1).has_value() ? optional_input_tensors.at(1).value().buffer()->address() : 0; - auto input_grad_addr = output_tensors.at(0).buffer()->address(); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = target_addr; - runtime_args[1] = weight_addr; - runtime_args[2] = divisor_addr; - runtime_args[3] = output_grad_addr; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_addr; - } - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } - operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_2d( const Tensor &target, const std::optional weight, @@ -579,9 +508,9 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_2d( std::vector reader_args = { target_addr, + output_grad_addr, weight_addr, divisor_addr, - output_grad_addr, static_cast(ignore_index), units_per_core, tile_offset, @@ -609,48 +538,12 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_2d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const void *operation, - Program &program, - const std::vector &input_tensors, - const std::vector> &optional_input_tensors, - const std::vector &output_tensors) { - TT_ASSERT(input_tensors.size() == 2); - TT_ASSERT(optional_input_tensors.size() == 2); - TT_ASSERT(output_tensors.size() == 1); - - auto target_addr = input_tensors.at(0).buffer()->address(); - auto output_grad_addr = input_tensors.at(1).buffer()->address(); - auto weight_addr = - optional_input_tensors.at(0).has_value() ? optional_input_tensors.at(0).value().buffer()->address() : 0; - auto divisor_addr = - optional_input_tensors.at(1).has_value() ? optional_input_tensors.at(1).value().buffer()->address() : 0; - auto input_grad_addr = output_tensors.at(0).buffer()->address(); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = target_addr; - runtime_args[1] = weight_addr; - runtime_args[2] = divisor_addr; - runtime_args[3] = output_grad_addr; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_addr; - } - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } - - } // namespace operation::ProgramWithCallbacks moreh_nll_loss_backward_impl( From 24775d9b71c0c12ad9a3a0a9a64ec3343b8159f1 Mon Sep 17 00:00:00 2001 From: hschoi Date: Tue, 4 Jun 2024 08:04:33 +0000 Subject: [PATCH 48/53] #9095: apply callback helper function to moreh_softmax --- .../softmax_c_large/softmax_c_large.cpp | 37 ++---------------- .../softmax_h_large/softmax_h_large.cpp | 37 ++---------------- .../softmax_h_small/softmax_h_small.cpp | 37 ++---------------- .../softmax_w_large/softmax_w_large.cpp | 37 ++---------------- .../softmax_w_small/softmax_w_small.cpp | 37 ++---------------- .../softmax_backward_c_large.cpp | 39 ++----------------- .../softmax_backward_h_large.cpp | 39 ++----------------- .../softmax_backward_h_small.cpp | 39 ++----------------- .../softmax_backward_w_large.cpp | 39 ++----------------- .../softmax_backward_w_small.cpp | 39 ++----------------- 10 files changed, 40 insertions(+), 340 deletions(-) diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp index 79e0b62a5e08..4ee07c10ee58 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp @@ -129,39 +129,10 @@ operation::ProgramWithCallbacks moreh_softmax_c_large(const Tensor &input, const tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp index abcb6b194099..ea2ab9945448 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp @@ -123,39 +123,10 @@ operation::ProgramWithCallbacks moreh_softmax_h_large(const Tensor &input, const tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp index 77523098d4c2..cbf825a1b5e3 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp @@ -145,39 +145,10 @@ operation::ProgramWithCallbacks moreh_softmax_h_small(const Tensor &input, const tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp index f1ae31c7dd69..7018590c32ac 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp @@ -124,39 +124,10 @@ operation::ProgramWithCallbacks moreh_softmax_w_large(const Tensor &input, const tile_offset += num_tiles_per_core * Wt; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp index bf90b8d47b0e..1dcf9f818dcb 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp @@ -145,39 +145,10 @@ operation::ProgramWithCallbacks moreh_softmax_w_small(const Tensor &input, const tile_offset += num_tiles_per_core * Wt; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp index 5752781a8934..2447581d0f4a 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp @@ -135,41 +135,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_c_large(const Tensor &out tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp index 859867d17f00..638ed5dc7e7c 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp @@ -130,41 +130,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_large(const Tensor &out tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp index 44df27586980..b17cff78ce4b 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp @@ -152,41 +152,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_small(const Tensor &out tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp index 78ce4ceecfae..a46b647d51f4 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp @@ -130,41 +130,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_large(const Tensor &out tile_offset += num_tiles_per_core * Wt; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp index a834f5e4acd3..8488ca725468 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp @@ -153,41 +153,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_small(const Tensor &out tile_offset += num_tiles_per_core * Wt; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary From 169e7ff7a5c272a788aaa48f788e65b710aa4dce Mon Sep 17 00:00:00 2001 From: hschoi Date: Tue, 4 Jun 2024 21:59:56 +0000 Subject: [PATCH 49/53] #9095: change shape from tuple to list --- .../unit_testing/misc/test_moreh_adamw.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py index 6826e3724d98..f7f615b66a7a 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py @@ -194,10 +194,10 @@ def forward(self, x): @pytest.mark.parametrize( "shape", - ( - (1, 1, 32, 32), # single - (12, 6, 64, 64), # multi tile - ), + [ + [1, 1, 32, 32], # single + [12, 6, 64, 64], # multi tile + ], ) @pytest.mark.parametrize("lr", [0.0, 1e-2]) @pytest.mark.parametrize("betas", ((0.9, 0.999), (0.5, 0.555))) @@ -213,7 +213,7 @@ def test_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device) @pytest.mark.parametrize( "shape", - ((1, 1, 32, 32),), # single + [[1, 1, 32, 32]], # single ) @pytest.mark.parametrize("lr", [1e-2]) @pytest.mark.parametrize("betas", [[0.9, 0.999], [0.5, 0.555]]) From a3a4a716131f667ac48be7830a52fbe5de97d66b Mon Sep 17 00:00:00 2001 From: hschoi Date: Wed, 5 Jun 2024 03:28:47 +0000 Subject: [PATCH 50/53] #9095: change arg_map from ref to value copy --- tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp b/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp index dd882937a0d7..cc5f31427f9a 100644 --- a/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp +++ b/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp @@ -185,7 +185,7 @@ create_override_runtime_arguments_callback( KernelHandle writer_kernel_id, uint32_t num_cores, uint32_t core_h, - CallbackArgMap &arg_map) { + CallbackArgMap arg_map) { return [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, arg_map, num_cores, core_h]( const void *operation, Program &program, From a197fbcb80157916bf49658922d5e9f83f229401 Mon Sep 17 00:00:00 2001 From: KalaivaniMCW Date: Mon, 3 Jun 2024 04:03:51 +0000 Subject: [PATCH 51/53] #5044: Add optional output to where op --- .../python_api_testing/sweep_tests/op_map.py | 8 ++ .../pytests/tt_dnn/test_eltwise_ternary.py | 46 +++++++++++ .../sweep_tests/pytorch_ops.py | 6 ++ .../sweep_tests/tt_lib_ops.py | 22 ++++++ .../op_library/composite/composite_ops.cpp | 76 ++++++++++++++----- .../op_library/composite/composite_ops.hpp | 12 ++- .../tt_lib_bindings_tensor_composite_ops.cpp | 20 +++-- 7 files changed, 158 insertions(+), 32 deletions(-) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py index 4d70e6b70d6f..d6bf0b1ab5f4 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py @@ -803,6 +803,14 @@ "tt_op": tt_lib_ops.where, "pytorch_op": pytorch_ops.where, }, + "eltwise-where-optional": { + "tt_op": tt_lib_ops.where_optional, + "pytorch_op": pytorch_ops.where, + }, + "eltwise-where-scalar-optional": { + "tt_op": tt_lib_ops.where_scalar_optional, + "pytorch_op": pytorch_ops.where_scalar, + }, "where-bw": { "tt_op": tt_lib_ops.where_bw, "pytorch_op": pytorch_ops.where_bw, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py index 4ddb18dde5ce..fee9e99be3c3 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py @@ -4,6 +4,7 @@ import pytest import torch +import random from functools import partial from math import pi @@ -36,3 +37,48 @@ def test_run_eltwise_where_test(input_shapes, device, function_level_defaults): comparison_func, device, ) + + +@pytest.mark.parametrize("input_shapes", shapes) +def test_run_eltwise_where_test_optional(input_shapes, device, function_level_defaults): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-5, high=+5), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=+10), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32), + ] + comparison_func = partial(comparison_funcs.comp_pcc) + run_single_pytorch_test( + "eltwise-where-optional", + [input_shapes[0], input_shapes[0], input_shapes[0], input_shapes[0]], + datagen_func, + comparison_func, + device, + ) + + +shapes_scalar = ( + [[1, 1, 32, 32], [1, 1, 32, 32]], # Single core + [[1, 1, 320, 384], [1, 1, 320, 384]], # Multi core + [[1, 3, 320, 384], [1, 3, 320, 384]], # Multi core +) + + +@pytest.mark.parametrize("input_shapes", shapes_scalar) +def test_run_eltwise_where_scalar_optional(input_shapes, device, function_level_defaults): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32), + ] + test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0] + test_args.update({"scalar_true": random.uniform(0.5, 75.5), "scalar_false": random.uniform(0.5, 95.5)}) + + comparison_func = partial(comparison_funcs.comp_pcc) + run_single_pytorch_test( + "eltwise-where-scalar-optional", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 6a8047855130..1b0f4c27a1a9 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -96,6 +96,12 @@ def where(x, y, z, *args, **kwargs): return torch.where(x > 0, y, z) +def where_scalar(x, *args, **kwargs): + y = kwargs.pop("scalar_true") + z = kwargs.pop("scalar_false") + return torch.where(x > 0, y, z) + + def where_bw(x, y, z, w, *args, **kwargs): grad_data = x in_data = y diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index e22d65583297..b9dac18fd1b5 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -1518,6 +1518,28 @@ def where(x, y, z, device, dtype, layout, input_mem_config, output_mem_config, * return tt2torch_tensor(t3) +@setup_host_and_device +def where_optional(x, y, z, out, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1]) + t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2]) + t3 = setup_tt_tensor(out, device, layout[3], input_mem_config[3], dtype[3]) + ttl.tensor.where(t0, t1, t2, output_mem_config=output_mem_config, output_tensor=t3) + + return tt2torch_tensor(t3) + + +@setup_host_and_device +def where_scalar_optional( + x, out, device, dtype, layout, input_mem_config, output_mem_config, scalar_true, scalar_false, **kwargs +): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t3 = setup_tt_tensor(out, device, layout[1], input_mem_config[1], dtype[1]) + ttl.tensor.where(t0, scalar_true, scalar_false, output_mem_config=output_mem_config, output_tensor=t3) + + return tt2torch_tensor(t3) + + @setup_host_and_device def eltwise_div_unary( x, diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 7db4638049f9..97bd34762386 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -1228,48 +1228,84 @@ Tensor _where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config) { + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); - Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + if(output_tensor.has_value()) + { + mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v1( - const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) { + const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config); - Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + + if(output_tensor.has_value()){ + mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v2( - const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config) { - Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); + const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + + if(output_tensor.has_value()){ + mul(gtz(predicate, output_mem_config), value_true, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(output_tensor.value(), t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v3( - const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) { + const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config); Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + if(output_tensor.has_value()){ + add(t2, t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } else { + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } - Tensor where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config); + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config, output_tensor); } // on-device tensor creation 0s like @reference_tensor diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index 0d79d22a44ea..45edd04a6aca 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -316,22 +316,26 @@ Tensor where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const float value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const Tensor& value_true, const float value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const float value_true, const float value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); // on-device tensor creation 0s like @reference_tensor Tensor zeros_like( diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index b3750d8cdd80..5ea5a87f8ecf 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -72,8 +72,8 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -89,9 +89,10 @@ namespace tt::tt_metal::detail{ "true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "False Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -107,9 +108,10 @@ namespace tt::tt_metal::detail{ "true_value", "float", "float", "float scalar", "Yes" "false_value", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -125,9 +127,10 @@ namespace tt::tt_metal::detail{ "true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "float", "float", "float scalar", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -143,6 +146,7 @@ namespace tt::tt_metal::detail{ "true_value", "float", "float", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "float", "float", "float scalar", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); // *** composite unary ops *** detail::bind_unary_op(m_tensor, "normalize_hw", tt::tt_metal::normalize_hw, R"doc(Returns a new tensor with the Gaussian normalize of the elements of the input tensor ``{0}`` on H,W axes.)doc"); From 3023ec0f70b5673d6052bc92199820ed0d135495 Mon Sep 17 00:00:00 2001 From: Jack Cai Date: Tue, 4 Jun 2024 15:38:32 -0500 Subject: [PATCH 52/53] #0: enable multi-device tensor support for moreh sum op --- tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp b/tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp index f8c787a970e3..aa350191a30d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp @@ -53,7 +53,7 @@ Tensor _moreh_sum( std::optional compute_kernel_config) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input}))}; - TT_FATAL(input.storage_type() == StorageType::DEVICE); + TT_FATAL(input.storage_type() == StorageType::DEVICE || input.storage_type() == StorageType::MULTI_DEVICE); auto kernel_config_val = init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4); operation::launch_op( From a40994414738fe9f11270d0b95d4531cb6eac05a Mon Sep 17 00:00:00 2001 From: Stuti Raizada Date: Wed, 5 Jun 2024 09:16:41 +0000 Subject: [PATCH 53/53] #5337: dense matmul after all-gather --- .../t3000/mixtral8x7b/tt/mixtral_attention.py | 34 ++++++------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index d22af394cf0e..4b10f62a6ad6 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -90,11 +90,11 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.device_mesh, - mesh_mapper=ShardTensorToMesh(self.device_mesh, dim=-2), + mesh_mapper=ReplicateTensorToMesh(self.device_mesh), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], - cache_file_name=cache_name(f"wo_multidevice4d"), + cache_file_name=cache_name(f"wo_multidevice4d_H"), ) cache_k = torch.zeros( @@ -129,17 +129,6 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): self.scale = self.head_dim**-0.5 - reduce_mask_torch = torch.zeros(1, 1, self.max_batch_size, self.max_batch_size * 8) - for i in range(self.max_batch_size): - reduce_mask_torch[:, :, i, range(i, self.max_batch_size * 8, self.max_batch_size)] = 1 - self.reduce_mask = ttnn.from_torch( - reduce_mask_torch, - device=self.device_mesh, - mesh_mapper=ReplicateTensorToMesh(self.device_mesh), - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - ) - self.compute_kernel = self.model_args.get_compute_kernel_config() self.compute_kernel_attn = self.model_args.get_compute_kernel_attn_config() @@ -300,16 +289,19 @@ def forward( ) attn_output_1B4D.deallocate(True) - # attn_output_11BH = ttnn.experimental.tensor.sharded_to_interleaved( - # attn_output_11BH, output_mem_config=ttnn.L1_MEMORY_CONFIG - # ) + attn_output_11BH = ttnn.experimental.tensor.sharded_to_interleaved( + attn_output_11BH, output_mem_config=ttnn.L1_MEMORY_CONFIG + ) ### # Output matmul ### + # All gather + dense_outputs_11BH_gathered = ttnn.all_gather(attn_output_11BH, dim=3, num_links=1) - dense_out_11BH = ttnn.experimental.operations.primary.matmul( - attn_output_11BH, + # return the sum of the outputs + dense_outputs_11BH = ttnn.experimental.operations.primary.matmul( + dense_outputs_11BH_gathered, wo, output_mem_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"], # compute_with_storage_grid_size=(8, 8), @@ -317,10 +309,6 @@ def forward( compute_kernel_config=self.compute_kernel, output_dtype=ttnn.bfloat8_b, ) - attn_output_11BH.deallocate(True) - # All gather - dense_outputs_11BH = ttnn.all_gather(dense_out_11BH, dim=2, num_links=1) - # return the sum of the outputs - dense_outputs_11BH = ttnn.experimental.operations.primary.matmul(self.reduce_mask, dense_outputs_11BH) + dense_outputs_11BH_gathered.deallocate(True) return dense_outputs_11BH