From 0e98d33aa967b7866f2619b9461fe4b4084aebe4 Mon Sep 17 00:00:00 2001 From: asaigal Date: Thu, 11 Apr 2024 15:04:31 +0000 Subject: [PATCH] #0: Add Async-Mode support for Falcon 7B: - Uplift all falcon ops to use launch_op API - Resolve issues in async mode - Add async mode tests to multichip CI --- .../tests/multi_chip/test_falcon_attention.py | 11 +++ .../tests/multi_chip/test_falcon_causallm.py | 72 ++++++++------ .../tests/multi_chip/test_falcon_decoder.py | 11 +++ .../tests/multi_chip/test_falcon_mlp.py | 11 +++ .../run_frequent_regressions_multi_device.sh | 1 + tt_eager/tensor/tensor.cpp | 15 ++- tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp | 78 ++++++++------- ...op_multi_core_reuse_mcast_1d_optimized.cpp | 2 - .../eltwise_unary/eltwise_unary_op.cpp | 6 -- .../op_library/layernorm/layernorm_op.hpp | 95 ++++++++++--------- .../tt_dnn/op_library/nlp_tms/nlp_tms.hpp | 14 ++- .../tt_dnn/op_library/operation_history.cpp | 6 +- .../tt_dnn/op_library/operation_history.hpp | 1 + .../tt_dnn/op_library/permute/permute_op.cpp | 28 +++--- .../rotary_embedding/rotary_embedding_op.hpp | 65 +++++++------ tt_eager/tt_dnn/op_library/run_operation.hpp | 5 + .../tt_dnn/op_library/softmax/softmax_op.cpp | 36 ++++--- .../transformer_tms/transformer_tms.hpp | 13 ++- tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp | 31 +++--- .../update_cache/update_cache_op.hpp | 16 +++- tt_metal/impl/dispatch/work_executor.hpp | 4 +- tt_metal/tt_metal.cpp | 3 +- 22 files changed, 321 insertions(+), 203 deletions(-) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py index 2ab45982015..9c1a4a0c943 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py @@ -65,6 +65,10 @@ def torch_model(): ], indirect=True, ) +@pytest.mark.parametrize( + "enable_async", + [True, False], +) def test_falcon_attention( device_mesh, model_name, @@ -75,7 +79,11 @@ def test_falcon_attention( expected_pcc, model_config_str, torch_model, + enable_async, ): + for device in device_mesh.get_device_ids(): + device_mesh.get_device(device).enable_async(enable_async) + torch.manual_seed(0) batch = device_batch_size * device_mesh.get_num_devices() if llm_mode == "decode": @@ -178,3 +186,6 @@ def test_falcon_attention( assert_with_pcc( pytorch_layer_present[1].squeeze(1), tt_layer_present[1].to(pytorch_layer_present[1].dtype), expected_pcc ) + + for device in device_mesh.get_device_ids(): + device_mesh.get_device(device).enable_async(False) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py index 308597dd424..6c62fd0e9ac 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py @@ -62,6 +62,10 @@ ], indirect=True, ) +@pytest.mark.parametrize( + "enable_async, num_loops", + ((True, 20), (False, 1)), +) def test_falcon_causal_lm( device_mesh, use_program_cache, @@ -73,7 +77,12 @@ def test_falcon_causal_lm( num_layers, expected_pcc, model_config_str, + enable_async, + num_loops, ): + for device in device_mesh.get_device_ids(): + device_mesh.get_device(device).enable_async(enable_async) + torch.manual_seed(0) batch = device_batch_size * device_mesh.get_num_devices() if llm_mode == "decode": @@ -159,35 +168,41 @@ def convert_to_ttnn(model, name): ) # TODO: Generate embeddings and attention_mask on device if llm_mode == "prefill": - tt_outs = [] - tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing( - llm_mode, model_input, kv_cache_len, num_input_tokens=seq_len - ) - tt_out, tt_layer_present = tt_FalconCausalLM( - input_embeddings=tt_embeddings, - llm_mode=llm_mode, - attention_mask=tt_attention_mask, - user_id=0, - layer_past=tt_layer_past, - layer_past_len=kv_cache_len, - use_cache=True, - ) - tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(device_mesh, dim=shard_dim)).squeeze(1) + for loop in range(num_loops): + tt_outs = [] + tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing( + llm_mode, model_input, kv_cache_len, num_input_tokens=seq_len + ) + tt_out, tt_layer_present = tt_FalconCausalLM( + input_embeddings=tt_embeddings, + llm_mode=llm_mode, + attention_mask=tt_attention_mask, + user_id=0, + layer_past=tt_layer_past, + layer_past_len=kv_cache_len, + use_cache=True, + ) + # Explicitly move tensor to host ... in async mode this is faster than calling from torch directly, + # due to parallelization of tensor shards + tt_out = ttnn.from_device(tt_out) + tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(device_mesh, dim=shard_dim)).squeeze(1) elif llm_mode == "decode": - tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing( - llm_mode, model_input, kv_cache_len, num_input_tokens=kv_len - ) - tt_out, tt_layer_present = tt_FalconCausalLM( - input_embeddings=tt_embeddings, - llm_mode=llm_mode, - attention_mask=tt_attention_mask, - layer_past=tt_layer_past, - layer_past_len=kv_cache_len, - use_cache=True, - ) - tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(device_mesh, dim=shard_dim)).squeeze(1) - tt_out = tt_out.transpose(0, 1) + for loop in range(num_loops): + tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing( + llm_mode, model_input, kv_cache_len, num_input_tokens=kv_len + ) + tt_out, tt_layer_present = tt_FalconCausalLM( + input_embeddings=tt_embeddings, + llm_mode=llm_mode, + attention_mask=tt_attention_mask, + layer_past=tt_layer_past, + layer_past_len=kv_cache_len, + use_cache=True, + ) + tt_out = ttnn.from_device(tt_out) + tt_out = ttnn.to_torch(tt_out, mesh_composer=ConcatMeshToTensor(device_mesh, dim=shard_dim)).squeeze(1) + tt_out = tt_out.transpose(0, 1) passed, pcc = assert_with_pcc(pytorch_out, tt_out.to(pytorch_out.dtype), expected_pcc) logger.success(f"Passed: pcc: {pcc}, expected: {expected_pcc}") @@ -223,3 +238,6 @@ def convert_to_ttnn(model, name): logger.success(f"Passed: pcc: {pcc}, expected: {expected_pcc}") logger.info("Falcon CausalLM Passed!") + + for device in device_mesh.get_device_ids(): + device_mesh.get_device(device).enable_async(False) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py index 0363ce12e37..fb6dd258f07 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_decoder.py @@ -65,6 +65,10 @@ def torch_model(): ], indirect=True, ) +@pytest.mark.parametrize( + "enable_async", + [True, False], +) def test_falcon_decoder( device_mesh, model_name, @@ -75,7 +79,11 @@ def test_falcon_decoder( expected_pcc, model_config_str, torch_model, + enable_async, ): + for device in device_mesh.get_device_ids(): + device_mesh.get_device(device).enable_async(enable_async) + torch.manual_seed(0) batch = device_batch_size * device_mesh.get_num_devices() if llm_mode == "decode": @@ -176,3 +184,6 @@ def test_falcon_decoder( assert_with_pcc( pytorch_layer_present[1].squeeze(1), tt_layer_present[1].to(pytorch_layer_present[1].dtype), expected_pcc ) + + for device in device_mesh.get_device_ids(): + device_mesh.get_device(device).enable_async(False) diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py index 6a57c606e83..f5584a1b879 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py @@ -56,6 +56,10 @@ def torch_model(): ], indirect=True, ) +@pytest.mark.parametrize( + "enable_async", + [True, False], +) def test_falcon_mlp( device_mesh, model_name, @@ -64,7 +68,11 @@ def test_falcon_mlp( expected_pcc, model_config_str, torch_model, + enable_async, ): + for device in device_mesh.get_device_ids(): + device_mesh.get_device(device).enable_async(enable_async) + torch.manual_seed(0) configuration = transformers.FalconConfig.from_pretrained(PRETRAINED_MODEL_NAME) @@ -100,3 +108,6 @@ def test_falcon_mlp( expected_pcc, ) logger.success(f"Passed: pcc: {pcc}, expected: {expected_pcc}") + + for device in device_mesh.get_device_ids(): + device_mesh.get_device(device).enable_async(False) diff --git a/tests/scripts/run_frequent_regressions_multi_device.sh b/tests/scripts/run_frequent_regressions_multi_device.sh index 6d646736871..9af83da8c86 100755 --- a/tests/scripts/run_frequent_regressions_multi_device.sh +++ b/tests/scripts/run_frequent_regressions_multi_device.sh @@ -19,6 +19,7 @@ pytest tests/ttnn/unit_tests/test_multi_device.py pytest models/demos/ttnn_falcon7b/tests/multi_chip -k test_falcon_mlp pytest models/demos/ttnn_falcon7b/tests/multi_chip -k test_falcon_attention pytest models/demos/ttnn_falcon7b/tests/multi_chip -k test_falcon_decoder +pytest models/demos/ttnn_falcon7b/tests/multi_chip -k test_falcon_causallm # Llama2_70b related cached files and tests (the test should parse env variables similar to these) export LLAMA_CKPT_DIR=/mnt/MLPerf/tt_dnn-models/llama-2/llama-2-70b-repacked/ diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index b426eba658f..67d1953e603 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -136,10 +136,12 @@ void Tensor::deallocate(bool force) { std::visit([force, worker] (auto&& s) { using type = std::decay_t; if constexpr (std::is_same_v) { - if (force or s.buffers.at(worker->id()).use_count() == 1) { - DeallocateBuffer(*(s.buffers.at(worker->id()))); + if (s.num_buffers()) { + if (force or s.buffers.at(worker->id()).use_count() == 1) { + DeallocateBuffer(*(s.buffers.at(worker->id()))); + } + s.buffers.at(worker->id()).reset(); } - s.buffers.at(worker->id()).reset(); } }, this->tensor_attributes->storage); }); @@ -392,7 +394,12 @@ Tensor Tensor::cpu(bool blocking) const { else { host_tensor.set_populated(target_device); } - }, blocking); + }); + } + if (blocking) { + for (auto target_device : workers) { + target_device->synchronize(); + } } // Update main_thread_ref_count for tensor after pushing to queue. this->tensor_attributes->update_main_thread_ref_count(workers.at(0), original_tensor_ref_count); diff --git a/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp b/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp index cf02bb55a07..476bf91d0ed 100644 --- a/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp +++ b/tt_eager/tt_dnn/op_library/bcast/bcast_op.hpp @@ -74,42 +74,48 @@ inline Tensor bcast( BcastOpMath bcast_op, BcastOpDim bcast_dim, const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { - using tt::constants::TILE_HEIGHT; - using tt::constants::TILE_WIDTH; - - if (bcast_dim == BcastOpDim::W) { - TT_FATAL(input_tensor_a.get_legacy_shape()[2] == input_tensor_b.get_legacy_shape()[2]); - if (input_tensor_b.get_layout() == Layout::TILE) { - TT_FATAL(input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH); - } else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) { - TT_FATAL(input_tensor_b.get_legacy_shape()[3] == 1 || input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH); - } else { - TT_FATAL(false, "Unsupported layout"); - } - } else if (bcast_dim == BcastOpDim::H) { - TT_FATAL(input_tensor_a.get_legacy_shape()[3] == input_tensor_b.get_legacy_shape()[3]); - if (input_tensor_b.get_layout() == Layout::TILE) { - TT_FATAL(input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT); - } else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) { - TT_FATAL(input_tensor_b.get_legacy_shape()[2] == 1 || input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT); - } else { - TT_FATAL(false, "Unsupported layout"); - } - } else if (bcast_dim == BcastOpDim::HW) { - if (input_tensor_b.get_layout() == Layout::TILE) { - TT_FATAL( - input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT && - input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH); - } else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) { - TT_FATAL( - (input_tensor_b.get_legacy_shape()[2] == 1 && input_tensor_b.get_legacy_shape()[3] == 1) || - (input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT && - input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH)); - } - } - return operation::run_with_autoformat( - EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config, false}, {input_tensor_a, input_tensor_b}) - .at(0); + + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; + operation::launch_with_autoformat( + [bcast_op, bcast_dim, output_mem_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + using tt::constants::TILE_HEIGHT; + using tt::constants::TILE_WIDTH; + auto& input_tensor_a = input_tensors.at(0); + auto& input_tensor_b = input_tensors.at(1); + if (bcast_dim == BcastOpDim::W) { + TT_FATAL(input_tensor_a.get_legacy_shape()[2] == input_tensor_b.get_legacy_shape()[2]); + if (input_tensor_b.get_layout() == Layout::TILE) { + TT_FATAL(input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH); + } else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) { + TT_FATAL(input_tensor_b.get_legacy_shape()[3] == 1 || input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH); + } else { + TT_FATAL(false, "Unsupported layout"); + } + } else if (bcast_dim == BcastOpDim::H) { + TT_FATAL(input_tensor_a.get_legacy_shape()[3] == input_tensor_b.get_legacy_shape()[3]); + if (input_tensor_b.get_layout() == Layout::TILE) { + TT_FATAL(input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT); + } else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) { + TT_FATAL(input_tensor_b.get_legacy_shape()[2] == 1 || input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT); + } else { + TT_FATAL(false, "Unsupported layout"); + } + } else if (bcast_dim == BcastOpDim::HW) { + if (input_tensor_b.get_layout() == Layout::TILE) { + TT_FATAL( + input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT && + input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH); + } else if (input_tensor_b.get_layout() == Layout::ROW_MAJOR) { + TT_FATAL( + (input_tensor_b.get_legacy_shape()[2] == 1 && input_tensor_b.get_legacy_shape()[3] == 1) || + (input_tensor_b.get_legacy_shape()[2] == TILE_HEIGHT && + input_tensor_b.get_legacy_shape()[3] == TILE_WIDTH)); + } + } + return operation::run_with_autoformat( + EltwiseBinaryBroadcast{bcast_op, bcast_dim, output_mem_config}, {input_tensor_a, input_tensor_b}); + }, {input_tensor_a, input_tensor_b}, output_tensors); + return output_tensors.at(0); } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp index 6a1ad696c04..03fb0eb6f44 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_1d_optimized/bmm_op_multi_core_reuse_mcast_1d_optimized.cpp @@ -291,7 +291,6 @@ operation::ProgramWithCallbacks create_program_mcast_in0( all_cores, tt_metal::DataMovementConfig{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = in1_noc, .compile_args = in1_sender_writer_compile_time_args, .defines = mm_kernel_in1_sender_writer_defines}); - KernelHandle mm_kernel_in0_receiver_id = 0; if (!in0_is_sharded) { mm_kernel_in0_receiver_id = tt_metal::CreateKernel( @@ -872,7 +871,6 @@ operation::ProgramWithCallbacks create_program_mcast_in1( mcast_sender, tt_metal::DataMovementConfig{.processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = in1_noc, .compile_args = in1_sender_writer_compile_time_args, .defines = mm_kernel_in1_sender_writer_defines}); - auto mm_kernel_in1_receiver_writer_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_receiver_writer_padding.cpp", diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index f24b3f4ddeb..2449384cd66 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -339,12 +339,6 @@ const operation::Hash EltwiseUnary::compute_program_hash(const std::vector Tensor tie_binop_to_unary(const Tensor& input_tensor, float value, const MemoryConfig& output_mem_config) { - if (is_multi_device_tensor(input_tensor)) { - return transform(input_tensor, [&](const Tensor& tensor) { - return tie_binop_to_unary(tensor, value, output_mem_config); - }); - } - Tensor t_value = mk_tiled_scalar(value); return bcast(input_tensor, t_value, OP, BcastOpDim::HW); } diff --git a/tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp b/tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp index a7d010925bd..a4105ad05b7 100644 --- a/tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp +++ b/tt_eager/tt_dnn/op_library/layernorm/layernorm_op.hpp @@ -95,38 +95,44 @@ template struct make_layernorm { Tensor operator()( const Tensor &a, float eps, std::optional gamma = std::nullopt, std::optional beta = std::nullopt, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional compute_kernel_config = std::nullopt) const { - TT_FATAL(a.get_legacy_shape()[-1] % TILE_WIDTH == 0, "Normalizing on last dim cannot be padded"); - - if (gamma.has_value() and gamma.value().get_layout() == Layout::TILE) { - TT_FATAL( - gamma.value().get_legacy_shape()[-1] == a.get_legacy_shape()[-1], - "Gamma width must be equal to input width"); - } - if (beta.has_value() and beta.value().get_layout() == Layout::TILE) { - TT_FATAL( - beta.value().get_legacy_shape()[-1] == a.get_legacy_shape()[-1], - "Beta width must be equal to input width"); - } - - auto original_shape = a.get_shape(); - auto a_4D = ttnn::unsqueeze_to_4D(a); - std::optional gamma_4D = gamma.has_value() ? ttnn::unsqueeze_to_4D(gamma.value()) : gamma; - std::optional beta_4D = beta.has_value() ? ttnn::unsqueeze_to_4D(beta.value()) : beta; - - auto arch = - a.storage_type() == StorageType::DEVICE ? a.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); - auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, true, false, false); - auto output = operation::run_with_autoformat( - LayerNorm{ - .norm_type = norm_type, - .eps = eps, - .output_mem_config = mem_config, - .program_config = LayerNormDefaultProgramConfig(), - .compute_kernel_config = kernel_config_val}, - {a_4D}, - {std::nullopt, gamma_4D, beta_4D}) - .at(0); - return ttnn::reshape(output, original_shape); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a}))}; + operation::launch_with_autoformat( + [eps, gamma, beta, mem_config, compute_kernel_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + auto& a = input_tensors.at(0); + TT_FATAL(a.get_legacy_shape()[-1] % TILE_WIDTH == 0, "Normalizing on last dim cannot be padded"); + + if (gamma.has_value() and gamma.value().get_layout() == Layout::TILE) { + TT_FATAL( + gamma.value().get_legacy_shape()[-1] == a.get_legacy_shape()[-1], + "Gamma width must be equal to input width"); + } + if (beta.has_value() and beta.value().get_layout() == Layout::TILE) { + TT_FATAL( + beta.value().get_legacy_shape()[-1] == a.get_legacy_shape()[-1], + "Beta width must be equal to input width"); + } + + auto original_shape = a.get_shape(); + auto a_4D = ttnn::unsqueeze_to_4D(a); + std::optional gamma_4D = gamma.has_value() ? ttnn::unsqueeze_to_4D(gamma.value()) : gamma; + std::optional beta_4D = beta.has_value() ? ttnn::unsqueeze_to_4D(beta.value()) : beta; + + auto arch = + a.storage_type() == StorageType::DEVICE ? a.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); + auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, true, false, false); + auto output = operation::run_with_autoformat( + LayerNorm{ + .norm_type = norm_type, + .eps = eps, + .output_mem_config = mem_config, + .program_config = LayerNormDefaultProgramConfig(), + .compute_kernel_config = kernel_config_val}, + {a_4D}, + {std::nullopt, gamma_4D, beta_4D}) + .at(0); + return {ttnn::reshape(output, original_shape)}; + }, {a}, output_tensors); + return output_tensors.at(0); } }; @@ -210,17 +216,15 @@ struct make_layernorm { template struct make_add_layernorm { Tensor operator()( - const Tensor& a, - const Tensor& b, - float eps, - std::optional gamma = std::nullopt, - std::optional beta = std::nullopt, - const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - const LayerNormProgramConfig& program_config = LayerNormDefaultProgramConfig{}, - std::optional compute_kernel_config = std::nullopt) const { - auto arch = a.storage_type() == StorageType::DEVICE ? a.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); - auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, true, false, false); - return operation::run( + const Tensor &a, const Tensor& b, float eps, std::optional gamma = std::nullopt, std::optional beta = std::nullopt, const MemoryConfig& mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, const LayerNormProgramConfig& program_config = LayerNormDefaultProgramConfig{}, std::optional compute_kernel_config = std::nullopt) const { + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a, b}))}; + operation::launch_op( + [eps, gamma, beta, mem_config, program_config, compute_kernel_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + const auto& a = input_tensors.at(0); + const auto& b = input_tensors.at(1); + auto arch = a.storage_type() == StorageType::DEVICE ? a.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); + auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, true, false, false); + return operation::run( LayerNorm{ .norm_type = layernorm_type, .eps = eps, @@ -228,8 +232,9 @@ struct make_add_layernorm { .program_config = program_config, .compute_kernel_config = kernel_config_val}, {a}, - {b, gamma, beta}) - .at(0); + {b, gamma, beta}); + }, {a, b}, output_tensors); + return output_tensors.at(0); } }; diff --git a/tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.hpp b/tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.hpp index db1b32cea3f..89734cdd7fb 100644 --- a/tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.hpp +++ b/tt_eager/tt_dnn/op_library/nlp_tms/nlp_tms.hpp @@ -80,7 +80,12 @@ struct NlpConcatHeads { inline std::vector nlp_create_qkv_heads_falcon7b(const Tensor &input_tensor_a, const MemoryConfig& mem_config) { // TODO: hard-coded for falcon-7b; can delete if we switch to the more generic one (but perf may be worse) - return operation::run(NlpCreateHeadsFalcon7B{mem_config}, {input_tensor_a}); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a})), Tensor(operation::get_workers_for_op_output({input_tensor_a})), Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; + operation::launch_op( + [mem_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + return operation::run(NlpCreateHeadsFalcon7B{mem_config}, input_tensors); + }, {input_tensor_a}, output_tensors); + return output_tensors; } inline std::vector nlp_create_qkv_heads( const Tensor &input_tensor, std::optional input_tensor_kv, @@ -105,7 +110,12 @@ inline std::vector nlp_create_qkv_heads( return operation::run(NlpCreateHeads{num_heads, num_kv_heads_val, head_dim, transpose_k_heads, mem_config}, {input_tensor}, {input_tensor_kv}); } inline Tensor nlp_concat_heads(const Tensor &input_tensor_a, const MemoryConfig& mem_config) { - return operation::run(NlpConcatHeads{mem_config}, {input_tensor_a}).at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; + operation::launch_op( + [mem_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + return operation::run(NlpConcatHeads{mem_config}, input_tensors); + }, {input_tensor_a}, output_tensors); + return output_tensors.at(0); } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/operation_history.cpp b/tt_eager/tt_dnn/op_library/operation_history.cpp index 4a64cb60a0f..f5b5086fa91 100644 --- a/tt_eager/tt_dnn/op_library/operation_history.cpp +++ b/tt_eager/tt_dnn/op_library/operation_history.cpp @@ -20,6 +20,7 @@ OperationHistory::~OperationHistory() { } void OperationHistory::append(OperationRecord&& record) { + std::scoped_lock lock(op_history_mutex); TT_ASSERT(record.input_tensor_records.size() <= 5); this->records.push_back(std::move(record)); } @@ -126,7 +127,10 @@ void OperationHistory::dump_to_csv() { } } -void OperationHistory::clear() { this->records.clear(); } +void OperationHistory::clear() { + std::scoped_lock lock(op_history_mutex); + this->records.clear(); +} } // namespace detail diff --git a/tt_eager/tt_dnn/op_library/operation_history.hpp b/tt_eager/tt_dnn/op_library/operation_history.hpp index 81ea0de8d77..be338f507bf 100644 --- a/tt_eager/tt_dnn/op_library/operation_history.hpp +++ b/tt_eager/tt_dnn/op_library/operation_history.hpp @@ -50,6 +50,7 @@ struct OperationHistory { void clear(); private: + std::mutex op_history_mutex; std::vector records; }; diff --git a/tt_eager/tt_dnn/op_library/permute/permute_op.cpp b/tt_eager/tt_dnn/op_library/permute/permute_op.cpp index 6e4730bb8dc..1409d0f9978 100644 --- a/tt_eager/tt_dnn/op_library/permute/permute_op.cpp +++ b/tt_eager/tt_dnn/op_library/permute/permute_op.cpp @@ -105,20 +105,20 @@ Tensor permute_(const Tensor &a, std::vector dims, const MemoryConfig& } Tensor permute(const Tensor &a, std::vector dims, const MemoryConfig& output_mem_config) { - if (is_multi_device_tensor(a)) { - return transform(a, [&](const Tensor& tensor) { - return permute(tensor, dims, output_mem_config); - }); - } - - std::vector normalized_dims(dims.size()); - std::transform(dims.begin(), dims.end(), normalized_dims.begin(), [a](std::int64_t idx) {return a.get_legacy_shape().get_normalized_index(idx);}); - std::vector seq_dims(dims.size()); - std::iota(seq_dims.begin(), seq_dims.end(), 0); - if (normalized_dims == seq_dims) { - return AutoFormat::move_tensor_to_mem_config(a, output_mem_config); - } - return operation::decorate_as_composite(__func__, permute_)(a, normalized_dims, output_mem_config); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({a}))}; + operation::launch_with_autoformat( + [dims, output_mem_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + auto& a = input_tensors.at(0); + std::vector normalized_dims(dims.size()); + std::transform(dims.begin(), dims.end(), normalized_dims.begin(), [a](std::int64_t idx) {return a.get_legacy_shape().get_normalized_index(idx);}); + std::vector seq_dims(dims.size()); + std::iota(seq_dims.begin(), seq_dims.end(), 0); + if (normalized_dims == seq_dims) { + return {AutoFormat::move_tensor_to_mem_config(a, output_mem_config)}; + } + return {operation::decorate_as_composite(__func__, permute_)(a, normalized_dims, output_mem_config)}; + }, {a}, output_tensors); + return output_tensors.at(0); } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp b/tt_eager/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp index fb04db3d772..674c7db8638 100644 --- a/tt_eager/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp +++ b/tt_eager/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp @@ -50,35 +50,42 @@ inline Tensor rotary_embedding( std::optional token_idx = std::nullopt, const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::optional compute_kernel_config = std::nullopt) { - TT_FATAL(input_tensor.get_legacy_shape()[-1] % (TILE_WIDTH * 2) == 0, "Input X dim must be divisible into tiles"); - uint32_t seq_len = input_tensor.get_legacy_shape()[-2]; - uint32_t B = input_tensor.get_legacy_shape()[0]; - uint32_t X = input_tensor.get_legacy_shape()[-1]; - TT_FATAL(cos.get_legacy_shape() == sin.get_legacy_shape(), "Cos and Sin dims must match"); - TT_FATAL(cos.get_legacy_shape()[0] == 1 && cos.get_legacy_shape()[1] == 1 && cos.get_legacy_shape()[-1] == X, "Cos dims must match input dims"); - if (token_idx.has_value()) { - seq_len = input_tensor.get_legacy_shape()[0]; - TT_FATAL(seq_len == 1); - TT_FATAL(cos.get_legacy_shape()[-2] >= token_idx, "Cos dims must match input dims"); - } else { - TT_FATAL(cos.get_legacy_shape()[-2] >= seq_len, "Cos dims must match input dims"); - } - - auto arch = input_tensor.storage_type() == StorageType::DEVICE ? input_tensor.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); - auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, true, false, false); - - Shape input_pad_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape()); - FormatParams input_format_params = {.pad_shape = input_pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; - Shape cos_pad_shape = AutoFormat::pad_to_tile_shape(cos.get_legacy_shape()); - FormatParams cos_format_params = {.pad_shape = cos_pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; - Shape sin_pad_shape = AutoFormat::pad_to_tile_shape(sin.get_legacy_shape()); - FormatParams sin_format_params = {.pad_shape = sin_pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; - return operation::run_with_autoformat( - RotaryEmbedding{seq_len, token_idx, output_mem_config, kernel_config_val}, - {input_tensor, cos, sin}, - {input_format_params, cos_format_params, sin_format_params}, - {Layout::TILE}) - .at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor, cos, sin}))}; + operation::launch_with_autoformat( + [token_idx, output_mem_config, compute_kernel_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + auto& input_tensor = input_tensors.at(0); + auto& cos = input_tensors.at(1); + auto& sin = input_tensors.at(2); + TT_FATAL(input_tensor.get_legacy_shape()[-1] % (TILE_WIDTH * 2) == 0, "Input X dim must be divisible into tiles"); + uint32_t seq_len = input_tensor.get_legacy_shape()[-2]; + uint32_t B = input_tensor.get_legacy_shape()[0]; + uint32_t X = input_tensor.get_legacy_shape()[-1]; + TT_FATAL(cos.get_legacy_shape() == sin.get_legacy_shape(), "Cos and Sin dims must match"); + TT_FATAL(cos.get_legacy_shape()[0] == 1 && cos.get_legacy_shape()[1] == 1 && cos.get_legacy_shape()[-1] == X, "Cos dims must match input dims"); + if (token_idx.has_value()) { + seq_len = input_tensor.get_legacy_shape()[0]; + TT_FATAL(seq_len == 1); + TT_FATAL(cos.get_legacy_shape()[-2] >= token_idx, "Cos dims must match input dims"); + } else { + TT_FATAL(cos.get_legacy_shape()[-2] >= seq_len, "Cos dims must match input dims"); + } + + auto arch = input_tensor.storage_type() == StorageType::DEVICE ? input_tensor.device()->arch() : AutoFormat::GetDefaultDevice()->arch(); + auto kernel_config_val = init_device_compute_kernel_config(arch, compute_kernel_config, MathFidelity::HiFi4, true, false, false); + + Shape input_pad_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape()); + FormatParams input_format_params = {.pad_shape = input_pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; + Shape cos_pad_shape = AutoFormat::pad_to_tile_shape(cos.get_legacy_shape()); + FormatParams cos_format_params = {.pad_shape = cos_pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; + Shape sin_pad_shape = AutoFormat::pad_to_tile_shape(sin.get_legacy_shape()); + FormatParams sin_format_params = {.pad_shape = sin_pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; + return operation::run_with_autoformat( + RotaryEmbedding{seq_len, token_idx, output_mem_config, kernel_config_val}, + {input_tensor, cos, sin}, + {input_format_params, cos_format_params, sin_format_params}, + {Layout::TILE}); + }, {input_tensor, cos, sin}, output_tensors); + return output_tensors.at(0); } } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/run_operation.hpp b/tt_eager/tt_dnn/op_library/run_operation.hpp index 71395a0cbfc..df0a8a32594 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.hpp +++ b/tt_eager/tt_dnn/op_library/run_operation.hpp @@ -49,22 +49,27 @@ struct RunOperationState { RunOperationState() {} void push_composite_parent_name(const char* parent_name) { + std::scoped_lock lock(parent_name_mutex); this->composite_parent_names.push_back(parent_name); } void pop_composite_parent_name() { + std::scoped_lock lock(parent_name_mutex); this->composite_parent_names.pop_back(); } bool is_composite_operation() const { + std::scoped_lock lock(parent_name_mutex); return not composite_parent_names.empty(); } const auto& get_composite_parent_names() const { + std::scoped_lock lock(parent_name_mutex); return this->composite_parent_names; } private: + mutable std::mutex parent_name_mutex; std::vector composite_parent_names{}; }; diff --git a/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp b/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp index 093509166e2..30e8425eeb3 100644 --- a/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp +++ b/tt_eager/tt_dnn/op_library/softmax/softmax_op.cpp @@ -205,21 +205,27 @@ Tensor softmax(const Tensor& input_tensor, const MemoryConfig& output_mem_config namespace transformers { Tensor scale_mask_softmax(const Tensor& input_tensor, std::optional scale, std::optional mask, const MemoryConfig& output_mem_config, const bool is_causal_mask, std::optional compute_kernel_config) { - Shape input_pad_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape()); - FormatParams input_format_params = {.pad_shape=input_pad_shape, .pad_value=-std::numeric_limits::infinity(), .target_layout=Layout::TILE}; - std::optional mask_format_params = std::nullopt; - if (mask.has_value()) { - TT_FATAL(input_tensor.get_legacy_shape()[-1] == mask.value().get_legacy_shape()[-1]); - TT_FATAL(input_tensor.get_legacy_shape()[0] == mask.value().get_legacy_shape()[0]); - TT_FATAL(mask.value().get_legacy_shape()[-2] == 1 or mask.value().get_legacy_shape()[-2] == TILE_HEIGHT); - for (uint32_t i = 1; i < input_tensor.get_legacy_shape().rank() - 2; i++) { - TT_FATAL(mask.value().get_legacy_shape()[i] == 1); - } - Shape mask_pad_shape = AutoFormat::pad_to_tile_shape(mask.value().get_legacy_shape()); - mask_format_params = {.pad_shape=mask_pad_shape, .pad_value=-std::numeric_limits::infinity(), .target_layout=Layout::TILE}; - } - auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); - return operation::run_with_autoformat(tt::operations::primary::Softmax{.scale=scale, .inplace=false, .output_mem_config=output_mem_config, .is_causal_mask=is_causal_mask, .compute_kernel_config=kernel_config_val, .is_scale_causal_mask_hw_dims_softmax=false}, {input_tensor}, {input_format_params}, {Layout::TILE}, {mask}, {mask_format_params}).at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_with_autoformat( + [scale, mask, output_mem_config, is_causal_mask, compute_kernel_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + auto& input_tensor = input_tensors.at(0); + Shape input_pad_shape = AutoFormat::pad_to_tile_shape(input_tensor.get_legacy_shape()); + FormatParams input_format_params = {.pad_shape=input_pad_shape, .pad_value=-std::numeric_limits::infinity(), .target_layout=Layout::TILE}; + std::optional mask_format_params = std::nullopt; + if (mask.has_value()) { + TT_FATAL(input_tensor.get_legacy_shape()[-1] == mask.value().get_legacy_shape()[-1]); + TT_FATAL(input_tensor.get_legacy_shape()[0] == mask.value().get_legacy_shape()[0]); + TT_FATAL(mask.value().get_legacy_shape()[-2] == 1 or mask.value().get_legacy_shape()[-2] == TILE_HEIGHT); + for (uint32_t i = 1; i < input_tensor.get_legacy_shape().rank() - 2; i++) { + TT_FATAL(mask.value().get_legacy_shape()[i] == 1); + } + Shape mask_pad_shape = AutoFormat::pad_to_tile_shape(mask.value().get_legacy_shape()); + mask_format_params = {.pad_shape=mask_pad_shape, .pad_value=-std::numeric_limits::infinity(), .target_layout=Layout::TILE}; + } + auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4, true, false, false); + return operation::run_with_autoformat(tt::operations::primary::Softmax{.scale=scale, .inplace=false, .output_mem_config=output_mem_config, .is_causal_mask=is_causal_mask, .compute_kernel_config=kernel_config_val}, {input_tensor}, {input_format_params}, {Layout::TILE}, {mask}, {mask_format_params}); + }, {input_tensor}, output_tensors); + return output_tensors.at(0); } } // namespace transformers } // namespace tt_metal diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp index e3be4f8104d..51b08327375 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.hpp @@ -40,7 +40,11 @@ struct SplitFusedQKVAndSplitHeads { }; inline std::tuple split_query_key_value_and_split_heads(const Tensor &input_tensor, const CoreCoord& compute_with_storage_grid_size, const MemoryConfig& mem_config, const uint32_t num_heads = 16) { - auto output_tensors = operation::run(SplitFusedQKVAndSplitHeads{compute_with_storage_grid_size, mem_config, num_heads}, {input_tensor}); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor})), Tensor(operation::get_workers_for_op_output({input_tensor})), Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_op( + [compute_with_storage_grid_size, mem_config, num_heads] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + return operation::run(SplitFusedQKVAndSplitHeads{compute_with_storage_grid_size, mem_config, num_heads}, input_tensors); + }, {input_tensor}, output_tensors); return {output_tensors.at(0), output_tensors.at(1), output_tensors.at(2)}; } @@ -56,7 +60,12 @@ struct ConcatenateHeads { }; inline Tensor concatenate_heads(const Tensor &input_tensor, const CoreCoord& compute_with_storage_grid_size, const MemoryConfig& mem_config) { - return operation::run(ConcatenateHeads{compute_with_storage_grid_size, mem_config}, {input_tensor}).at(0); + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; + operation::launch_op( + [compute_with_storage_grid_size, mem_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + return operation::run(ConcatenateHeads{compute_with_storage_grid_size, mem_config}, input_tensors); + }, {input_tensor}, output_tensors); + return output_tensors.at(0); } struct AttnMatmul { diff --git a/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp b/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp index d81d97d16ad..665c756b25f 100644 --- a/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp +++ b/tt_eager/tt_dnn/op_library/unpad/unpad_op.cpp @@ -209,19 +209,24 @@ const operation::Hash Unpad::compute_program_hash ( Tensor unpad(const Tensor &input_tensor_a, const Shape &output_tensor_start, const Shape &output_tensor_end, const MemoryConfig& output_mem_config) { // No-op (Will do a tensor copy) // TODO: We need to run asserts before this - auto input_tensor_shape = input_tensor_a.get_legacy_shape(); - const Shape output_tensor_shape = { - output_tensor_end[0] - output_tensor_start[0] + 1, - output_tensor_end[1] - output_tensor_start[1] + 1, - output_tensor_end[2] - output_tensor_start[2] + 1, - output_tensor_end[3] - output_tensor_start[3] + 1, - }; - if (input_tensor_a.get_legacy_shape() == output_tensor_shape) { - return AutoFormat::move_tensor_to_mem_config(input_tensor_a, output_mem_config); - } - - return operation::run(Unpad{output_tensor_start, output_tensor_end, output_mem_config, output_tensor_shape, input_tensor_shape}, {input_tensor_a}).at(0); - + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a}))}; + operation::launch_op( + [output_tensor_start, output_tensor_end, output_mem_config] (const std::vector& input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + auto& input_tensor_a = input_tensors.at(0); + auto input_tensor_shape = input_tensor_a.get_legacy_shape(); + const Shape output_tensor_shape = { + output_tensor_end[0] - output_tensor_start[0] + 1, + output_tensor_end[1] - output_tensor_start[1] + 1, + output_tensor_end[2] - output_tensor_start[2] + 1, + output_tensor_end[3] - output_tensor_start[3] + 1, + }; + if (input_tensor_a.get_legacy_shape() == output_tensor_shape) { + return {AutoFormat::move_tensor_to_mem_config(input_tensor_a, output_mem_config)}; + } + return operation::run(Unpad{output_tensor_start, output_tensor_end, output_mem_config, output_tensor_shape, input_tensor_shape}, {input_tensor_a}); + }, + {input_tensor_a}, output_tensors); + return output_tensors.at(0); } void UnpadOnHost::validate(const std::vector &input_tensors) const { diff --git a/tt_eager/tt_dnn/op_library/update_cache/update_cache_op.hpp b/tt_eager/tt_dnn/op_library/update_cache/update_cache_op.hpp index ba52efb62bb..fcd99d9620c 100644 --- a/tt_eager/tt_dnn/op_library/update_cache/update_cache_op.hpp +++ b/tt_eager/tt_dnn/op_library/update_cache/update_cache_op.hpp @@ -55,13 +55,23 @@ struct UpdateCache { }; inline Tensor fill_cache(const Tensor& cache_tensor, const Tensor& input_tensor, const uint32_t batch_idx) { - operation::run(UpdateCache{batch_idx, 0, 0, UpdateCacheOpType::FILL}, {cache_tensor, input_tensor}); + std::vector dummy_output_tensors = {Tensor(operation::get_workers_for_op_output({cache_tensor, input_tensor}))}; + operation::launch_op( + [batch_idx] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + return operation::run(UpdateCache{batch_idx, 0, 0, UpdateCacheOpType::FILL}, input_tensors); + }, {cache_tensor, input_tensor}, dummy_output_tensors); return cache_tensor; } inline Tensor update_cache(const Tensor& cache_tensor, const Tensor& input_tensor, const uint32_t update_idx, const uint32_t batch_offset, std::optional compute_kernel_config = std::nullopt) { - auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config); - operation::run(UpdateCache{0, update_idx, batch_offset, UpdateCacheOpType::UPDATE, kernel_config_val}, {cache_tensor, input_tensor}); + std::vector dummy_output_tensors = {Tensor(operation::get_workers_for_op_output({cache_tensor, input_tensor}))}; + operation::launch_op( + [update_idx, batch_offset, compute_kernel_config] (std::vector input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + auto& cache_tensor = input_tensors.at(0); + auto& input_tensor = input_tensors.at(1); + auto kernel_config_val = init_device_compute_kernel_config(input_tensor.device()->arch(), compute_kernel_config); + return operation::run(UpdateCache{0, update_idx, batch_offset, UpdateCacheOpType::UPDATE, kernel_config_val}, {cache_tensor, input_tensor}); + }, {cache_tensor, input_tensor}, dummy_output_tensors); return cache_tensor; } diff --git a/tt_metal/impl/dispatch/work_executor.hpp b/tt_metal/impl/dispatch/work_executor.hpp index c347aceacc3..e3ff8ab2923 100644 --- a/tt_metal/impl/dispatch/work_executor.hpp +++ b/tt_metal/impl/dispatch/work_executor.hpp @@ -80,8 +80,8 @@ class WorkExecutor { } inline void synchronize() { - if (this->worker_queue_mode == WorkExecutorMode::ASYNCHRONOUS) { - // Blocking = wait for queue flushed + if (this->worker_queue_mode == WorkExecutorMode::ASYNCHRONOUS and std::hash{}(std::this_thread::get_id()) == worker_queue.parent_thread_id.load()) { + // Blocking = wait for queue flushed. Only main thread can explcitly insert a synchronize, otherwise we have a deadlock. this->worker_queue.push([](){}); // Send flush command (i.e. empty function) // Wait for queue empty, i.e. flush command picked up while(not this->worker_queue.empty()) { diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index de1cb436eeb..40dffafcee2 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -44,8 +44,7 @@ void ConfigureKernelGroup(const Program &program, const KernelGroup *kernel_grou std::optional get_semaphore_address(const Program &program, const CoreRange &core_range) { std::optional address = nullopt; - static std::vector semaphore_histogram(NUM_SEMAPHORES, 0); - std::fill(semaphore_histogram.begin(), semaphore_histogram.end(), 0); + std::vector semaphore_histogram(NUM_SEMAPHORES, 0); for (auto x = core_range.start.x; x <= core_range.end.x; x++) { for (auto y = core_range.start.y; y <= core_range.end.y; y++) { CoreCoord logical_core(x, y);