diff --git a/models/demos/falcon7b/demo/demo.py b/models/demos/falcon7b/demo/demo.py index 5943021a7e86..edf081573d66 100644 --- a/models/demos/falcon7b/demo/demo.py +++ b/models/demos/falcon7b/demo/demo.py @@ -268,7 +268,7 @@ def run_falcon_demo_kv( tt_prefill_attention_mask[i].deallocate() tt_logits[i].deallocate() - # exit(0) + synchronize_devices(devices) logger.info("Finished 1st run prefill stage with compile!") @@ -409,60 +409,56 @@ def run_falcon_demo_kv( else: N = 15 N_warmup = 5 - for i in range(1): - kv_cache_len = num_input_tokens # This will increment by one after each decode - for output_token_index in range(N): - time_decode_inference_start = time.time() - ( - tt_decode_embeddings, - tt_decode_attention_mask, - ) = tt_FalconCausalLM.model_preprocessing( - "decode", decode_ids, kv_cache_len, num_input_tokens=kv_cache_len + 1 - ) - assert tt_decode_attention_mask is not None - - tt_logits, kv_cache = tt_FalconCausalLM( - input_embeddings=tt_decode_embeddings, - llm_mode="decode", - attention_mask=tt_decode_attention_mask, - layer_past=kv_cache, - layer_past_len=kv_cache_len, - use_cache=use_cache, - ) - synchronize_devices(devices) - time_decode_inference_end = time.time() - if output_token_index >= N_warmup: - time_decode_inference += time_decode_inference_end - time_decode_inference_start - - logits = torch.concat([tt2torch_tensor(tt_logits[i]).squeeze(1) for i in range(num_devices)], dim=-2) - - for i in range(num_devices): - tt_decode_embeddings[i].deallocate() - if tt_decode_attention_mask is not None: - tt_decode_attention_mask[i].deallocate() - tt_logits[i].deallocate() - - if not perf_mode: - if greedy_sampling: - decode_ids = post_processor(logits=logits, index=...).reshape(global_batch, 1) - else: - decode_ids = top_pk_logits_efficient(logits.reshape(global_batch, -1)).reshape(global_batch, 1) - - for user_id, user_decode_id in enumerate(decode_ids[:num_users]): - if user_decode_id == END_OF_TEXT: - prompt_is_done[user_id] = True - if prompt_is_done[user_id]: - decode_ids[user_id] = SPACE - - if all(prompt_is_done): - break - - generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1) - kv_cache_len += 1 - - # TODO: Remove if we don't want to print per generated token - os.system("clear") - print_output_prompts(generated_ids, tokenizer, batch_size) + for output_token_index in range(N): + time_decode_inference_start = time.time() + ( + tt_decode_embeddings, + tt_decode_attention_mask, + ) = tt_FalconCausalLM.model_preprocessing("decode", decode_ids, kv_cache_len, num_input_tokens=kv_cache_len + 1) + assert tt_decode_attention_mask is not None + + tt_logits, kv_cache = tt_FalconCausalLM( + input_embeddings=tt_decode_embeddings, + llm_mode="decode", + attention_mask=tt_decode_attention_mask, + layer_past=kv_cache, + layer_past_len=kv_cache_len, + use_cache=use_cache, + ) + synchronize_devices(devices) + time_decode_inference_end = time.time() + if output_token_index >= N_warmup: + time_decode_inference += time_decode_inference_end - time_decode_inference_start + + logits = torch.concat([tt2torch_tensor(tt_logits[i]).squeeze(1) for i in range(num_devices)], dim=-2) + + for i in range(num_devices): + tt_decode_embeddings[i].deallocate() + if tt_decode_attention_mask is not None: + tt_decode_attention_mask[i].deallocate() + tt_logits[i].deallocate() + + if not perf_mode: + if greedy_sampling: + decode_ids = post_processor(logits=logits, index=...).reshape(global_batch, 1) + else: + decode_ids = top_pk_logits_efficient(logits.reshape(global_batch, -1)).reshape(global_batch, 1) + + for user_id, user_decode_id in enumerate(decode_ids[:num_users]): + if user_decode_id == END_OF_TEXT: + prompt_is_done[user_id] = True + if prompt_is_done[user_id]: + decode_ids[user_id] = SPACE + + if all(prompt_is_done): + break + + generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1) + kv_cache_len += 1 + + # TODO: Remove if we don't want to print per generated token + os.system("clear") + print_output_prompts(generated_ids, tokenizer, batch_size) logger.info("Finished inference decode stage!") num_tokens_generated_decode = global_batch * (output_token_index - N_warmup + 1) diff --git a/models/demos/falcon7b/tt/falcon_decoder.py b/models/demos/falcon7b/tt/falcon_decoder.py index ad82cf25561d..40ce703cbb37 100644 --- a/models/demos/falcon7b/tt/falcon_decoder.py +++ b/models/demos/falcon7b/tt/falcon_decoder.py @@ -103,7 +103,6 @@ def forward( assert not output_attentions layernorm_output = [] - # print("=== Running Decoder Layernorm ===") for i in range(self.num_devices): layernorm_output.append( tt_lib.tensor.layernorm( @@ -112,7 +111,6 @@ def forward( output_mem_config=self.model_config["INPUT_LAYERNORM_OUTPUT_MEMCFG"], ) ) - # print("=== Running Decoder BcastMul ===") for i in range(self.num_devices): layernorm_output[i] = tt_lib.tensor.bcast( layernorm_output[i], @@ -121,7 +119,6 @@ def forward( tt_lib.tensor.BcastOpDim.H, output_mem_config=self.model_config["INPUT_LAYERNORM_OUTPUT_MEMCFG"], ) - # print("=== Running Decoder BcastAdd ===") for i in range(self.num_devices): layernorm_output[i] = tt_lib.tensor.bcast( layernorm_output[i], @@ -134,7 +131,6 @@ def forward( residual = hidden_states # Self Attention - # print("=== Running Decoder SelfAttn ===") attn_outputs = self.self_attn( hidden_states=layernorm_output, alibi=alibi, @@ -150,7 +146,6 @@ def forward( # MLP # mlp will deallocate layernorm_output - # print("=== Running Decoder MLP ===") mlp_output = self.mlp(layernorm_output) output = [] diff --git a/models/demos/falcon7b/tt/falcon_model.py b/models/demos/falcon7b/tt/falcon_model.py index 8dc1acea5930..1854a053267a 100644 --- a/models/demos/falcon7b/tt/falcon_model.py +++ b/models/demos/falcon7b/tt/falcon_model.py @@ -235,14 +235,12 @@ def forward( layer_output = layer_output[0] # apply final norm layer - # print(" === Running Layernorm ===") for i in range(self.num_devices): layer_output[i] = tt_lib.tensor.layernorm( layer_output[i], self.layernorm_eps, output_mem_config=self.model_config["LN_F_OUTPUT_MEMCFG"], ) - # print(" === Running BcastMul ===") for i in range(self.num_devices): layer_output[i] = tt_lib.tensor.bcast( layer_output[i], @@ -251,7 +249,6 @@ def forward( tt_lib.tensor.BcastOpDim.H, output_mem_config=self.model_config["LN_F_OUTPUT_MEMCFG"], ) - # print(" === Running BcastAdd ===") for i in range(self.num_devices): layer_output[i] = tt_lib.tensor.bcast( layer_output[i], diff --git a/models/demos/t3000/falcon7b/demo_t3000.py b/models/demos/t3000/falcon7b/demo_t3000.py index 518eb786bf73..53e3c4227593 100644 --- a/models/demos/t3000/falcon7b/demo_t3000.py +++ b/models/demos/t3000/falcon7b/demo_t3000.py @@ -8,6 +8,7 @@ @pytest.mark.parametrize("perf_mode", (True,)) # Option to measure perf using max seq length (with invalid outputs) +@pytest.mark.parametrize("async_mode", (False,)) # Option to run Falcon in Async mode @pytest.mark.parametrize("num_devices", (1, 2, 3, 4, 5, 6, 7, 8)) def test_demo_multichip( perf_mode, @@ -17,14 +18,17 @@ def test_demo_multichip( get_tt_cache_path, all_devices, use_program_cache, + async_mode, ): assert is_wormhole_b0(), "Multi-chip is only supported for Wormhole B0" devices = get_devices_for_t3000(all_devices, num_devices) + for device in devices: + device.enable_async(async_mode) return run_falcon_demo_kv( user_input=user_input, batch_size=32, - max_seq_len=128, + max_seq_len=1024, model_config_strs_prefill_decode=["BFLOAT16-DRAM", "BFLOAT16-L1_SHARDED"], model_location_generator=model_location_generator, get_tt_cache_path=get_tt_cache_path, diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 20cac115266b..c01e377f6110 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -82,7 +82,7 @@ Tensor::~Tensor() { } void Tensor::deallocate(bool force) { - ZoneScopedN("Deallocate"); + ZoneScopedN("TensorDeallocate"); if (this->tensor_attributes.use_count()) { // Check if the attributes didn't get moved to another tensor. // If not, we can deallocate this tensor. @@ -92,7 +92,6 @@ void Tensor::deallocate(bool force) { // This is a special case, where storage type cannot change for multi // device tensors (see assert in launch_op). Hence, this only applies // to the single device case, where metadata populated == tensor populated. - ZoneScopedN("WaitForMDPopulated"); this->wait_for_tensor_metadata_populated(); } std::visit( @@ -105,10 +104,11 @@ void Tensor::deallocate(bool force) { } else if constexpr (std::is_same_v) { if (this->workers.at(0)->in_main_thread()) { // If owned by the main thread, deallocate this tensor only from the main thread - uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count.load(); + uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count; if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { this->tensor_attributes->deallocated = true; - uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(this->workers.at(0)); + // Record ref count before sending to worker + uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(); this->workers.at(0)->push_work([force, *this] () mutable { std::visit([force, this] (auto&& s) { using type = std::decay_t; @@ -123,6 +123,7 @@ void Tensor::deallocate(bool force) { } }, this->tensor_attributes->storage); }); + // Update ref count after sending to worker this->tensor_attributes->update_main_thread_ref_count(this->workers.at(0), device_tensor_ref_count); } } else { @@ -134,9 +135,11 @@ void Tensor::deallocate(bool force) { } } else if constexpr (std::is_same_v) { if (this->workers.at(0)->in_main_thread()) { - uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count.load(); + uint32_t ref_count_to_use = (this->workers.at(0)->get_worker_mode() == WorkExecutorMode::SYNCHRONOUS) ? this->tensor_attributes.use_count() : this->tensor_attributes->main_thread_ref_count; if ((force or ref_count_to_use == 1) and not this->tensor_attributes->deallocated) { this->tensor_attributes->deallocated = true; + // Record ref count before sending to workers + uint32_t device_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(); for (auto worker : this->workers) { worker->push_work([force, *this, worker] () mutable { std::visit([force, worker] (auto&& s) { @@ -152,6 +155,8 @@ void Tensor::deallocate(bool force) { }, this->tensor_attributes->storage); }); } + // Update ref count after sending to workers + this->tensor_attributes->update_main_thread_ref_count(this->workers.at(0), device_tensor_ref_count); } } } else if constexpr (std::is_same_v) { @@ -318,8 +323,8 @@ Tensor Tensor::to(CommandQueue & queue, const MemoryConfig & mem_config) const { // functions running in main can get storage type without blocking Tensor device_tensor({target_device}); // Record main thread ref count for tensors before pushing to queue. - uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(device_tensor.workers.at(0)); - uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(device_tensor.workers.at(0)); + uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); + uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(); queue.device()->push_work([async_safe_tensor, device_tensor, mem_config, target_device] () mutable { if (async_safe_tensor.storage_type() == StorageType::DEVICE) { TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices"); @@ -347,8 +352,8 @@ Tensor Tensor::to(Device *target_device, const MemoryConfig &mem_config) const { // functions running in main can get storage type without blocking Tensor device_tensor({target_device}); // Record main thread ref count for tensors before pushing to queue. - uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(device_tensor.workers.at(0)); - uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(device_tensor.workers.at(0)); + uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); + uint32_t original_tensor_ref_count = async_safe_tensor.tensor_attributes->record_main_thread_ref_count(); target_device->push_work([async_safe_tensor, device_tensor, mem_config, target_device] () mutable { if (async_safe_tensor.storage_type() == StorageType::DEVICE) { TT_ASSERT(async_safe_tensor.device() == target_device && "Currently do not support moving between devices"); @@ -374,7 +379,7 @@ Tensor Tensor::to(DeviceMesh *device_mesh, const MemoryConfig &mem_config) const auto workers = std::vector(all_workers.begin(), all_workers.end()); TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); Tensor multi_device_tensor = Tensor(workers); - uint32_t device_tensor_ref_count = multi_device_tensor.tensor_attributes->record_main_thread_ref_count(multi_device_tensor.workers.at(0)); + uint32_t device_tensor_ref_count = multi_device_tensor.tensor_attributes->record_main_thread_ref_count(); for (auto& target_device : workers) { target_device->push_work([*this, multi_device_tensor, mem_config, target_device] () mutable { @@ -407,7 +412,7 @@ Tensor Tensor::cpu(bool blocking) const { } TT_FATAL(validate_worker_modes(workers), "All device threads/workers must be running in the same mode (ASYNC or SYNC)"); Tensor host_tensor({}, workers.size()); - uint32_t original_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(workers.at(0)); + uint32_t original_tensor_ref_count = this->tensor_attributes->record_main_thread_ref_count(); for (auto target_device : workers) { target_device->push_work([host_tensor, blocking, target_device, *this, workers] () mutable { TT_ASSERT(this->storage_type() == StorageType::DEVICE or this->storage_type() == StorageType::MULTI_DEVICE, "Can only use worker queue for cpu call if tensor is on device."); diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index fca9885f2943..2111925864d5 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -33,10 +33,10 @@ struct Tensor { Layout layout; std::mutex populated_mutex; std::vector tensor_populated = {}; - std::atomic main_thread_ref_count = 0; + uint32_t main_thread_ref_count = 0; bool deallocated = false; // Set to true if device side storage was deallocated bool dynamic_storage = false; // Storage type can change, depending on op behaviour - bool track = false; + bool track_ref_count = false; TensorAttributes(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout) : storage(storage), shape(shape), dtype(dtype), layout(layout) {} TensorAttributes() : shape({0xff, 0xff, 0xff, 0xff}), dtype(DataType::INVALID), layout(Layout::INVALID) {} ~TensorAttributes() = default; @@ -56,8 +56,8 @@ struct Tensor { void increment_main_thread_ref_count(Device* worker) { if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS and worker->in_main_thread()) { main_thread_ref_count++; - if (track) { - std::cout << "Inc: " << this << " " << main_thread_ref_count << " " << shared_from_this().use_count() << std::endl; + if (track_ref_count) { + tt::log_info("Inc Ref Count on tensor {}. Main Thread Ref Count: {}. Total Ref Count: {}.", reinterpret_cast(this), main_thread_ref_count, shared_from_this().use_count()); } } } @@ -65,25 +65,20 @@ struct Tensor { void decrement_main_thread_ref_count(Device* worker) { if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS and worker->in_main_thread()) { main_thread_ref_count--; - if (track) { - std::cout << "Dec: " << this << " " << main_thread_ref_count << " " << shared_from_this().use_count() << std::endl; + if (track_ref_count) { + tt::log_info("Dec Ref Count on tensor {}. Main Thread Ref Count: {}. Total Ref Count: {}.", reinterpret_cast(this), main_thread_ref_count, shared_from_this().use_count()); } } } - uint32_t record_main_thread_ref_count(Device* worker) { - if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS and worker->in_main_thread()) { - if (track) { - std::cout << "Record: " << this << " " << main_thread_ref_count << " " << shared_from_this().use_count() << std::endl; - } - } + uint32_t record_main_thread_ref_count() { return main_thread_ref_count; } void update_main_thread_ref_count(Device* worker, uint32_t ref_count) { if (worker->get_worker_mode() == WorkExecutorMode::ASYNCHRONOUS and worker->in_main_thread()) { - if (track) { - std::cout << "Update: " << this << " " << main_thread_ref_count << " " << shared_from_this().use_count() << std::endl; + if (track_ref_count) { + tt::log_info("Update Ref Count on tensor {}. Main Thread Ref Count: {}. Total Ref Count: {}.", reinterpret_cast(this), main_thread_ref_count, shared_from_this().use_count()); } main_thread_ref_count = ref_count; } @@ -171,9 +166,8 @@ struct Tensor { ~Tensor(); - void track() { - this->tensor_attributes->track = true; - } + void track_ref_count() { this->tensor_attributes->track_ref_count = true; } + void deepcopy(const Tensor& other); void populate_buffers_and_metadata(const Tensor& other); diff --git a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp index 91c122a681f6..0578c88760fa 100644 --- a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp @@ -804,21 +804,37 @@ Tensor falcon_dense_h_to_4h_matmul(const Tensor &input_tensor_a, const Tensor &i Tensor falcon_lm_head_matmul(const Tensor &input_tensor_a, const Tensor &input_tensor_b, std::optional bias, const MemoryConfig& mem_config, std::optional output_dtype) { auto seq_len = input_tensor_a.get_legacy_shape()[2]; - + std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor_a, input_tensor_b}, {bias}))}; if (seq_len > 512) { // TODO: Check support for seq_len == 128, 256, 512, ..., 2048 - TT_FATAL(seq_len % TILE_HEIGHT == 0, "Falcon mm's seq_len must be a multiple of 32!"); - TT_FATAL(seq_len >= 128, "Falcon mm's seq_len must be greater than 128!"); - TT_FATAL((input_tensor_a.get_legacy_shape() == Shape({1, 1, seq_len, 4544})), "Unsupported input shape"); - TT_FATAL((input_tensor_b.get_legacy_shape() == Shape({1, 1, 4544, 65024})), "Unsupported input shape"); - return operation::run_with_autoformat(Matmul{.bcast_batch=true, .output_mem_config=mem_config, .output_dtype=output_dtype.value_or(input_tensor_a.get_dtype())}, {input_tensor_a, input_tensor_b}, {bias}).at(0); + operation::launch_with_autoformat( + [seq_len, mem_config, output_dtype] (const std::vector& input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + auto& input_tensor_a = input_tensors.at(0); + auto& input_tensor_b = input_tensors.at(1); + auto& bias = optional_input_tensors.at(0); + TT_FATAL(seq_len % TILE_HEIGHT == 0, "Falcon mm's seq_len must be a multiple of 32!"); + TT_FATAL(seq_len >= 128, "Falcon mm's seq_len must be greater than 128!"); + TT_FATAL((input_tensor_a.get_legacy_shape() == Shape({1, 1, seq_len, 4544})), "Unsupported input shape"); + TT_FATAL((input_tensor_b.get_legacy_shape() == Shape({1, 1, 4544, 65024})), "Unsupported input shape"); + return operation::run_with_autoformat(Matmul{.bcast_batch=true, .output_mem_config=mem_config, .output_dtype=output_dtype.value_or(input_tensor_a.get_dtype())}, {input_tensor_a, input_tensor_b}, {bias}); + }, + {input_tensor_a, input_tensor_b}, output_tensors, {bias}); + } else { - CoreCoord grid_size = get_falcon_matmul_grid_size(input_tensor_a.device()); - auto program_config = bmm_op_utils::get_mcast_1d_config(input_tensor_a, input_tensor_b, true, std::nullopt, true, mem_config.is_sharded(), grid_size); - std::optional config = std::nullopt; - auto compute_kernel_config = init_device_compute_kernel_config(input_tensor_a.device()->arch(), config, MathFidelity::LoFi, true /* math_approx_mode */, false /* fp32_dest_acc_en */, true /* packer_l1_acc */); - return operations::primary::matmul_1d(input_tensor_a, input_tensor_b, bias, program_config, mem_config, output_dtype, compute_kernel_config); + operation::launch_op( + [mem_config, output_dtype] (const std::vector& input_tensors, const std::vector>& optional_input_tensors) mutable -> std::vector { + auto& input_tensor_a = input_tensors.at(0); + auto& input_tensor_b = input_tensors.at(1); + auto& bias = optional_input_tensors.at(0); + CoreCoord grid_size = get_falcon_matmul_grid_size(input_tensor_a.device()); + auto program_config = bmm_op_utils::get_mcast_1d_config(input_tensor_a, input_tensor_b, true, std::nullopt, true, mem_config.is_sharded(), grid_size); + std::optional config = std::nullopt; + auto compute_kernel_config = init_device_compute_kernel_config(input_tensor_a.device()->arch(), config, MathFidelity::LoFi, true /* math_approx_mode */, false /* fp32_dest_acc_en */, true /* packer_l1_acc */); + return {operations::primary::matmul_1d(input_tensor_a, input_tensor_b, bias, program_config, mem_config, output_dtype, compute_kernel_config)}; + }, + {input_tensor_a, input_tensor_b}, output_tensors, {bias}); } + return output_tensors.at(0); } /** diff --git a/tt_eager/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp b/tt_eager/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp index 3aeef227b272..674c7db8638c 100644 --- a/tt_eager/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp +++ b/tt_eager/tt_dnn/op_library/rotary_embedding/rotary_embedding_op.hpp @@ -79,12 +79,11 @@ inline Tensor rotary_embedding( FormatParams cos_format_params = {.pad_shape = cos_pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; Shape sin_pad_shape = AutoFormat::pad_to_tile_shape(sin.get_legacy_shape()); FormatParams sin_format_params = {.pad_shape = sin_pad_shape, .pad_value = 0.0, .target_layout = Layout::TILE}; - auto rval = operation::run_with_autoformat( + return operation::run_with_autoformat( RotaryEmbedding{seq_len, token_idx, output_mem_config, kernel_config_val}, {input_tensor, cos, sin}, {input_format_params, cos_format_params, sin_format_params}, {Layout::TILE}); - return rval; }, {input_tensor, cos, sin}, output_tensors); return output_tensors.at(0); } diff --git a/tt_eager/tt_dnn/op_library/run_operation.cpp b/tt_eager/tt_dnn/op_library/run_operation.cpp index 600b0b021eea..d012ad38888d 100644 --- a/tt_eager/tt_dnn/op_library/run_operation.cpp +++ b/tt_eager/tt_dnn/op_library/run_operation.cpp @@ -644,20 +644,9 @@ void launch_with_autoformat( ) { // Mark each output tensor as having dynamic storage (can be on host or device, depending // on autoformat behaviour). Multi device tensors do not support dynamic storage. - uint32_t min_num_input_workers = std::numeric_limits::max(); - for (const auto& input : input_tensors) { - min_num_input_workers = std::min(min_num_input_workers, static_cast(input.workers.size())); - } - for (const auto& input : optional_input_tensors) { - if (input.has_value()) { - min_num_input_workers = std::min(min_num_input_workers, static_cast(input.value().workers.size())); - } - } - for (auto& output_tensor : output_tensors) { output_tensor.tensor_attributes->dynamic_storage = (output_tensor.workers.size() <= 1); } - launch_op(std::move(op_func), input_tensors, output_tensors, optional_input_tensors); } @@ -687,12 +676,12 @@ void launch_op( // copy borrowed tensors to owned storage. for (int i = 0; i < input_tensors.size(); i++) { async_safe_input_tensors.push_back(copy_borrowed_tensor_in_async_mode(workers.at(0), input_tensors.at(i))); - input_tensor_ref_count.push_back(async_safe_input_tensors[i].tensor_attributes->record_main_thread_ref_count(workers.at(0))); + input_tensor_ref_count.push_back(async_safe_input_tensors[i].tensor_attributes->record_main_thread_ref_count()); } for (int i = 0; i < optional_input_tensors.size(); i++) { if (optional_input_tensors[i].has_value()) { async_safe_optional_input_tensors.push_back(copy_borrowed_tensor_in_async_mode(workers.at(0), optional_input_tensors[i].value())); - optional_input_tensor_ref_count.push_back(async_safe_optional_input_tensors[i].value().tensor_attributes->record_main_thread_ref_count(workers.at(0))); + optional_input_tensor_ref_count.push_back(async_safe_optional_input_tensors[i].value().tensor_attributes->record_main_thread_ref_count()); } else { async_safe_optional_input_tensors.push_back(std::nullopt); @@ -700,11 +689,10 @@ void launch_op( } } for (int i = 0; i < output_tensors.size(); i++) { - output_tensor_ref_count.push_back(output_tensors[i].tensor_attributes->record_main_thread_ref_count(workers.at(0))); + output_tensor_ref_count.push_back(output_tensors[i].tensor_attributes->record_main_thread_ref_count()); } { ZoneScopedN("PushOpToWorkers"); - // Remove push backs, use reserve for (auto target_device : workers) { target_device->push_work([target_device, workers, op_func, async_safe_optional_input_tensors, inputs = async_safe_input_tensors, outputs = output_tensors] () mutable { std::vector input_shards = {}; diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp index 512a3c8b0cb8..a23f09915cf8 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_pytensor.cpp @@ -864,8 +864,11 @@ Tensor convert_python_tensors_to_tt_tensors(py::list tensor_shards, std::optiona tt_tensor = tt_tensor.to(tt_device) )doc") - .def("track", - [](Tensor &self) { return self.track(); }) + .def("track_ref_count", + [](Tensor &self) { return self.track_ref_count(); }, + R"doc( + Log the reference count (as seen by the main and worker threads) of a tensor as it evolves during runtime. + )doc") .def( "to", py::overload_cast(&Tensor::to, py::const_), diff --git a/tt_metal/common/executor.hpp b/tt_metal/common/executor.hpp index 6090335c2eed..bdb5ac267a72 100644 --- a/tt_metal/common/executor.hpp +++ b/tt_metal/common/executor.hpp @@ -8,7 +8,7 @@ #include namespace tt::tt_metal::detail { - static const size_t EXECUTOR_NTHREADS = std::getenv("TT_METAL_THREADCOUNT") ? std::stoi( std::getenv("TT_METAL_THREADCOUNT") ) : std::thread::hardware_concurrency();; + static const size_t EXECUTOR_NTHREADS = std::getenv("TT_METAL_THREADCOUNT") ? std::stoi( std::getenv("TT_METAL_THREADCOUNT") ) : std::thread::hardware_concurrency(); using Executor = tf::Executor; using ExecTask = tf::Task; diff --git a/tt_metal/detail/kernel_cache.hpp b/tt_metal/detail/kernel_cache.hpp index 2f03f7d3ef59..7bb9c4eac62b 100644 --- a/tt_metal/detail/kernel_cache.hpp +++ b/tt_metal/detail/kernel_cache.hpp @@ -13,11 +13,11 @@ namespace tt::tt_metal::detail } bool exists(size_t khash) { - std::scoped_lock lock(mutex_); + unique_lock lock(mutex_); return hashes_.find(khash) != hashes_.end(); } bool add(size_t khash) { - std::scoped_lock lock(mutex_); + unique_lock lock(mutex_); bool ret = false; if (hashes_.find(khash) == hashes_.end() ){ hashes_.insert(khash); @@ -27,7 +27,7 @@ namespace tt::tt_metal::detail } void clear() { - std::scoped_lock lock(mutex_); + unique_lock lock(mutex_); hashes_.clear(); } diff --git a/tt_metal/impl/dispatch/work_executor.hpp b/tt_metal/impl/dispatch/work_executor.hpp index b8a954d4b863..65e615fe6614 100644 --- a/tt_metal/impl/dispatch/work_executor.hpp +++ b/tt_metal/impl/dispatch/work_executor.hpp @@ -121,6 +121,7 @@ class WorkExecutor { this->worker_state = WorkerState::RUNNING; this->worker_thread = std::thread(&WorkExecutor::run_worker, this); this->worker_queue.worker_thread_id = std::hash{}(this->worker_thread.get_id()); + // Bind a worker tied to a device to a specific CPU core in round robin fashion. Thread affinity == Better Perf. cpu_set_t cpuset; CPU_ZERO(&cpuset); CPU_SET(managed_device_id % std::thread::hardware_concurrency(), &cpuset); diff --git a/tt_metal/impl/program/program.hpp b/tt_metal/impl/program/program.hpp index 21d90c3b9990..d0b0820fa421 100644 --- a/tt_metal/impl/program/program.hpp +++ b/tt_metal/impl/program/program.hpp @@ -154,7 +154,6 @@ class Program { uint64_t id; // Need to make non-const due to move constructor static std::atomic program_counter; - // inline static std::vector compile_workers = std::vector(32, WorkExecutor(0)); std::unordered_map >> kernels_; std::unordered_map grid_extent_; diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index b98b056b9748..40dffafcee2d 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -622,9 +622,6 @@ void CloseDevices(std::map devices) { void DeallocateBuffer(Buffer *buffer) { detail::DispatchStateCheck(not buffer->device()->using_slow_dispatch()); - // if (buffer->device()->in_main_thread()) { - // std::cout << "deallocate called in main thread" << std::endl; - // } EnqueueDeallocateBuffer(buffer->device()->command_queue(), *(buffer->device()->allocator_), buffer->address(), buffer->buffer_type(), false); }