From 9e5384a19786d651f1824acd2fdd274cf2ae2062 Mon Sep 17 00:00:00 2001 From: Evan Smal <159198142+esmalTT@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:02:33 -0400 Subject: [PATCH 01/58] #13542: Disable demo performance checks in Mamba functional tests (#13542) #0: Disable demo performance checks in Mamba functional tests Don't check performance in nightly test pipeline, since we only want to check this on perf pipelines. Also fixes Mamba time-to-first-token measurement. Updates performance tests to use device kernel time instead of FW time. Force-merging because others are on vacation. --- models/demos/wormhole/mamba/demo/demo.py | 14 +++++--- .../wormhole/mamba/tests/test_mamba_demo.py | 2 ++ .../wormhole/mamba/tests/test_mamba_perf.py | 35 +++++++++---------- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/models/demos/wormhole/mamba/demo/demo.py b/models/demos/wormhole/mamba/demo/demo.py index c18c22314ef..d5a4017a0b8 100644 --- a/models/demos/wormhole/mamba/demo/demo.py +++ b/models/demos/wormhole/mamba/demo/demo.py @@ -219,6 +219,7 @@ def run_mamba_demo( cache_dir: Optional[str] = None, display: bool = True, prefill_chunk_size: int = 32, + assert_on_performance_measurements: bool = True, ): profiler = BenchmarkProfiler() profiler.start("run") @@ -345,6 +346,8 @@ def callback(token: torch.Tensor, inference_time: float) -> None: prefill_time_to_token_per_user = prefill_stats.mean_throughput_per_user decode_time_to_token_per_user = decode_stats.mean_throughput_per_user + time_to_first_token = 1 / (prefill_time_to_token_per_user + decode_time_to_token_per_user) # t/s/u + measurements = { "total_demo_time": profiler.get_duration("run"), "compile_prefill": profiler.get_duration("compile_prefill"), @@ -352,10 +355,10 @@ def callback(token: torch.Tensor, inference_time: float) -> None: "inference_prefill": prefill_stats.total_time, "inference_decode": decode_stats.total_time, "prefill_t/s": prefill_stats.mean_throughput, - "prefill_time_to_token": prefill_stats.total_time, + "prefill_time_to_token": time_to_first_token, "decode_t/s": decode_stats.mean_throughput, "decode_t/s/u": decode_stats.mean_throughput_per_user, - "prefill_decode_t/s/u": 1 / (prefill_time_to_token_per_user + decode_time_to_token_per_user), # t/s/u + "prefill_decode_t/s/u": time_to_first_token, "token_verification": 1, # This is checked by the caller - but we could also do a match here } @@ -367,7 +370,7 @@ def callback(token: torch.Tensor, inference_time: float) -> None: logger.info( f"Decode throughput: {decode_stats.mean_throughput:.1f} t/s, {decode_stats.mean_throughput_per_user:.2f} t/s/u" ) - logger.info(f"Time to first token: {(1e3 * measurements['prefill_decode_t/s/u']):.2f} ms") + logger.info(f"Time to first token: {(1e3 * time_to_first_token):.2f} ms") chunk_size_to_prefill_targets_tok_per_s = {32: 135.0, 128: 270.0} # perf is different for different chunk sizes targets = { @@ -390,7 +393,10 @@ def callback(token: torch.Tensor, inference_time: float) -> None: output_sequence_length=tokenized_prompts.shape[1] + generated_sequence_length, ) - verify_perf(measurements, targets) + if assert_on_performance_measurements: + verify_perf(measurements, targets) + else: + logger.warning(f"Skipping performance checks (this is expected for functional tests)") return DemoResult(generated_text=token_display.sequences) diff --git a/models/demos/wormhole/mamba/tests/test_mamba_demo.py b/models/demos/wormhole/mamba/tests/test_mamba_demo.py index f5496420b51..1680c933fb8 100644 --- a/models/demos/wormhole/mamba/tests/test_mamba_demo.py +++ b/models/demos/wormhole/mamba/tests/test_mamba_demo.py @@ -44,6 +44,7 @@ def test_demo( get_tt_cache_path, max_gen_len, prefill_chunk_size, + reset_seeds, ): assert len(user_input) == len(expected_output) @@ -55,6 +56,7 @@ def test_demo( display=True, cache_dir=get_tt_cache_path(model_version), prefill_chunk_size=prefill_chunk_size, + assert_on_performance_measurements=False, # Don't check performance for functional tests ) expected = user_input[0] + expected_output[0] diff --git a/models/demos/wormhole/mamba/tests/test_mamba_perf.py b/models/demos/wormhole/mamba/tests/test_mamba_perf.py index 5e69de74619..1d66d1cbf8e 100644 --- a/models/demos/wormhole/mamba/tests/test_mamba_perf.py +++ b/models/demos/wormhole/mamba/tests/test_mamba_perf.py @@ -36,8 +36,8 @@ def is_nearby(actual: float, expected: float, lower_margin: float = 0.03, upper_ @pytest.mark.parametrize( "model_version, mode, batch_size, sequence_length, iterations, expected_compile_time, expected_inference_time", ( - ("state-spaces/mamba-2.8b", ModelMode.DECODE, 32, 1, 8, 15.0, 0.110), - ("state-spaces/mamba-2.8b", ModelMode.PREFILL, 1, 128, 8, 27.0, 0.520), + ("state-spaces/mamba-2.8b", ModelMode.DECODE, 32, 1, 8, 18.0, 0.110), + ("state-spaces/mamba-2.8b", ModelMode.PREFILL, 1, 128, 8, 30.0, 0.520), ), ) @pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) @@ -129,7 +129,7 @@ def test_mamba_perf_e2e( upper_margin = MARGIN if not is_nearby(inference_time, expected_inference_time, lower_margin=lower_margin, upper_margin=upper_margin): logger.warning( - "Inference time does not match (within some margin) the expected value (was {inference_time:2f} but expected {expected_inference_time:2f})" + f"Inference time does not match (within some margin) the expected value (was {inference_time:2f} but expected {expected_inference_time:2f})" ) if not is_nearby(compile_time, expected_compile_time, lower_margin=lower_margin, upper_margin=upper_margin): @@ -142,33 +142,30 @@ def test_mamba_perf_e2e( @pytest.mark.timeout(600) @pytest.mark.models_device_performance_bare_metal @pytest.mark.parametrize( - "batch, warmup, expected_device_fw_duration_ms", - ((32, True, 1.66),), + "batch, expected_layer_duration_ms", + ((32, 1.71),), ) -def test_mamba_perf_device(batch, warmup, expected_device_fw_duration_ms, reset_seeds): +def test_mamba_perf_device(batch, expected_layer_duration_ms): subdir = "ttnn_mamba" - margin = 0.03 - if warmup: - inference_iterations = 2 - else: - inference_iterations = 1 - command = f"pytest models/demos/wormhole/mamba/tests/test_mamba_model.py::test_device_perf[{inference_iterations}]" + margin = 0.01 + command = f"pytest models/demos/wormhole/mamba/tests/test_mamba_model.py::test_device_perf[1]" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] # Convert expected perf (ms) to samples/s - expected_device_fw_duration_ns = expected_device_fw_duration_ms * 1e6 # ms to ns - expected_total_device_fw_samples = get_samples_per_s(expected_device_fw_duration_ns * inference_iterations, batch) - - inference_time_key = "AVG DEVICE FW SAMPLES/S" - expected_perf_cols = {inference_time_key: expected_total_device_fw_samples} + expected_layer_duration_ns = expected_layer_duration_ms * 1e6 # ms to ns + expected_total_layer_samples_per_s = get_samples_per_s(expected_layer_duration_ns, batch) + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_total_layer_samples_per_s} post_processed_results = run_device_perf(command, subdir, 1, cols, batch) + logger.info( + f"Checking device performance... Expecting {expected_total_layer_samples_per_s} samples/sec (equivalent to {expected_layer_duration_ms} ms per layer)" + ) expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True) - comment = "" prep_device_perf_report( model_name=f"mamba-2.8b_batch_{batch}", batch_size=batch, post_processed_results=post_processed_results, expected_results=expected_results, - comments=comment, + comments="", ) From f521af0061bf53567942b7a27fd89aa300ec16ce Mon Sep 17 00:00:00 2001 From: Eyon Land <41128502+eyonland@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:52:19 -0500 Subject: [PATCH 02/58] #0: Transferring ownership of model governance to @uaydonat and @esmalTT (#13559) --- CODEOWNERS | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/CODEOWNERS b/CODEOWNERS index 81e62bd088a..2f5b8e2b2ac 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -126,10 +126,10 @@ tests/ttnn/distributed/ @cfjchu @ayerofieiev-tt @dmakoviichuk-tt /models/ @tt-rkim @uaydonat /models/*/** models/conv_on_device_utils*.py @mywoodstock @shwetankTT @sankarmanoj-tt -functional_*/ @eyonland @patrickroberts @yan-zaretskiy @cfjchu @xanderchin -models/demos @eyonland @patrickroberts @yan-zaretskiy @cfjchu @xanderchin +functional_*/ @uaydonat @esmalTT +models/demos @uaydonat @tt-rkim models/demos/metal_BERT_large_11 @tt-aho @TT-BrianLiu -models/demos/wormhole @uaydonat @eyonland +models/demos/wormhole @uaydonat @tt-rkim models/demos/t3000 @uaydonat models/demos/falcon7b_common @skhorasganiTT @djordje-tt @uaydonat models/demos/wormhole/mamba @esmalTT @uaydonat @kpaigwar @@ -145,7 +145,7 @@ models/demos/t3000/llama3_70b @cglagovichTT @uaydonat @johanna-rock-tt @djordje- models/demos/t3000/mixtral8x7b @yieldthought @mtairum @uaydonat models/demos/tg/llama3_70b @cglagovichTT @uaydonat @johanna-rock-tt @djordje-tt @kpaigwar models/demos/tg/falcon7b @skhorasganiTT @djordje-tt @uaydonat -models/demos/grayskull @uaydonat @eyonland +models/demos/grayskull @uaydonat @tt-rkim models/demos/**/*resnet* @mywoodstock @shwetankTT @tt-aho models/experimental/functional_unet @esmalTT @uaydonat @mywoodstock models/perf/ @uaydonat @tt-rkim From 05f52b6ce403f7f73e64ffd88285d24d13b28f46 Mon Sep 17 00:00:00 2001 From: Anil Mahmud Date: Sun, 29 Sep 2024 17:37:02 -0400 Subject: [PATCH 03/58] #13026: Enable full dst synchronization scheme --- .../test_copy_block_matmul_partials.cpp | 96 +++-- .../unit_tests/compute/test_reconfig.cpp | 110 +++-- .../unit_tests/compute/test_reduce.cpp | 347 ++++++++-------- .../compute/test_untilize_tilize.cpp | 209 +++++----- .../compute/matmul/test_matmul_X_tile.cpp | 377 +++++++++--------- .../ttnn/unit_tests/operations/test_matmul.py | 27 ++ .../blackhole/metal/common/chlkc_list.h | 3 + .../metal/llk_api/llk_math_binary_api.h | 4 +- .../metal/llk_api/llk_math_binary_sfpu_api.h | 8 +- .../metal/llk_api/llk_math_common_api.h | 6 +- .../llk_api/llk_math_unary_datacopy_api.h | 4 +- .../blackhole/metal/llk_api/llk_pack_api.h | 11 +- .../llk_math_eltwise_unary_sfpu_params.h | 2 +- .../grayskull/metal/common/chlkc_list.h | 3 + .../metal/llk_api/llk_math_binary_api.h | 4 +- .../metal/llk_api/llk_math_common_api.h | 6 +- .../llk_api/llk_math_unary_datacopy_api.h | 2 +- .../metal/llk_api/llk_math_unary_sfpu_api.h | 14 +- .../grayskull/metal/llk_api/llk_pack_api.h | 11 +- ..._math_eltwise_unary_sfpu_common_includes.h | 2 +- .../wormhole_b0/metal/common/chlkc_list.h | 3 + .../metal/llk_api/llk_math_binary_api.h | 4 +- .../metal/llk_api/llk_math_binary_sfpu_api.h | 8 +- .../metal/llk_api/llk_math_common_api.h | 6 +- .../llk_api/llk_math_unary_datacopy_api.h | 4 +- .../wormhole_b0/metal/llk_api/llk_pack_api.h | 11 +- tt_metal/impl/kernels/kernel.cpp | 6 +- tt_metal/impl/kernels/kernel_types.hpp | 1 + tt_metal/jit_build/genfiles.cpp | 18 + tt_metal/jit_build/settings.hpp | 2 + .../moreh_layernorm/moreh_layernorm_op.cpp | 2 +- ...reh_layernorm_backward_gamma_beta_grad.cpp | 2 +- .../moreh_layernorm_backward_input_grad.cpp | 2 +- .../multi_core/moreh_matmul_op_multi_core.cpp | 2 +- .../softmax_c_large/softmax_c_large.cpp | 2 +- .../softmax_h_large/softmax_h_large.cpp | 2 +- .../softmax_h_small/softmax_h_small.cpp | 4 +- .../softmax_w_large/softmax_w_large.cpp | 2 +- .../softmax_w_small/softmax_w_small.cpp | 4 +- .../softmax_backward_c_large.cpp | 2 +- .../softmax_backward_h_large.cpp | 2 +- .../softmax_backward_h_small.cpp | 2 +- .../softmax_backward_w_large.cpp | 2 +- .../softmax_backward_w_small.cpp | 2 +- .../moreh_sum_h_impl/moreh_int_sum_h_impl.cpp | 4 +- .../moreh_sum_h_impl/moreh_sum_h_impl.cpp | 2 +- .../moreh_int_sum_nc_impl.cpp | 2 +- .../moreh_sum_nc_impl/moreh_sum_nc_impl.cpp | 2 +- .../moreh_sum_w_impl/moreh_int_sum_w_impl.cpp | 2 +- .../moreh_sum_w_impl/moreh_sum_w_impl.cpp | 2 +- .../moreh_sum_backward_impl.cpp | 2 +- .../compute_kernel/compute_kernel_config.cpp | 22 +- .../compute_kernel/compute_kernel_config.hpp | 7 +- .../device/fast_reduce_nc_program_factory.cpp | 2 +- .../matmul_op_multi_core_program_factory.cpp | 10 +- .../device/moreh_adam_program_factory.cpp | 2 +- .../device/multi_core_program_factory.cpp | 2 +- .../device/moreh_dot_program_factory.cpp | 2 +- .../moreh_layer_norm_program_factory.cpp | 2 +- ...ckward_gamma_beta_grad_program_factory.cpp | 2 +- ...rm_backward_input_grad_program_factory.cpp | 2 +- ...ar_backward_multi_core_program_factory.cpp | 3 +- ...r_backward_single_core_program_factory.cpp | 2 +- .../device/moreh_matmul_program_factory.cpp | 2 +- .../device/moreh_mean_h_program_factory.cpp | 2 +- .../device/moreh_mean_nc_program_factory.cpp | 2 +- .../device/moreh_mean_w_program_factory.cpp | 2 +- .../moreh_mean_backward_program_factory.cpp | 2 +- .../moreh_nll_loss_step1_program_factory.cpp | 2 +- .../moreh_nll_loss_step2_program_factory.cpp | 6 +- ...oreh_nll_loss_backward_program_factory.cpp | 6 +- ...oss_unreduced_backward_program_factory.cpp | 6 +- .../device/moreh_norm_program_factory_h.cpp | 2 +- .../moreh_norm_program_factory_other.cpp | 2 +- .../device/moreh_norm_program_factory_w.cpp | 2 +- .../moreh_norm_backward_program_factory.cpp | 2 +- .../device/moreh_sgd_program_factory.cpp | 2 +- .../device/moreh_softmax_device_operation.cpp | 4 +- .../softmax_c_large/softmax_c_large.cpp | 2 +- .../softmax_h_large/softmax_h_large.cpp | 2 +- .../softmax_h_small/softmax_h_small.cpp | 2 +- .../softmax_w_large/softmax_w_large.cpp | 2 +- .../softmax_w_small/softmax_w_small.cpp | 2 +- .../softmax_backward_c_large.cpp | 2 +- .../softmax_backward_h_large.cpp | 2 +- .../softmax_backward_h_small.cpp | 2 +- .../softmax_backward_w_large.cpp | 2 +- .../softmax_backward_w_small.cpp | 2 +- .../moreh_int_sum_h_program_factory.cpp | 2 +- .../moreh_int_sum_nc_program_factory.cpp | 2 +- .../moreh_int_sum_w_program_factory.cpp | 2 +- .../device/moreh_sum_h_program_factory.cpp | 2 +- .../device/moreh_sum_nc_program_factory.cpp | 2 +- .../device/moreh_sum_w_program_factory.cpp | 2 +- .../moreh_sum_backward_program_factory.cpp | 2 +- ...ple_bilinear_program_factory_multicore.cpp | 2 +- .../multi_core_h/reduce_op_multi_core_h.cpp | 2 +- .../multi_core_w/reduce_op_multi_core_w.cpp | 2 +- .../reduce_op_single_core_hw.cpp | 2 +- 99 files changed, 846 insertions(+), 655 deletions(-) diff --git a/tests/tt_metal/tt_metal/unit_tests/compute/test_copy_block_matmul_partials.cpp b/tests/tt_metal/tt_metal/unit_tests/compute/test_copy_block_matmul_partials.cpp index bc1da6a8f75..5cd7d0f4f24 100644 --- a/tests/tt_metal/tt_metal/unit_tests/compute/test_copy_block_matmul_partials.cpp +++ b/tests/tt_metal/tt_metal/unit_tests/compute/test_copy_block_matmul_partials.cpp @@ -16,6 +16,7 @@ struct CopyBlockMatmulPartialsConfig { uint32_t compute_ublock; uint32_t src0_cb_index; uint32_t ouput_cb_index; + bool dst_full_sync_en; }; void run_single_core_copy_block_matmul_partials(tt_metal::Device* device, const CopyBlockMatmulPartialsConfig& test_config) { @@ -81,7 +82,8 @@ void run_single_core_copy_block_matmul_partials(tt_metal::Device* device, const program, "tests/tt_metal/tt_metal/test_kernels/compute/eltwise_copy_block_matmul_partials.cpp", core, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args} + tt_metal::ComputeConfig{.dst_full_sync_en = test_config.dst_full_sync_en, + .compile_args = compute_kernel_args} ); @@ -153,53 +155,65 @@ void run_single_core_copy_block_matmul_partials(tt_metal::Device* device, const // //////////////////////////////////////////////////////////////////////////// TEST_F(DeviceFixture, ComputeCopyBlockMatmulPartialsR8W8C8) { - unit_tests::compute::matmul_partials::CopyBlockMatmulPartialsConfig test_config = { - .single_tile_size = 2 * 1024, - .num_tiles = 8, - .reader_ublock = 8, - .writer_ublock = 8, - .compute_ublock = 8, - .src0_cb_index = 0, - .ouput_cb_index = 16 - }; - unit_tests::compute::matmul_partials::run_single_core_copy_block_matmul_partials(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::matmul_partials::CopyBlockMatmulPartialsConfig test_config = { + .single_tile_size = 2 * 1024, + .num_tiles = 8, + .reader_ublock = 8, + .writer_ublock = 8, + .compute_ublock = 8, + .src0_cb_index = 0, + .ouput_cb_index = 16, + .dst_full_sync_en = dst_full_sync_en + }; + unit_tests::compute::matmul_partials::run_single_core_copy_block_matmul_partials(this->devices_.at(0), test_config); + } } TEST_F(DeviceFixture, ComputeCopyBlockMatmulPartialsR8W8C1) { - unit_tests::compute::matmul_partials::CopyBlockMatmulPartialsConfig test_config = { - .single_tile_size = 2 * 1024, - .num_tiles = 8, - .reader_ublock = 8, - .writer_ublock = 8, - .compute_ublock = 1, - .src0_cb_index = 0, - .ouput_cb_index = 16 - }; - unit_tests::compute::matmul_partials::run_single_core_copy_block_matmul_partials(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::matmul_partials::CopyBlockMatmulPartialsConfig test_config = { + .single_tile_size = 2 * 1024, + .num_tiles = 8, + .reader_ublock = 8, + .writer_ublock = 8, + .compute_ublock = 1, + .src0_cb_index = 0, + .ouput_cb_index = 16, + .dst_full_sync_en = dst_full_sync_en + }; + unit_tests::compute::matmul_partials::run_single_core_copy_block_matmul_partials(this->devices_.at(0), test_config); + } } TEST_F(DeviceFixture, ComputeCopyBlockMatmulPartialsR8W1C1) { - unit_tests::compute::matmul_partials::CopyBlockMatmulPartialsConfig test_config = { - .single_tile_size = 2 * 1024, - .num_tiles = 8, - .reader_ublock = 8, - .writer_ublock = 1, - .compute_ublock = 1, - .src0_cb_index = 0, - .ouput_cb_index = 16 - }; - unit_tests::compute::matmul_partials::run_single_core_copy_block_matmul_partials(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::matmul_partials::CopyBlockMatmulPartialsConfig test_config = { + .single_tile_size = 2 * 1024, + .num_tiles = 8, + .reader_ublock = 8, + .writer_ublock = 1, + .compute_ublock = 1, + .src0_cb_index = 0, + .ouput_cb_index = 16, + .dst_full_sync_en = dst_full_sync_en + }; + unit_tests::compute::matmul_partials::run_single_core_copy_block_matmul_partials(this->devices_.at(0), test_config); + } } TEST_F(DeviceFixture, ComputeCopyBlockMatmulPartialsR1W1C1) { - unit_tests::compute::matmul_partials::CopyBlockMatmulPartialsConfig test_config = { - .single_tile_size = 2 * 1024, - .num_tiles = 1, - .reader_ublock = 1, - .writer_ublock = 1, - .compute_ublock = 1, - .src0_cb_index = 0, - .ouput_cb_index = 16 - }; - unit_tests::compute::matmul_partials::run_single_core_copy_block_matmul_partials(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::matmul_partials::CopyBlockMatmulPartialsConfig test_config = { + .single_tile_size = 2 * 1024, + .num_tiles = 1, + .reader_ublock = 1, + .writer_ublock = 1, + .compute_ublock = 1, + .src0_cb_index = 0, + .ouput_cb_index = 16, + .dst_full_sync_en = dst_full_sync_en + }; + unit_tests::compute::matmul_partials::run_single_core_copy_block_matmul_partials(this->devices_.at(0), test_config); + } } diff --git a/tests/tt_metal/tt_metal/unit_tests/compute/test_reconfig.cpp b/tests/tt_metal/tt_metal/unit_tests/compute/test_reconfig.cpp index ecaf175f9a9..9d48d09ccaa 100644 --- a/tests/tt_metal/tt_metal/unit_tests/compute/test_reconfig.cpp +++ b/tests/tt_metal/tt_metal/unit_tests/compute/test_reconfig.cpp @@ -17,6 +17,7 @@ struct ReconfigConfig { bool explicit_reconfig = false; bool split_src_reconfig = false; bool l1_acc = false; + bool dst_full_sync_en = false; }; /// @brief Does Dramx3 --> Reader --> CB --> Add with acc --> CB --> Writer --> Dram @@ -142,7 +143,8 @@ bool single_core_reconfig(tt_metal::Device* device, const ReconfigConfig& test_c program, "tests/tt_metal/tt_metal/test_kernels/compute/reconfig.cpp", core, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args, .defines = defines}); + tt_metal::ComputeConfig{.dst_full_sync_en = test_config.dst_full_sync_en, + .compile_args = compute_kernel_args, .defines = defines}); SetRuntimeArgs( program, @@ -275,14 +277,18 @@ TEST_F(DeviceFixture, TileCopyReconfigExplicitSplit) { if (arch == tt::ARCH::GRAYSKULL) { GTEST_SKIP(); } - unit_tests::compute::reconfig::ReconfigConfig test_config = { - .num_tiles = 1, - .ublock_size_tiles = 1, - .explicit_reconfig = true, - .split_src_reconfig = true - }; - for (unsigned int id = 0; id < num_devices_; id++) { - ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::reconfig::ReconfigConfig test_config = { + .num_tiles = 1, + .ublock_size_tiles = 1, + .explicit_reconfig = true, + .split_src_reconfig = true, + .dst_full_sync_en = dst_full_sync_en + }; + for (unsigned int id = 0; id < num_devices_; id++) { + ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + } } } @@ -291,14 +297,18 @@ TEST_F(DeviceFixture, TileCopyReconfigExplicitJoined) { if (arch == tt::ARCH::GRAYSKULL) { GTEST_SKIP(); } - unit_tests::compute::reconfig::ReconfigConfig test_config = { - .num_tiles = 1, - .ublock_size_tiles = 1, - .explicit_reconfig = true, - .split_src_reconfig = false - }; - for (unsigned int id = 0; id < num_devices_; id++) { - ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::reconfig::ReconfigConfig test_config = { + .num_tiles = 1, + .ublock_size_tiles = 1, + .explicit_reconfig = true, + .split_src_reconfig = false, + .dst_full_sync_en = dst_full_sync_en + }; + for (unsigned int id = 0; id < num_devices_; id++) { + ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + } } } @@ -307,14 +317,18 @@ TEST_F(DeviceFixture, TileCopyReconfigImplicitSplit) { if (arch == tt::ARCH::GRAYSKULL) { GTEST_SKIP(); } - unit_tests::compute::reconfig::ReconfigConfig test_config = { - .num_tiles = 1, - .ublock_size_tiles = 1, - .explicit_reconfig = false, - .split_src_reconfig = true - }; - for (unsigned int id = 0; id < num_devices_; id++) { - ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::reconfig::ReconfigConfig test_config = { + .num_tiles = 1, + .ublock_size_tiles = 1, + .explicit_reconfig = false, + .split_src_reconfig = true, + .dst_full_sync_en = dst_full_sync_en + }; + for (unsigned int id = 0; id < num_devices_; id++) { + ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + } } } @@ -323,14 +337,18 @@ TEST_F(DeviceFixture, TileCopyReconfigImplicitJoined) { if (arch == tt::ARCH::GRAYSKULL) { GTEST_SKIP(); } - unit_tests::compute::reconfig::ReconfigConfig test_config = { - .num_tiles = 1, - .ublock_size_tiles = 1, - .explicit_reconfig = false, - .split_src_reconfig = false - }; - for (unsigned int id = 0; id < num_devices_; id++) { - ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::reconfig::ReconfigConfig test_config = { + .num_tiles = 1, + .ublock_size_tiles = 1, + .explicit_reconfig = false, + .split_src_reconfig = false, + .dst_full_sync_en = dst_full_sync_en + }; + for (unsigned int id = 0; id < num_devices_; id++) { + ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + } } } @@ -339,16 +357,20 @@ TEST_F(DeviceFixture, TileCopyReconfigL1Acc) { if (arch == tt::ARCH::GRAYSKULL) { GTEST_SKIP(); } - unit_tests::compute::reconfig::ReconfigConfig test_config = { - .num_tiles = 1, - .ublock_size_tiles = 1, - }; - for (unsigned int id = 0; id < num_devices_; id++) { - test_config.l1_acc = false; - ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); - log_info(LogTest, "Passed without L1 accumulation"); - test_config.l1_acc = true; - ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); - log_info(LogTest, "Passed with L1 accumulation"); + + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::reconfig::ReconfigConfig test_config = { + .num_tiles = 1, + .ublock_size_tiles = 1, + .dst_full_sync_en = dst_full_sync_en + }; + for (unsigned int id = 0; id < num_devices_; id++) { + test_config.l1_acc = false; + ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + log_info(LogTest, "Passed without L1 accumulation"); + test_config.l1_acc = true; + ASSERT_TRUE(unit_tests::compute::reconfig::single_core_reconfig(devices_.at(id), test_config)); + log_info(LogTest, "Passed with L1 accumulation"); + } } } diff --git a/tests/tt_metal/tt_metal/unit_tests/compute/test_reduce.cpp b/tests/tt_metal/tt_metal/unit_tests/compute/test_reduce.cpp index 961d5fd111c..c12dfb809be 100644 --- a/tests/tt_metal/tt_metal/unit_tests/compute/test_reduce.cpp +++ b/tests/tt_metal/tt_metal/unit_tests/compute/test_reduce.cpp @@ -52,6 +52,7 @@ struct ReduceConfig { std::vector result_shape; bool math_only_reduce = false; bool fp32_dest_acc_en = false; + bool dst_full_sync_en = false; MathFidelity math_fidelity = MathFidelity::HiFi4; }; @@ -315,6 +316,7 @@ void run_single_core_reduce_program(tt_metal::Device* device, const ReduceConfig core, tt_metal::ComputeConfig{.math_fidelity = test_config.math_fidelity, .fp32_dest_acc_en = test_config.fp32_dest_acc_en, + .dst_full_sync_en = test_config.dst_full_sync_en, .compile_args = compute_kernel_args, .defines = reduce_defines}); @@ -382,22 +384,25 @@ TEST_F(DeviceFixture, ComputeReduceH) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .shape = shape, - .reduce_dim = ReduceDim::H, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_h, - .result_shape = result_shape, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid), - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .shape = shape, + .reduce_dim = ReduceDim::H, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_h, + .result_shape = result_shape, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid), + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } @@ -411,23 +416,26 @@ TEST_F(DeviceFixture, ComputeReduceW) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - if ((fp32_dest_acc_en == true) && (this->arch_ == tt::ARCH::GRAYSKULL)) continue; - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .shape = shape, - .reduce_dim = ReduceDim::W, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_w, - .result_shape = result_shape, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid), - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + if ((fp32_dest_acc_en == true) && (this->arch_ == tt::ARCH::GRAYSKULL)) continue; + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .shape = shape, + .reduce_dim = ReduceDim::W, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_w, + .result_shape = result_shape, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid), + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } @@ -441,24 +449,27 @@ TEST_F(DeviceFixture, ComputeReduceHW) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - // Currently fp32 dest unsupported with reduce scalar - if (fp32_dest_acc_en) continue; - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .shape = shape, - .reduce_dim = ReduceDim::HW, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_hw, - .result_shape = result_shape, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid) - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + // Currently fp32 dest unsupported with reduce scalar + if (fp32_dest_acc_en) continue; + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .shape = shape, + .reduce_dim = ReduceDim::HW, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_hw, + .result_shape = result_shape, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid) + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } @@ -476,23 +487,26 @@ TEST_F(DeviceFixture, ComputeReduceHMathOnly) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .shape = shape, - .reduce_dim = ReduceDim::H, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_h, - .result_shape = result_shape, - .math_only_reduce = true, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid) - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .shape = shape, + .reduce_dim = ReduceDim::H, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_h, + .result_shape = result_shape, + .math_only_reduce = true, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid) + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } @@ -506,24 +520,27 @@ TEST_F(DeviceFixture, ComputeReduceWMathOnly) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - if ((fp32_dest_acc_en == true) && (this->arch_ == tt::ARCH::GRAYSKULL)) continue; - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .shape = shape, - .reduce_dim = ReduceDim::W, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_w, - .result_shape = result_shape, - .math_only_reduce = true, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid) - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + if ((fp32_dest_acc_en == true) && (this->arch_ == tt::ARCH::GRAYSKULL)) continue; + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .shape = shape, + .reduce_dim = ReduceDim::W, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_w, + .result_shape = result_shape, + .math_only_reduce = true, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid) + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } @@ -537,25 +554,28 @@ TEST_F(DeviceFixture, ComputeReduceHWMathOnly) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - // Currently fp32 dest unsupported with reduce scalar - if (fp32_dest_acc_en) continue; - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .shape = shape, - .reduce_dim = ReduceDim::HW, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_hw, - .result_shape = result_shape, - .math_only_reduce = true, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid) - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + // Currently fp32 dest unsupported with reduce scalar + if (fp32_dest_acc_en) continue; + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .shape = shape, + .reduce_dim = ReduceDim::HW, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_hw, + .result_shape = result_shape, + .math_only_reduce = true, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid) + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } @@ -573,23 +593,26 @@ TEST_F(DeviceFixture, ComputeReduceHShortInit) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .short_init = true, - .shape = shape, - .reduce_dim = ReduceDim::H, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_h, - .result_shape = result_shape, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid) - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .short_init = true, + .shape = shape, + .reduce_dim = ReduceDim::H, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_h, + .result_shape = result_shape, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid) + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } @@ -603,24 +626,27 @@ TEST_F(DeviceFixture, ComputeReduceWShortInit) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - if ((fp32_dest_acc_en == true) && (this->arch_ == tt::ARCH::GRAYSKULL)) continue; - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .short_init = true, - .shape = shape, - .reduce_dim = ReduceDim::W, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_w, - .result_shape = result_shape, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid) - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + if ((fp32_dest_acc_en == true) && (this->arch_ == tt::ARCH::GRAYSKULL)) continue; + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .short_init = true, + .shape = shape, + .reduce_dim = ReduceDim::W, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_w, + .result_shape = result_shape, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid) + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } @@ -634,25 +660,28 @@ TEST_F(DeviceFixture, ComputeReduceHWShortInit) { if (math_fid == 1) continue; for (uint8_t reduce_type = uint8_t(ReduceType::SUM); reduce_type <= uint8_t(ReduceType::MAX); reduce_type++) { for (bool fp32_dest_acc_en : {true, false}) { - // Currently fp32 dest unsupported with reduce scalar - if (fp32_dest_acc_en) continue; - log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}", math_fid, reduce_type, fp32_dest_acc_en); - ReduceConfig test_config = { - .short_init = true, - .shape = shape, - .reduce_dim = ReduceDim::HW, - .reduce_type = ReduceType(reduce_type), - .data_gen_rand_max = 10.0f, - .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), - .data_gen_offset = -10.0f, - .atol = 1e-2f, - .rtol = 0.08f, - .golden_function = unit_tests::compute::gold_reduce_hw, - .result_shape = result_shape, - .fp32_dest_acc_en = fp32_dest_acc_en, - .math_fidelity = MathFidelity(math_fid) - }; - run_single_core_reduce_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + // Currently fp32 dest unsupported with reduce scalar + if (fp32_dest_acc_en) continue; + log_info(LogTest, "MathFid = {}, ReduceType = {}, FP32DestAcc = {}, DstSyncFull = {}", math_fid, reduce_type, fp32_dest_acc_en, dst_full_sync_en); + ReduceConfig test_config = { + .short_init = true, + .shape = shape, + .reduce_dim = ReduceDim::HW, + .reduce_type = ReduceType(reduce_type), + .data_gen_rand_max = 10.0f, + .data_gen_seed = std::chrono::system_clock::now().time_since_epoch().count(), + .data_gen_offset = -10.0f, + .atol = 1e-2f, + .rtol = 0.08f, + .golden_function = unit_tests::compute::gold_reduce_hw, + .result_shape = result_shape, + .fp32_dest_acc_en = fp32_dest_acc_en, + .dst_full_sync_en = dst_full_sync_en, + .math_fidelity = MathFidelity(math_fid) + }; + run_single_core_reduce_program(this->devices_.at(0), test_config); + } } } } diff --git a/tests/tt_metal/tt_metal/unit_tests/compute/test_untilize_tilize.cpp b/tests/tt_metal/tt_metal/unit_tests/compute/test_untilize_tilize.cpp index 7abc33e2ef1..c96a44be6d2 100644 --- a/tests/tt_metal/tt_metal/unit_tests/compute/test_untilize_tilize.cpp +++ b/tests/tt_metal/tt_metal/unit_tests/compute/test_untilize_tilize.cpp @@ -43,6 +43,7 @@ using GoldenFunc = std::variant< struct TestConfig { bool short_init = false; + bool dst_full_sync_en = false; uint32_t input_single_tile_size; uint32_t output_single_tile_size; uint32_t num_tiles_r; @@ -169,7 +170,8 @@ void run_single_core_tilize_program(tt_metal::Device* device, const TestConfig& program, compute_kernel, core, - tt_metal::ComputeConfig{.compile_args = compute_kernel_args, .defines = defines} + tt_metal::ComputeConfig{.dst_full_sync_en = test_config.dst_full_sync_en, + .compile_args = compute_kernel_args, .defines = defines} ); std::vector src0_vec = create_arange_vector_of_bfloat16(input_dram_buffer_size, false); @@ -276,15 +278,18 @@ Following tests are for Unpack Tilize TEST_F(DeviceFixture, ComputeUnpackTilize) { vector > num_tiles = {{1, 4}, {2, 2}, {4, 1}}; for(auto num_tile : num_tiles) { - unit_tests::compute::tilize::TestConfig test_config = { - .input_single_tile_size = 2 * 1024, - .output_single_tile_size = 2 * 1024, - .num_tiles_r = num_tile[0], - .num_tiles_c = num_tile[1], - .tilize_type = unit_tests::compute::tilize::TilizeType::UNPACK_A, - .golden_function = unit_tests::compute::gold_standard_tilize - }; - unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::tilize::TestConfig test_config = { + .dst_full_sync_en = dst_full_sync_en, + .input_single_tile_size = 2 * 1024, + .output_single_tile_size = 2 * 1024, + .num_tiles_r = num_tile[0], + .num_tiles_c = num_tile[1], + .tilize_type = unit_tests::compute::tilize::TilizeType::UNPACK_A, + .golden_function = unit_tests::compute::gold_standard_tilize + }; + unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + } } } @@ -293,33 +298,40 @@ TEST_F(DeviceFixture, ComputeUnpackTilizeA_B) { if (arch == tt::ARCH::GRAYSKULL) { GTEST_SKIP(); } - unit_tests::compute::tilize::TestConfig test_config = { - .input_single_tile_size = 2 * 1024, - .output_single_tile_size = 2 * 1024, - .num_tiles_r = 2, - .num_tiles_c = 8, - .tilize_type = unit_tests::compute::tilize::TilizeType::UNPACK_A_B, - .golden_function = unit_tests::compute::gold_standard_tilize_w_elwadd - }; - unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); -} -TEST_F(DeviceFixture, ComputeUnpackTilizeShortInit) { - vector > num_tiles = {{1, 4}, {2, 2}, {4, 1}}; - for(auto num_tile : num_tiles) { + for (bool dst_full_sync_en : {true, false}) { unit_tests::compute::tilize::TestConfig test_config = { - .short_init = true, + .dst_full_sync_en = dst_full_sync_en, .input_single_tile_size = 2 * 1024, .output_single_tile_size = 2 * 1024, - .num_tiles_r = num_tile[0], - .num_tiles_c = num_tile[1], - .tilize_type = unit_tests::compute::tilize::TilizeType::UNPACK_A, - .golden_function = unit_tests::compute::gold_standard_tilize + .num_tiles_r = 2, + .num_tiles_c = 8, + .tilize_type = unit_tests::compute::tilize::TilizeType::UNPACK_A_B, + .golden_function = unit_tests::compute::gold_standard_tilize_w_elwadd }; unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); } } +TEST_F(DeviceFixture, ComputeUnpackTilizeShortInit) { + vector > num_tiles = {{1, 4}, {2, 2}, {4, 1}}; + for(auto num_tile : num_tiles) { + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::tilize::TestConfig test_config = { + .short_init = true, + .dst_full_sync_en = dst_full_sync_en, + .input_single_tile_size = 2 * 1024, + .output_single_tile_size = 2 * 1024, + .num_tiles_r = num_tile[0], + .num_tiles_c = num_tile[1], + .tilize_type = unit_tests::compute::tilize::TilizeType::UNPACK_A, + .golden_function = unit_tests::compute::gold_standard_tilize + }; + unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + } + } +} + /************************************** Following tests are for Unpack Untilize ***************************************/ @@ -327,31 +339,37 @@ Following tests are for Unpack Untilize TEST_F(DeviceFixture, ComputeUnpackUntilize) { vector > num_tiles = {{1, 4}, {2, 2}, {4, 1}}; for(auto num_tile : num_tiles) { - unit_tests::compute::tilize::TestConfig test_config = { - .input_single_tile_size = 2 * 1024, - .output_single_tile_size = 2 * 1024, - .num_tiles_r = num_tile[0], - .num_tiles_c = num_tile[1], - .untilize_type = unit_tests::compute::tilize::UntilizeType::UNPACK, - .golden_function = unit_tests::compute::gold_standard_untilize - }; - unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::tilize::TestConfig test_config = { + .dst_full_sync_en = dst_full_sync_en, + .input_single_tile_size = 2 * 1024, + .output_single_tile_size = 2 * 1024, + .num_tiles_r = num_tile[0], + .num_tiles_c = num_tile[1], + .untilize_type = unit_tests::compute::tilize::UntilizeType::UNPACK, + .golden_function = unit_tests::compute::gold_standard_untilize + }; + unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + } } } TEST_F(DeviceFixture, ComputeUnpackUntilizeShortInit) { vector > num_tiles = {{1, 4}, {2, 2}, {4, 1}}; for(auto num_tile : num_tiles) { - unit_tests::compute::tilize::TestConfig test_config = { - .short_init = true, - .input_single_tile_size = 2 * 1024, - .output_single_tile_size = 2 * 1024, - .num_tiles_r = num_tile[0], - .num_tiles_c = num_tile[1], - .untilize_type = unit_tests::compute::tilize::UntilizeType::UNPACK, - .golden_function = unit_tests::compute::gold_standard_untilize - }; - unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::tilize::TestConfig test_config = { + .short_init = true, + .dst_full_sync_en = dst_full_sync_en, + .input_single_tile_size = 2 * 1024, + .output_single_tile_size = 2 * 1024, + .num_tiles_r = num_tile[0], + .num_tiles_c = num_tile[1], + .untilize_type = unit_tests::compute::tilize::UntilizeType::UNPACK, + .golden_function = unit_tests::compute::gold_standard_untilize + }; + unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + } } } @@ -361,47 +379,55 @@ Following tests are for pack untilize TEST_F(DeviceFixture, ComputePackUntilize) { vector > num_tiles = {{1, 4}, {2, 2}, {4, 1}}; for(auto num_tile : num_tiles) { - unit_tests::compute::tilize::TestConfig test_config = { - .input_single_tile_size = 2 * 1024, - .output_single_tile_size = 2 * 1024, - .num_tiles_r = num_tile[0], - .num_tiles_c = num_tile[1], - .untilize_type = unit_tests::compute::tilize::UntilizeType::PACK, - .golden_function = unit_tests::compute::gold_standard_untilize - }; - unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::tilize::TestConfig test_config = { + .dst_full_sync_en = dst_full_sync_en, + .input_single_tile_size = 2 * 1024, + .output_single_tile_size = 2 * 1024, + .num_tiles_r = num_tile[0], + .num_tiles_c = num_tile[1], + .untilize_type = unit_tests::compute::tilize::UntilizeType::PACK, + .golden_function = unit_tests::compute::gold_standard_untilize + }; + unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + } } } TEST_F(DeviceFixture, ComputePackUntilizeShortInit) { vector > num_tiles = {{1, 4}, {2, 2}, {4, 1}}; for(auto num_tile : num_tiles) { - unit_tests::compute::tilize::TestConfig test_config = { - .short_init = true, - .input_single_tile_size = 2 * 1024, - .output_single_tile_size = 2 * 1024, - .num_tiles_r = num_tile[0], - .num_tiles_c = num_tile[1], - .untilize_type = unit_tests::compute::tilize::UntilizeType::PACK, - .golden_function = unit_tests::compute::gold_standard_untilize - }; - unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::tilize::TestConfig test_config = { + .short_init = true, + .dst_full_sync_en = dst_full_sync_en, + .input_single_tile_size = 2 * 1024, + .output_single_tile_size = 2 * 1024, + .num_tiles_r = num_tile[0], + .num_tiles_c = num_tile[1], + .untilize_type = unit_tests::compute::tilize::UntilizeType::PACK, + .golden_function = unit_tests::compute::gold_standard_untilize + }; + unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + } } - } TEST_F(DeviceFixture, ComputePackUntilizeDst) { vector > num_tiles = {{1, 4}, {2, 2}, {4, 1}}; for(auto num_tile : num_tiles) { - unit_tests::compute::tilize::TestConfig test_config = { - .input_single_tile_size = 2 * 1024, - .output_single_tile_size = 2 * 1024, - .num_tiles_r = num_tile[0], - .num_tiles_c = num_tile[1], - .untilize_type = unit_tests::compute::tilize::UntilizeType::DST, - .golden_function = unit_tests::compute::gold_standard_untilize - }; - unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + unit_tests::compute::tilize::TestConfig test_config = { + .dst_full_sync_en = dst_full_sync_en, + .input_single_tile_size = 2 * 1024, + .output_single_tile_size = 2 * 1024, + .num_tiles_r = num_tile[0], + .num_tiles_c = num_tile[1], + .untilize_type = unit_tests::compute::tilize::UntilizeType::DST, + .golden_function = unit_tests::compute::gold_standard_untilize + }; + unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + } } } @@ -412,19 +438,22 @@ TEST_F(DeviceFixture, ComputePackUntilizeDstTinyTile) { vector > test_config_values = {{1, 1, 1, 1}, {1, 1, 2, 1}, {1, 2, 2, 1}}; uint32_t face_c_dim = 16; for(auto test_config_value : test_config_values) { - uint32_t num_faces_per_tile = test_config_value[2]; - uint32_t face_r_dim = test_config_value[3]; - unit_tests::compute::tilize::TestConfig test_config = { - .short_init = true, - .input_single_tile_size = 2 * 1024, - .output_single_tile_size = 2 * num_faces_per_tile * face_r_dim * face_c_dim, - .num_tiles_r = test_config_value[0], - .num_tiles_c = test_config_value[1], - .num_faces_per_tile = num_faces_per_tile, - .face_r_dim = face_r_dim, - .untilize_type = unit_tests::compute::tilize::UntilizeType::DST, - .golden_function = unit_tests::compute::gold_standard_untilize - }; - unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + for (bool dst_full_sync_en : {true, false}) { + uint32_t num_faces_per_tile = test_config_value[2]; + uint32_t face_r_dim = test_config_value[3]; + unit_tests::compute::tilize::TestConfig test_config = { + .short_init = true, + .dst_full_sync_en = dst_full_sync_en, + .input_single_tile_size = 2 * 1024, + .output_single_tile_size = 2 * num_faces_per_tile * face_r_dim * face_c_dim, + .num_tiles_r = test_config_value[0], + .num_tiles_c = test_config_value[1], + .num_faces_per_tile = num_faces_per_tile, + .face_r_dim = face_r_dim, + .untilize_type = unit_tests::compute::tilize::UntilizeType::DST, + .golden_function = unit_tests::compute::gold_standard_untilize + }; + unit_tests::compute::tilize::run_single_core_tilize_program(this->devices_.at(0), test_config); + } } } diff --git a/tests/tt_metal/tt_metal/unit_tests_common/compute/matmul/test_matmul_X_tile.cpp b/tests/tt_metal/tt_metal/unit_tests_common/compute/matmul/test_matmul_X_tile.cpp index 5b2ec8fcc43..1ef579e6ed5 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/compute/matmul/test_matmul_X_tile.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_common/compute/matmul/test_matmul_X_tile.cpp @@ -27,6 +27,7 @@ struct MatmulTileConfig { bool with_bias = false; bool test_init_short = false; bool with_dt = true; + bool dst_full_sync_en = false; string reader_kernel; string compute_kernel; vector compute_kernel_args; @@ -215,6 +216,7 @@ bool matmul_tile(CommonFixture *fixture, tt_metal::Device *device, const MatmulT cfg.compute_kernel, core, tt_metal::ComputeConfig{.math_fidelity = cfg.math_fidelity, + .dst_full_sync_en = cfg.dst_full_sync_en, .compile_args = cfg.compute_kernel_args, .defines = compute_defines} ); @@ -292,206 +294,221 @@ bool matmul_tile(CommonFixture *fixture, tt_metal::Device *device, const MatmulT } // namespace unit_tests_common::matmul::test_matmul_X_tile TEST_F(CommonFixture, MatmulSingleTile){ - for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { - if (i == 1) continue; - unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { - .M = 1, .K = 1, .N = 1, - .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_blocked.cpp", - .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul.cpp", - .compute_kernel_args = { - 1, // block_tile_dim - 1, // dst_tile_rows - 1, // dst_tile_cols - 1, // block_cnt - 1, // in0_block_tile_cnt - 1, // in1_block_tile_cnt - 1 // out_block_tile_cnt - }, - .math_fidelity = MathFidelity(i) - }; - SHAPE shape = {1, 1, 32, 32}; - tt::log_info(tt::LogTest, "Math Fidelity = {}", i); - tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); - auto activations_tile_layout = convert_to_tile_layout(tensor.get_values()); - auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); - - auto identity = create_identity_matrix(32, 32, 32); //bfloat16 32x32 identity - auto weights_tile_layout = convert_to_tile_layout(identity); - auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); - - for(unsigned int id = 0; id < devices_.size(); id++){ - ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations, weights, tensor)); + for (bool dst_full_sync_en : {true, false}) { + for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { + if (i == 1) continue; + unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { + .M = 1, .K = 1, .N = 1, + .dst_full_sync_en = dst_full_sync_en, + .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_blocked.cpp", + .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul.cpp", + .compute_kernel_args = { + 1, // block_tile_dim + 1, // dst_tile_rows + 1, // dst_tile_cols + 1, // block_cnt + 1, // in0_block_tile_cnt + 1, // in1_block_tile_cnt + 1 // out_block_tile_cnt + }, + .math_fidelity = MathFidelity(i) + }; + SHAPE shape = {1, 1, 32, 32}; + tt::log_info(tt::LogTest, "Math Fidelity = {}", i); + tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); + auto activations_tile_layout = convert_to_tile_layout(tensor.get_values()); + auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); + + auto identity = create_identity_matrix(32, 32, 32); //bfloat16 32x32 identity + auto weights_tile_layout = convert_to_tile_layout(identity); + auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); + + for(unsigned int id = 0; id < devices_.size(); id++){ + ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations, weights, tensor)); + } } } } TEST_F(CommonFixture, MatmulMultiTile){ - for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { - if (i == 1) continue; - uint32_t M = 4; - uint32_t N = 4; - uint32_t K = 4; - unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { - .M = M, .K = K, .N = N, - .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_with_bias_blocked.cpp", - .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul_with_bias.cpp", - .compute_kernel_args = { - 1, // block_tile_dim, within block, how many tiles are on the K dim - M, // dst_tile_rows - N, // dst_tile_cols - K, // block_cnt, across blocks, how many tiles are on the K dim - M, // in0_block_tile_cnt, M * block_tile_dim - N, // in1_block_tile_cnt, N * block_tile_dim - (M * N), // out_block_tile_cnt - matmul_config.with_bias // whether or not to use bias - }, - .math_fidelity = MathFidelity(i) - }; - tt::log_info(tt::LogTest, "Math Fidelity = {}", i); - SHAPE shape = {1, 1, M * 32, K * 32}; - tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); - auto activations_tilized = test_utils::tilize(tensor.get_values(), M * 32, K * 32); - auto activations_tile_layout = convert_to_tile_layout(activations_tilized); - auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); - auto activations_tile_transposed = transpose_tiles(activations, M, K, 1); - - auto identity = create_identity_matrix(K * 32, N * 32, std::min(K, N) * 32); //bfloat16 32x32 identity - auto identity_tilized = test_utils::tilize(identity, K * 32, N * 32); - auto weights_tile_layout = convert_to_tile_layout(identity_tilized); - auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); - - for(unsigned int id = 0; id < devices_.size(); id++){ - ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); - log_info(LogTest, "Multi tile with no bias passed"); - matmul_config.with_bias = true; - ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); - log_info(LogTest, "Multi tile with bias passed"); + for (bool dst_full_sync_en : {true, false}) { + for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { + if (i == 1) continue; + uint32_t M = 4; + uint32_t N = 4; + uint32_t K = 4; + unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { + .M = M, .K = K, .N = N, + .dst_full_sync_en = dst_full_sync_en, + .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_with_bias_blocked.cpp", + .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul_with_bias.cpp", + .compute_kernel_args = { + 1, // block_tile_dim, within block, how many tiles are on the K dim + M, // dst_tile_rows + N, // dst_tile_cols + K, // block_cnt, across blocks, how many tiles are on the K dim + M, // in0_block_tile_cnt, M * block_tile_dim + N, // in1_block_tile_cnt, N * block_tile_dim + (M * N), // out_block_tile_cnt + matmul_config.with_bias // whether or not to use bias + }, + .math_fidelity = MathFidelity(i) + }; + tt::log_info(tt::LogTest, "Math Fidelity = {}", i); + SHAPE shape = {1, 1, M * 32, K * 32}; + tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); + auto activations_tilized = test_utils::tilize(tensor.get_values(), M * 32, K * 32); + auto activations_tile_layout = convert_to_tile_layout(activations_tilized); + auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); + auto activations_tile_transposed = transpose_tiles(activations, M, K, 1); + + auto identity = create_identity_matrix(K * 32, N * 32, std::min(K, N) * 32); //bfloat16 32x32 identity + auto identity_tilized = test_utils::tilize(identity, K * 32, N * 32); + auto weights_tile_layout = convert_to_tile_layout(identity_tilized); + auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); + + for(unsigned int id = 0; id < devices_.size(); id++){ + ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); + log_info(LogTest, "Multi tile with no bias passed"); + matmul_config.with_bias = true; + ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); + log_info(LogTest, "Multi tile with bias passed"); + } } } } TEST_F(CommonFixture, MatmulBlock){ - for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { - if (i == 1) continue; - uint32_t M = 4; - uint32_t N = 4; - uint32_t K = 4; - unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { - .M = M, .K = K, .N = N, - .test_init_short = false, - .with_dt = false, - .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_with_bias_blocked.cpp", - .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul_block.cpp", - .compute_kernel_args = { - 1, // block_tile_dim, within block, how many tiles are on the K dim - M, // dst_tile_rows - N, // dst_tile_cols - K, // block_cnt, across blocks, how many tiles are on the K dim - M, // in0_block_tile_cnt, M * block_tile_dim - N, // in1_block_tile_cnt, N * block_tile_dim - (M * N), // out_block_tile_cnt - }, - .math_fidelity = MathFidelity(i) - }; - tt::log_info(tt::LogTest, "Math Fidelity = {}", i); - SHAPE shape = {1, 1, M * 32, K * 32}; - tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); - auto activations_tilized = test_utils::tilize(tensor.get_values(), M * 32, K * 32); - auto activations_tile_layout = convert_to_tile_layout(activations_tilized); - auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); - auto activations_tile_transposed = transpose_tiles(activations, M, K, 1); - - auto identity = create_identity_matrix(K * 32, N * 32, std::min(K, N) * 32); //bfloat16 32x32 identity - auto identity_tilized = test_utils::tilize(identity, K * 32, N * 32); - auto weights_tile_layout = convert_to_tile_layout(identity_tilized); - auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); - - for(unsigned int id = 0; id < devices_.size(); id++){ - ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); + for (bool dst_full_sync_en : {true, false}) { + for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { + if (i == 1) continue; + uint32_t M = 4; + uint32_t N = 4; + uint32_t K = 4; + unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { + .M = M, .K = K, .N = N, + .test_init_short = false, + .with_dt = false, + .dst_full_sync_en = dst_full_sync_en, + .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_with_bias_blocked.cpp", + .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul_block.cpp", + .compute_kernel_args = { + 1, // block_tile_dim, within block, how many tiles are on the K dim + M, // dst_tile_rows + N, // dst_tile_cols + K, // block_cnt, across blocks, how many tiles are on the K dim + M, // in0_block_tile_cnt, M * block_tile_dim + N, // in1_block_tile_cnt, N * block_tile_dim + (M * N), // out_block_tile_cnt + }, + .math_fidelity = MathFidelity(i) + }; + tt::log_info(tt::LogTest, "Math Fidelity = {}", i); + SHAPE shape = {1, 1, M * 32, K * 32}; + tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); + auto activations_tilized = test_utils::tilize(tensor.get_values(), M * 32, K * 32); + auto activations_tile_layout = convert_to_tile_layout(activations_tilized); + auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); + auto activations_tile_transposed = transpose_tiles(activations, M, K, 1); + + auto identity = create_identity_matrix(K * 32, N * 32, std::min(K, N) * 32); //bfloat16 32x32 identity + auto identity_tilized = test_utils::tilize(identity, K * 32, N * 32); + auto weights_tile_layout = convert_to_tile_layout(identity_tilized); + auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); + + for(unsigned int id = 0; id < devices_.size(); id++){ + ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); + } } } } TEST_F(CommonFixture, MatmulBlockInitShort){ - for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { - if (i == 1) continue; - uint32_t M = 4; - uint32_t N = 4; - uint32_t K = 4; - unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { - .M = M, .K = K, .N = N, - .test_init_short = true, - .with_dt = false, - .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_with_bias_blocked.cpp", - .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul_block.cpp", - .compute_kernel_args = { - 1, // block_tile_dim, within block, how many tiles are on the K dim - M, // dst_tile_rows - N, // dst_tile_cols - K, // block_cnt, across blocks, how many tiles are on the K dim - M, // in0_block_tile_cnt, M * block_tile_dim - N, // in1_block_tile_cnt, N * block_tile_dim - (M * N), // out_block_tile_cnt - }, - .math_fidelity = MathFidelity(i) - }; - tt::log_info(tt::LogTest, "Math Fidelity = {}", i); - SHAPE shape = {1, 1, M * 32, K * 32}; - tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); - auto activations_tilized = test_utils::tilize(tensor.get_values(), M * 32, K * 32); - auto activations_tile_layout = convert_to_tile_layout(activations_tilized); - auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); - auto activations_tile_transposed = transpose_tiles(activations, M, K, 1); - - auto identity = create_identity_matrix(K * 32, N * 32, std::min(K, N) * 32); //bfloat16 32x32 identity - auto identity_tilized = test_utils::tilize(identity, K * 32, N * 32); - auto weights_tile_layout = convert_to_tile_layout(identity_tilized); - auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); - - for(unsigned int id = 0; id < devices_.size(); id++){ - ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); + for (bool dst_full_sync_en : {true, false}) { + for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { + if (i == 1) continue; + uint32_t M = 4; + uint32_t N = 4; + uint32_t K = 4; + unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { + .M = M, .K = K, .N = N, + .test_init_short = true, + .with_dt = false, + .dst_full_sync_en = dst_full_sync_en, + .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_with_bias_blocked.cpp", + .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul_block.cpp", + .compute_kernel_args = { + 1, // block_tile_dim, within block, how many tiles are on the K dim + M, // dst_tile_rows + N, // dst_tile_cols + K, // block_cnt, across blocks, how many tiles are on the K dim + M, // in0_block_tile_cnt, M * block_tile_dim + N, // in1_block_tile_cnt, N * block_tile_dim + (M * N), // out_block_tile_cnt + }, + .math_fidelity = MathFidelity(i) + }; + tt::log_info(tt::LogTest, "Math Fidelity = {}", i); + SHAPE shape = {1, 1, M * 32, K * 32}; + tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); + auto activations_tilized = test_utils::tilize(tensor.get_values(), M * 32, K * 32); + auto activations_tile_layout = convert_to_tile_layout(activations_tilized); + auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); + auto activations_tile_transposed = transpose_tiles(activations, M, K, 1); + + auto identity = create_identity_matrix(K * 32, N * 32, std::min(K, N) * 32); //bfloat16 32x32 identity + auto identity_tilized = test_utils::tilize(identity, K * 32, N * 32); + auto weights_tile_layout = convert_to_tile_layout(identity_tilized); + auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); + + for(unsigned int id = 0; id < devices_.size(); id++){ + ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); + } } } } TEST_F(CommonFixture, MatmulBlockInitShortWithDt){ - for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { - if (i == 1) continue; - uint32_t M = 4; - uint32_t N = 4; - uint32_t K = 4; - unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { - .M = M, .K = K, .N = N, - .test_init_short = true, - .with_dt = true, - .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_with_bias_blocked.cpp", - .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul_block.cpp", - .compute_kernel_args = { - 1, // block_tile_dim, within block, how many tiles are on the K dim - M, // dst_tile_rows - N, // dst_tile_cols - K, // block_cnt, across blocks, how many tiles are on the K dim - M, // in0_block_tile_cnt, M * block_tile_dim - N, // in1_block_tile_cnt, N * block_tile_dim - (M * N), // out_block_tile_cnt - }, - .math_fidelity = MathFidelity(i) - }; - tt::log_info(tt::LogTest, "Math Fidelity = {}", i); - SHAPE shape = {1, 1, M * 32, K * 32}; - tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); - auto activations_tilized = test_utils::tilize(tensor.get_values(), M * 32, K * 32); - auto activations_tile_layout = convert_to_tile_layout(activations_tilized); - auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); - auto activations_tile_transposed = transpose_tiles(activations, M, K, 1); - - auto identity = create_identity_matrix(K * 32, N * 32, std::min(K, N) * 32); //bfloat16 32x32 identity - auto identity_tilized = test_utils::tilize(identity, K * 32, N * 32); - auto weights_tile_layout = convert_to_tile_layout(identity_tilized); - auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); - - for(unsigned int id = 0; id < devices_.size(); id++){ - ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); + for (bool dst_full_sync_en : {true, false}) { + for (uint8_t i = uint8_t(MathFidelity::LoFi); i <= uint8_t(MathFidelity::HiFi4); i++) { + if (i == 1) continue; + uint32_t M = 4; + uint32_t N = 4; + uint32_t K = 4; + unit_tests_common::matmul::test_matmul_X_tile::MatmulTileConfig matmul_config = { + .M = M, .K = K, .N = N, + .test_init_short = true, + .with_dt = true, + .dst_full_sync_en = dst_full_sync_en, + .reader_kernel = "tests/tt_metal/tt_metal/test_kernels/dataflow/reader_matmul_with_bias_blocked.cpp", + .compute_kernel = "tests/tt_metal/tt_metal/test_kernels/compute/matmul_block.cpp", + .compute_kernel_args = { + 1, // block_tile_dim, within block, how many tiles are on the K dim + M, // dst_tile_rows + N, // dst_tile_cols + K, // block_cnt, across blocks, how many tiles are on the K dim + M, // in0_block_tile_cnt, M * block_tile_dim + N, // in1_block_tile_cnt, N * block_tile_dim + (M * N), // out_block_tile_cnt + }, + .math_fidelity = MathFidelity(i) + }; + tt::log_info(tt::LogTest, "Math Fidelity = {}", i); + SHAPE shape = {1, 1, M * 32, K * 32}; + tt::deprecated::Tensor tensor = tt::deprecated::initialize_tensor(shape, tt::deprecated::Initialize::RANDOM, 100, std::chrono::system_clock::now().time_since_epoch().count()); + auto activations_tilized = test_utils::tilize(tensor.get_values(), M * 32, K * 32); + auto activations_tile_layout = convert_to_tile_layout(activations_tilized); + auto activations = pack_bfloat16_vec_into_uint32_vec(activations_tile_layout); + auto activations_tile_transposed = transpose_tiles(activations, M, K, 1); + + auto identity = create_identity_matrix(K * 32, N * 32, std::min(K, N) * 32); //bfloat16 32x32 identity + auto identity_tilized = test_utils::tilize(identity, K * 32, N * 32); + auto weights_tile_layout = convert_to_tile_layout(identity_tilized); + auto weights = pack_bfloat16_vec_into_uint32_vec(weights_tile_layout); + + for(unsigned int id = 0; id < devices_.size(); id++){ + ASSERT_TRUE(unit_tests_common::matmul::test_matmul_X_tile::matmul_tile(this, devices_.at(id), matmul_config, activations_tile_transposed, weights, tensor)); + } } } } diff --git a/tests/ttnn/unit_tests/operations/test_matmul.py b/tests/ttnn/unit_tests/operations/test_matmul.py index 1b9015471d9..f0aad38592b 100644 --- a/tests/ttnn/unit_tests/operations/test_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_matmul.py @@ -1346,3 +1346,30 @@ def core_range_for_num_cores(num_cores): matmul_output = matmul_output + bias_tensor assert_with_pcc(matmul_output, tt_mm_out, pcc=0.993) + + +@pytest.mark.parametrize("M", [32, 128]) +@pytest.mark.parametrize("K", [32, 128]) +@pytest.mark.parametrize("N", [32, 128]) +def test_alternating_dst_sync_mode_matmul(device, M, K, N): + torch.manual_seed(0) + torch_input_tensor_a = torch.randn([1, 1, M, K], dtype=torch.bfloat16) + torch_input_tensor_b = torch.randn([1, 1, K, N], dtype=torch.bfloat16) + torch_output_tensor = torch.matmul(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device) + # Half sync mode + output1 = ttnn.matmul(input_tensor_a, input_tensor_b, core_grid=ttnn.CoreGrid(y=4, x=4)) + # Full sync mode + output2 = ttnn.matmul(input_tensor_a, input_tensor_b) + # Half sync mode + output3 = ttnn.matmul(input_tensor_a, input_tensor_b, core_grid=ttnn.CoreGrid(y=4, x=4)) + + pcc = 0.99 + output_tensor = ttnn.to_torch(output1) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + output_tensor = ttnn.to_torch(output2) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) + output_tensor = ttnn.to_torch(output3) + assert_with_pcc(torch_output_tensor, output_tensor, pcc=pcc) diff --git a/tt_metal/hw/ckernels/blackhole/metal/common/chlkc_list.h b/tt_metal/hw/ckernels/blackhole/metal/common/chlkc_list.h index 51a79b3f01d..30ed491da3b 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/common/chlkc_list.h +++ b/tt_metal/hw/ckernels/blackhole/metal/common/chlkc_list.h @@ -14,6 +14,7 @@ using namespace ckernel; #ifdef UCK_CHLKC_MATH // clang-format off #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_math_approx_mode.h" #include "chlkc_math_fidelity.h" #include "chlkc_unpack_data_format.h" @@ -24,6 +25,7 @@ using namespace ckernel; #ifdef UCK_CHLKC_PACK // clang-format off #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_pack_data_format.h" #include "chlkc_pack.cpp" // clang-format on @@ -32,6 +34,7 @@ using namespace ckernel; #ifdef UCK_CHLKC_UNPACK // clang-format off #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_unpack_data_format.h" #include "chlkc_unpack.cpp" // clang-format on diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_api.h index 58e1451c48f..efead1bca4a 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_api.h @@ -54,7 +54,7 @@ inline void llk_math_eltwise_binary(uint dst_index, const bool clear_fp32_dst_ac _llk_math_eltwise_binary_< eltwise_binary_type, src_b_bcast_type, - DstSync::SyncHalf, + DST_SYNC_MODE, NUM_FIDELITY_PHASES, binary_reuse_dest, is_fp32_dest_acc_en>(num_faces, dst_index, clear_fp32_dst_acc); @@ -77,7 +77,7 @@ inline void llk_math_eltwise_binary( _llk_math_eltwise_binary_< eltwise_binary_type, src_b_bcast_type, - DstSync::SyncHalf, + DST_SYNC_MODE, NUM_FIDELITY_PHASES, binary_reuse_dest, is_fp32_dest_acc_en>(num_faces, dst_index, clear_fp32_dst_acc); diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_sfpu_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_sfpu_api.h index 3fd56e4b805..5b25f2bd06c 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_sfpu_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_binary_sfpu_api.h @@ -26,7 +26,7 @@ inline void llk_math_eltwise_binary_sfpu( const std::uint32_t num_faces = get_operand_num_faces(operand_id); const std::uint32_t face_r_dim = get_operand_face_r_dim(operand_id); - _llk_math_eltwise_binary_sfpu_( + _llk_math_eltwise_binary_sfpu_( face_r_dim, num_faces, dst_index_a, dst_index_b, vector_mode, param0, param1, param2, param3, param4, param5); } @@ -39,7 +39,7 @@ inline void llk_math_eltwise_binary_sfpu_init( template inline void llk_math_eltwise_binary_sfpu_quant_int32( const uint operand, uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); + llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); } template @@ -50,7 +50,7 @@ inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) template inline void llk_math_eltwise_binary_sfpu_requant_int32( const uint operand, uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); + llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); } template @@ -61,7 +61,7 @@ inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_poin template inline void llk_math_eltwise_binary_sfpu_dequant_int32( const uint operand, uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); + llk_math_eltwise_binary_sfpu(operand, dst_index_a, dst_index_b, vector_mode); } template diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_common_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_common_api.h index 64dba6fb568..5a0c5c8b04b 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_common_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_common_api.h @@ -28,18 +28,18 @@ inline void llk_math_hw_configure_disaggregated() { inline void llk_math_wait_for_dest_available() { WAYPOINT("MWDW"); - _llk_math_wait_for_dest_available_(); + _llk_math_wait_for_dest_available_(); WAYPOINT("MWDD"); } template inline void llk_math_dest_section_done() { - _llk_math_dest_section_done_(); + _llk_math_dest_section_done_(); } template inline void llk_math_pack_sync_init() { - _llk_math_pack_sync_init_(); + _llk_math_pack_sync_init_(); } template diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_datacopy_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_datacopy_api.h index 38d8d26e691..d8615fb6889 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_datacopy_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_unary_datacopy_api.h @@ -18,7 +18,7 @@ template < bool unpack_to_dest = false> inline void llk_math_eltwise_unary_datacopy(uint dst_index, uint operand = 0) { const std::uint32_t operand_id = get_operand_id(operand); - _llk_math_eltwise_unary_datacopy_( + _llk_math_eltwise_unary_datacopy_( dst_index, unpack_src_format[operand_id], unpack_dst_format[operand_id]); } @@ -34,7 +34,7 @@ inline void llk_math_eltwise_unary_datacopy_block(uint start_dst_index, uint nti _llk_math_eltwise_unary_datacopy_< type, src_b_bcast_type, - DstSync::SyncHalf, + DST_SYNC_MODE, is_fp32_dest_acc_en, unpack_to_dest>(dst_index, unpack_src_format[operand_id], unpack_dst_format[operand_id]); } diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h index ec0a17a466f..6539851bfac 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_pack_api.h @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once +#include "chlkc_list.h" #include "circular_buffer.h" #include "ckernel.h" #include "ckernel_defs.h" @@ -175,7 +176,7 @@ inline void llk_pack(std::uint32_t tile_index, std::uint32_t output, std::uint32 std::uint32_t pack_tile_addr = get_output_tile_address(output_id, output_tile_index); - _llk_pack_( + _llk_pack_( tile_index, pack_tile_addr ); @@ -251,7 +252,7 @@ inline void llk_matmul_pack( std::uint32_t pack_tile_addr = get_output_tile_address(output_id, output_tile_index); - _llk_pack_(tile_index, pack_tile_addr); + _llk_pack_(tile_index, pack_tile_addr); } } @@ -271,7 +272,7 @@ inline void llk_packer_set_math_semaphore() { template inline void llk_pack_dest_section_done() { - _llk_pack_dest_section_done_(); + _llk_pack_dest_section_done_(); } template @@ -280,7 +281,7 @@ inline void llk_init_packer_dest_offset_registers(const std::uint32_t pack_outpu const std::uint32_t face_r_dim = get_output_face_r_dim(output_id); const bool narrow_tile = get_output_narrow_tile(output_id); - _llk_init_packer_dest_offset_registers_( + _llk_init_packer_dest_offset_registers_( face_r_dim, narrow_tile ); @@ -292,7 +293,7 @@ inline void llk_pack_dest_init(const std::uint32_t pack_output = 16) { const std::uint32_t face_r_dim = get_output_face_r_dim(output_id); const bool narrow_tile = get_output_narrow_tile(output_id); - _llk_pack_dest_init_( + _llk_pack_dest_init_( face_r_dim, narrow_tile ); diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_params.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_params.h index 85052c780e2..19f313debfa 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_params.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_params.h @@ -13,7 +13,7 @@ inline void llk_math_eltwise_unary_sfpu_params( int vector_mode = (int)VectorMode::RC, ARGS&& ... args) { - _llk_math_eltwise_unary_sfpu_start_(dst_index); + _llk_math_eltwise_unary_sfpu_start_(dst_index); if (vector_mode == (int)VectorMode::R) { // Do a row vector, Face0 + Face1 -- first iteration (first row) diff --git a/tt_metal/hw/ckernels/grayskull/metal/common/chlkc_list.h b/tt_metal/hw/ckernels/grayskull/metal/common/chlkc_list.h index 7bb5bb4d024..fe3a31119f5 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/common/chlkc_list.h +++ b/tt_metal/hw/ckernels/grayskull/metal/common/chlkc_list.h @@ -17,18 +17,21 @@ using namespace ckernel; #include "chlkc_math_fidelity.h" #include "chlkc_math_approx_mode.h" #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_math.cpp" #endif #ifdef UCK_CHLKC_PACK #include "chlkc_pack_data_format.h" #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_pack.cpp" #endif #ifdef UCK_CHLKC_UNPACK #include "chlkc_unpack_data_format.h" #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_unpack.cpp" #endif diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_binary_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_binary_api.h index 3d435766e0c..80af6ca9fcd 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_binary_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_binary_api.h @@ -59,7 +59,7 @@ inline void llk_math_eltwise_binary(uint dst_index, const bool clear_fp32_dst_ac _llk_math_eltwise_binary_< eltwise_binary_type, src_b_bcast_type, - DstSync::SyncHalf, + DST_SYNC_MODE, NUM_FIDELITY_PHASES, binary_reuse_dest, is_fp32_dest_acc_en>(num_faces, num_faces, dst_index, clear_fp32_dst_acc); @@ -82,7 +82,7 @@ inline void llk_math_eltwise_binary( _llk_math_eltwise_binary_< eltwise_binary_type, src_b_bcast_type, - DstSync::SyncHalf, + DST_SYNC_MODE, NUM_FIDELITY_PHASES, binary_reuse_dest, is_fp32_dest_acc_en>(num_faces, num_faces, dst_index, clear_fp32_dst_acc); diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h index d4e791e7738..ee1f4715f15 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h @@ -23,18 +23,18 @@ inline void llk_math_hw_configure_disaggregated() { /*Unused for GS*/ } inline void llk_math_wait_for_dest_available() { WAYPOINT("MWDW"); - _llk_math_wait_for_dest_available_(); + _llk_math_wait_for_dest_available_(); WAYPOINT("MWDD"); } template inline void llk_math_dest_section_done() { - _llk_math_dest_section_done_(); + _llk_math_dest_section_done_(); } template inline void llk_math_pack_sync_init() { - _llk_math_pack_sync_init_(); + _llk_math_pack_sync_init_(); } template diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_datacopy_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_datacopy_api.h index 825eed6e673..7f8e90e54dc 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_datacopy_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_datacopy_api.h @@ -17,7 +17,7 @@ template < bool unpack_to_dest = false> inline void llk_math_eltwise_unary_datacopy(uint dst_index, uint operand = 0 /* unused */) { - _llk_math_eltwise_unary_datacopy_(dst_index); + _llk_math_eltwise_unary_datacopy_(dst_index); } template diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h index 3975dfe89f5..5ab568c012e 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_unary_sfpu_api.h @@ -30,7 +30,7 @@ namespace ckernel { template inline void llk_math_eltwise_unary_sfpu_log(uint dst_index, int vector_mode = VectorMode::RC) { - llk_math_eltwise_unary_sfpu(dst_index, vector_mode); + llk_math_eltwise_unary_sfpu(dst_index, vector_mode); } template @@ -41,7 +41,7 @@ inline void llk_math_eltwise_unary_sfpu_log_init() { //abs template inline void llk_math_eltwise_unary_sfpu_abs(uint dst_index, int vector_mode = VectorMode::RC) { - llk_math_eltwise_unary_sfpu(dst_index, vector_mode); + llk_math_eltwise_unary_sfpu(dst_index, vector_mode); } template @@ -52,7 +52,7 @@ inline void llk_math_eltwise_unary_sfpu_abs_init() { //log with base template inline void llk_math_eltwise_unary_sfpu_log_with_base(uint dst_index, uint base, int vector_mode = VectorMode::RC) { - llk_math_eltwise_unary_sfpu(dst_index, vector_mode, base); + llk_math_eltwise_unary_sfpu(dst_index, vector_mode, base); } template @@ -62,7 +62,7 @@ inline void llk_math_eltwise_unary_sfpu_log_with_base_init() { template inline void llk_math_eltwise_unary_sfpu_tanh(uint dst_index, int vector_mode = VectorMode::RC) { - llk_math_eltwise_unary_sfpu(dst_index, vector_mode); + llk_math_eltwise_unary_sfpu(dst_index, vector_mode); } template @@ -73,7 +73,7 @@ inline void llk_math_eltwise_unary_sfpu_tanh_init() { inline void llk_math_eltwise_unary_sfpu_dropout( uint dst_index, int vector_mode, int integer_dropout, int scale_factor) { constexpr bool dont_care = false; - llk_math_eltwise_unary_sfpu( + llk_math_eltwise_unary_sfpu( dst_index, vector_mode, integer_dropout, scale_factor); } @@ -86,7 +86,7 @@ inline void llk_math_eltwise_unary_sfpu_dropout_init(uint seed = 0) { template inline void llk_math_eltwise_unary_sfpu_max(uint dst_index, int vector_mode = VectorMode::RC) { - llk_math_eltwise_unary_sfpu(dst_index, vector_mode); + llk_math_eltwise_unary_sfpu(dst_index, vector_mode); } template @@ -96,7 +96,7 @@ inline void llk_math_eltwise_unary_sfpu_max_init() { template inline void llk_math_eltwise_unary_sfpu_square(uint dst_index, int vector_mode = VectorMode::RC) { - llk_math_eltwise_unary_sfpu(dst_index, vector_mode); + llk_math_eltwise_unary_sfpu(dst_index, vector_mode); } template diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h index 29641384ef7..9960adf52e7 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_pack_api.h @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once +#include "chlkc_list.h" #include "ckernel.h" #include "ckernel_defs.h" #include "ckernel_template.h" @@ -112,7 +113,7 @@ inline void llk_pack(std::uint32_t tile_index, std::uint32_t output, std::uint32 std::uint32_t pack_tile_addr = get_output_tile_address(output_id, output_tile_index); - _llk_pack_( + _llk_pack_( tile_index, pack_dst_format[output_id], pack_tile_addr @@ -177,17 +178,17 @@ inline void llk_packer_set_math_semaphore() { template inline void llk_pack_dest_section_done() { - _llk_pack_dest_section_done_(); + _llk_pack_dest_section_done_(); } template inline void llk_init_packer_dest_offset_registers(const std::uint32_t pack_output = 16) { - _llk_init_packer_dest_offset_registers_(); + _llk_init_packer_dest_offset_registers_(); } template inline void llk_pack_dest_init(const std::uint32_t pack_output = 16) { - _llk_pack_dest_init_(); + _llk_pack_dest_init_(); } template @@ -281,7 +282,7 @@ inline void llk_matmul_pack(const std::uint32_t start_tile_index, const std::uin std::uint32_t pack_tile_addr = get_output_tile_address(output_id, output_tile_index); - _llk_pack_( + _llk_pack_( tile_index, pack_dst_format[output_id], pack_tile_addr diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_common_includes.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_common_includes.h index 575356fe0ff..ba547a6b5af 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_common_includes.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_common_includes.h @@ -75,7 +75,7 @@ inline void llk_math_eltwise_unary_sfpu( uint param4 = 0, uint param5 = 0) { - _llk_math_eltwise_unary_sfpu_start_(dst_index); + _llk_math_eltwise_unary_sfpu_start_(dst_index); if (vector_mode == (int)VectorMode::R) { // Do a row vector, Face0 + Face1 -- first iteration const int ITERATIONS = 1; diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/common/chlkc_list.h b/tt_metal/hw/ckernels/wormhole_b0/metal/common/chlkc_list.h index 5fad25c0062..ffb9b395721 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/common/chlkc_list.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/common/chlkc_list.h @@ -18,6 +18,7 @@ using namespace ckernel; #include "chlkc_math_fidelity.h" #include "chlkc_math_approx_mode.h" #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_math.cpp" #endif @@ -25,6 +26,7 @@ using namespace ckernel; #include "chlkc_pack_data_format.h" #include "chlkc_pack_tile_dims.h" #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_pack.cpp" #endif @@ -32,6 +34,7 @@ using namespace ckernel; #include "chlkc_unpack_data_format.h" #include "chlkc_unpack_tile_dims.h" #include "chlkc_dst_accum_mode.h" +#include "chlkc_dst_sync_mode.h" #include "chlkc_unpack.cpp" #endif diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_api.h index 58e1451c48f..efead1bca4a 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_api.h @@ -54,7 +54,7 @@ inline void llk_math_eltwise_binary(uint dst_index, const bool clear_fp32_dst_ac _llk_math_eltwise_binary_< eltwise_binary_type, src_b_bcast_type, - DstSync::SyncHalf, + DST_SYNC_MODE, NUM_FIDELITY_PHASES, binary_reuse_dest, is_fp32_dest_acc_en>(num_faces, dst_index, clear_fp32_dst_acc); @@ -77,7 +77,7 @@ inline void llk_math_eltwise_binary( _llk_math_eltwise_binary_< eltwise_binary_type, src_b_bcast_type, - DstSync::SyncHalf, + DST_SYNC_MODE, NUM_FIDELITY_PHASES, binary_reuse_dest, is_fp32_dest_acc_en>(num_faces, dst_index, clear_fp32_dst_acc); diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h index 156d19e5d73..bdca47da10d 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_binary_sfpu_api.h @@ -26,7 +26,7 @@ inline void llk_math_eltwise_binary_sfpu( const std::uint32_t num_faces = get_operand_num_faces(operand_id); const std::uint32_t face_r_dim = get_operand_face_r_dim(operand_id); - _llk_math_eltwise_binary_sfpu_( + _llk_math_eltwise_binary_sfpu_( face_r_dim, num_faces, dst_index_a, dst_index_b, vector_mode, param0, param1, param2, param3, param4, param5); } @@ -39,7 +39,7 @@ inline void llk_math_eltwise_binary_sfpu_init( template inline void llk_math_eltwise_binary_sfpu_quant_int32( uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); + llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); } template @@ -50,7 +50,7 @@ inline void llk_math_eltwise_binary_sfpu_quant_int32_init(const uint zero_point) template inline void llk_math_eltwise_binary_sfpu_requant_int32( uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); + llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); } template @@ -61,7 +61,7 @@ inline void llk_math_eltwise_binary_sfpu_requant_int32_init(const uint zero_poin template inline void llk_math_eltwise_binary_sfpu_dequant_int32( uint dst_index_a, uint dst_index_b, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); + llk_math_eltwise_binary_sfpu(dst_index_a, dst_index_b, vector_mode); } template diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_common_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_common_api.h index db5adb372f2..90d724edbf4 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_common_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_common_api.h @@ -26,18 +26,18 @@ inline void llk_math_hw_configure_disaggregated() { /*Unused for WHB0*/ } inline void llk_math_wait_for_dest_available() { WAYPOINT("MWDW"); - _llk_math_wait_for_dest_available_(); + _llk_math_wait_for_dest_available_(); WAYPOINT("MWDD"); } template inline void llk_math_dest_section_done() { - _llk_math_dest_section_done_(); + _llk_math_dest_section_done_(); } template inline void llk_math_pack_sync_init() { - _llk_math_pack_sync_init_(); + _llk_math_pack_sync_init_(); } template diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_datacopy_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_datacopy_api.h index 1eaf67dd605..4b980b31d5b 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_datacopy_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_datacopy_api.h @@ -18,7 +18,7 @@ template < bool unpack_to_dest = false> inline void llk_math_eltwise_unary_datacopy(uint dst_index, uint operand = 0) { const std::uint32_t operand_id = get_operand_id(operand); - _llk_math_eltwise_unary_datacopy_( + _llk_math_eltwise_unary_datacopy_( dst_index, unpack_src_format[operand_id], unpack_dst_format[operand_id]); } @@ -31,7 +31,7 @@ inline void llk_math_eltwise_unary_datacopy_block(uint start_dst_index, uint nti const std::uint32_t operand_id = get_operand_id(operand); for (uint32_t dst_index = start_dst_index; dst_index < start_dst_index + ntiles; dst_index++) { - _llk_math_eltwise_unary_datacopy_( + _llk_math_eltwise_unary_datacopy_( dst_index, unpack_src_format[operand_id], unpack_dst_format[operand_id]); } } diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h index 729daf16de9..2a99bc53fe6 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_pack_api.h @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once +#include "chlkc_list.h" #include "ckernel.h" #include "ckernel_defs.h" #include "ckernel_template.h" @@ -171,7 +172,7 @@ inline void llk_pack(std::uint32_t tile_index, std::uint32_t output, std::uint32 std::uint32_t pack_tile_addr = get_output_tile_address(output_id, output_tile_index); - _llk_pack_( + _llk_pack_( tile_index, pack_tile_addr ); @@ -230,7 +231,7 @@ inline void llk_matmul_pack(std::uint32_t start_tile_index, std::uint32_t output for (uint32_t tile_index=start_tile_index; tile_index < start_tile_index + ntiles; tile_index++) { std::uint32_t pack_tile_addr = get_output_tile_address(output_id, output_tile_index); - _llk_pack_( + _llk_pack_( tile_index, pack_tile_addr ); @@ -253,7 +254,7 @@ inline void llk_packer_set_math_semaphore() { template inline void llk_pack_dest_section_done() { - _llk_pack_dest_section_done_(); + _llk_pack_dest_section_done_(); } template @@ -262,7 +263,7 @@ inline void llk_init_packer_dest_offset_registers(const std::uint32_t pack_outpu const std::uint32_t face_r_dim = get_output_face_r_dim(output_id); const bool narrow_tile = get_output_narrow_tile(output_id); - _llk_init_packer_dest_offset_registers_( + _llk_init_packer_dest_offset_registers_( face_r_dim, narrow_tile ); @@ -275,7 +276,7 @@ inline void llk_pack_dest_init(const std::uint32_t pack_output = 16) { const std::uint32_t face_r_dim = get_output_face_r_dim(output_id); const bool narrow_tile = get_output_narrow_tile(output_id); - _llk_pack_dest_init_( + _llk_pack_dest_init_( face_r_dim, narrow_tile ); diff --git a/tt_metal/impl/kernels/kernel.cpp b/tt_metal/impl/kernels/kernel.cpp index a4bf717800c..a714fc2de8b 100644 --- a/tt_metal/impl/kernels/kernel.cpp +++ b/tt_metal/impl/kernels/kernel.cpp @@ -159,10 +159,11 @@ std::string EthernetKernel::config_hash() const { std::string ComputeKernel::config_hash() const { return fmt::format( - "{}_{}_{}", + "{}_{}_{}_{}", magic_enum::enum_name(this->config_.math_fidelity), this->config_.fp32_dest_acc_en, - this->config_.math_approx_mode); + this->config_.math_approx_mode, + this->config_.dst_full_sync_en); } std::string Kernel::compute_hash() const { @@ -310,6 +311,7 @@ void ComputeKernel::set_build_options(JitBuildOptions &build_options) const { build_options.set_hlk_math_fidelity_all_cores(this->config_.math_fidelity); build_options.set_hlk_math_approx_mode_all_cores(this->config_.math_approx_mode); build_options.fp32_dest_acc_en = this->config_.fp32_dest_acc_en; + build_options.dst_full_sync_en = this->config_.dst_full_sync_en; build_options.unpack_to_dest_mode = this->config_.unpack_to_dest_mode; build_options.hlk_defines = this->defines_; } diff --git a/tt_metal/impl/kernels/kernel_types.hpp b/tt_metal/impl/kernels/kernel_types.hpp index b056b2a8213..a1edd32eb50 100644 --- a/tt_metal/impl/kernels/kernel_types.hpp +++ b/tt_metal/impl/kernels/kernel_types.hpp @@ -48,6 +48,7 @@ struct WriterDataMovementConfig : public DataMovementConfig { struct ComputeConfig { MathFidelity math_fidelity = MathFidelity::HiFi4; bool fp32_dest_acc_en = false; + bool dst_full_sync_en = false; std::vector unpack_to_dest_mode; bool math_approx_mode = false; std::vector compile_args; diff --git a/tt_metal/jit_build/genfiles.cpp b/tt_metal/jit_build/genfiles.cpp index 9832f11ee00..ca47b722499 100644 --- a/tt_metal/jit_build/genfiles.cpp +++ b/tt_metal/jit_build/genfiles.cpp @@ -473,6 +473,22 @@ static void generate_dst_accum_mode_descriptor(JitBuildOptions& options) { file_stream.close(); } +static void generate_dst_sync_mode_descriptor(JitBuildOptions& options) { + string dst_sync_mode_descriptor = options.path + "chlkc_dst_sync_mode.h"; + + ofstream file_stream; + + file_stream.open(dst_sync_mode_descriptor); + + if (options.dst_full_sync_en) { + file_stream << "#define DST_SYNC_MODE DstSync::SyncFull" << endl; + } else { + file_stream << "#define DST_SYNC_MODE DstSync::SyncHalf" << endl; + } + + file_stream.close(); +} + static void generate_math_fidelity_descriptor(JitBuildOptions& options) { string math_fidelity_descriptor = options.path + "chlkc_math_fidelity.h"; // assuming all cores within a op have the same desc @@ -509,11 +525,13 @@ void jit_build_genfiles_descriptors(const JitBuildEnv& env, JitBuildOptions& opt std::thread tm( [&]() { generate_math_fidelity_descriptor(options); } ); std::thread ta( [&]() { generate_math_approx_mode_descriptor(options); } ); std::thread tf( [&]() { generate_dst_accum_mode_descriptor(options); } ); + std::thread ts( [&]() { generate_dst_sync_mode_descriptor(options); } ); td.join(); tt.join(); tm.join(); ta.join(); tf.join(); + ts.join(); } catch (std::runtime_error& ex) { std::cerr << "EXCEPTION FROM THREADING IN GENERATE_DESCRIPTORS: " << ex.what() << std::endl; } diff --git a/tt_metal/jit_build/settings.hpp b/tt_metal/jit_build/settings.hpp index d9edb2c483b..03fc367b134 100644 --- a/tt_metal/jit_build/settings.hpp +++ b/tt_metal/jit_build/settings.hpp @@ -27,6 +27,8 @@ class JitBuildOptions { bool fp32_dest_acc_en; std::vector unpack_to_dest_mode; + bool dst_full_sync_en; + std::map hlk_defines; // preprocessor defines for HLK std::map ncrisc_defines; std::map brisc_defines; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp index 28a6d1d3d94..bfac7fe8658 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp @@ -118,7 +118,7 @@ operation::ProgramWithCallbacks moreh_layernorm_impl( num_rows_per_core_group_2] = tt_metal::split_work_to_cores(grid, num_outer); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); // This could be inefficient. diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp index 90de1c698fc..f4ccc6d8775 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp @@ -82,7 +82,7 @@ operation::ProgramWithCallbacks moreh_layernorm_backward_gamma_beta_grad_impl( num_cols_per_core_group_2] = tt_metal::split_work_to_cores(grid, num_inner); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp index 4f219add8c2..d680ce29a48 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp @@ -93,7 +93,7 @@ operation::ProgramWithCallbacks moreh_layernorm_backward_input_grad_impl( num_rows_per_core_group_2] = tt_metal::split_work_to_cores(grid, num_outer); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// // CircularBuffer Setup diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/multi_core/moreh_matmul_op_multi_core.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/multi_core/moreh_matmul_op_multi_core.cpp index 3a50a63754f..45fc10cb821 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/multi_core/moreh_matmul_op_multi_core.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/multi_core/moreh_matmul_op_multi_core.cpp @@ -173,7 +173,7 @@ operation::ProgramWithCallbacks moreh_matmul_multi_core( need_other_mask_h, need_other_mask_w, other_mask_h, other_mask_w); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); log_debug( LogOp, "math_fidelity {} math_approx_mode {} fp32_dest_acc_en {} packer_l1_acc {}", diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp index 79dde3b8b1f..c7bebf27af8 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp @@ -36,7 +36,7 @@ operation::ProgramWithCallbacks moreh_softmax_c_large(const Tensor &input, const split_work_to_cores(core_range, num_tiles); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp index 54c64516708..d2ae2936535 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp @@ -36,7 +36,7 @@ operation::ProgramWithCallbacks moreh_softmax_h_large(const Tensor &input, const split_work_to_cores(core_range, num_cols_tiles); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp index 3a1cfa004b8..2a63ac8af88 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp @@ -26,7 +26,7 @@ bool is_moreh_softmax_h_small_available(const Tensor &tensor, const ttnn::Device int32_t Ht = (h + TILE_HEIGHT - 1) / TILE_HEIGHT; auto arch = tensor.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); auto data_format = tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; @@ -69,7 +69,7 @@ operation::ProgramWithCallbacks moreh_softmax_h_small(const Tensor &input, const split_work_to_cores(core_range, num_cols_tiles); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp index 43160f7092e..d29342441ae 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp @@ -37,7 +37,7 @@ operation::ProgramWithCallbacks moreh_softmax_w_large(const Tensor &input, const split_work_to_cores(core_range, num_kernel_rows); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp index 5e376fd4e68..2a05e42095a 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp @@ -26,7 +26,7 @@ bool is_moreh_softmax_w_small_available(const Tensor &tensor, const ttnn::Device int32_t Wt = (w + TILE_WIDTH - 1) / TILE_WIDTH; auto arch = tensor.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); auto data_format = tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; @@ -69,7 +69,7 @@ operation::ProgramWithCallbacks moreh_softmax_w_small(const Tensor &input, const split_work_to_cores(core_range, num_kernel_rows); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp index c5f9ee7cd73..632f5e8d3c4 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp @@ -37,7 +37,7 @@ operation::ProgramWithCallbacks moreh_softmax_backward_c_large(const Tensor &out split_work_to_cores(core_range, num_tiles); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp index 9cf7204055c..77dbaf76338 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp @@ -38,7 +38,7 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_large(const Tensor &out split_work_to_cores(core_range, num_cols_tiles); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp index 55e0c3d886c..cc64991b0e5 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp @@ -61,7 +61,7 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_small(const Tensor &out split_work_to_cores(core_range, num_cols_tiles); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp index 8f99f497c2f..5904de7044b 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp @@ -37,7 +37,7 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_large(const Tensor &out split_work_to_cores(core_range, num_kernel_rows); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp index 96385ae1aac..d0b07328121 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp @@ -62,7 +62,7 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_small(const Tensor &out split_work_to_cores(core_range, num_kernel_rows); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_h_impl/moreh_int_sum_h_impl.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_h_impl/moreh_int_sum_h_impl.cpp index 9cf8fa9c022..674c10156f6 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_h_impl/moreh_int_sum_h_impl.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_h_impl/moreh_int_sum_h_impl.cpp @@ -45,7 +45,7 @@ operation::ProgramWithCallbacks moreh_sum_int_h_impl(const Tensor &input, const const bool do_mask_h {(origin_H % TILE_HEIGHT) != 0}; const auto mask_h {do_mask_h ? origin_H % TILE_HEIGHT: TILE_HEIGHT}; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug( LogOp, "math_fidelity {} math_approx_mode {} fp32_dest_acc_en {} packer_l1_acc {}", @@ -224,7 +224,7 @@ operation::ProgramWithCallbacks moreh_sum_int_h_impl(const Tensor &input, const // const bool do_mask_h = (origin_H % TILE_HEIGHT) != 0; // const auto mask_h = do_mask_h ? origin_H % TILE_HEIGHT : TILE_HEIGHT; - // auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); + // auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); // log_debug( // LogOp, // "math_fidelity {} math_approx_mode {} fp32_dest_acc_en {} packer_l1_acc {}", diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_h_impl/moreh_sum_h_impl.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_h_impl/moreh_sum_h_impl.cpp index e5cdc3a08df..b86b202a00b 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_h_impl/moreh_sum_h_impl.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_h_impl/moreh_sum_h_impl.cpp @@ -37,7 +37,7 @@ operation::ProgramWithCallbacks moreh_sum_h_impl(const Tensor &a, const Tensor & const bool do_mask_h = (origin_H % TILE_HEIGHT) != 0; const auto mask_h = do_mask_h ? origin_H % TILE_HEIGHT : TILE_HEIGHT; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(a.device()->arch(), compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(a.device()->arch(), compute_kernel_config); log_debug( LogOp, "math_fidelity {} math_approx_mode {} fp32_dest_acc_en {} packer_l1_acc {}", diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_int_sum_nc_impl.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_int_sum_nc_impl.cpp index 8eea2505e7d..27348bbe1b3 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_int_sum_nc_impl.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_int_sum_nc_impl.cpp @@ -33,7 +33,7 @@ operation::ProgramWithCallbacks moreh_sum_int_nc_impl(const Tensor &input, const const auto [Wt, Ht, inner_tile_size, reduce_tile_size] = extract_and_scale_spatial_dims(input_shape, static_cast(dim)); const auto num_reduce_input_tile {input_shape[dim]}; const auto num_output_tiles {output.volume() / TILE_HW}; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug(LogOp, "reduce_tile_size {} inner_tile_size {} Ht {} Wt {}", reduce_tile_size, inner_tile_size, Ht, Wt); log_debug( diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_sum_nc_impl.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_sum_nc_impl.cpp index 03efdfee580..af1e57687c5 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_sum_nc_impl.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_nc_impl/moreh_sum_nc_impl.cpp @@ -33,7 +33,7 @@ operation::ProgramWithCallbacks moreh_sum_nc_impl(const Tensor &input, const Ten const auto [Wt, Ht, inner_tile_size, reduce_tile_size] = extract_and_scale_spatial_dims(input_shape, static_cast(dim)); const auto num_reduce_input_tile = input_shape[dim]; const auto num_output_tiles = output.volume() / TILE_HW; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug(LogOp, "reduce_tile_size {} inner_tile_size {} Ht {} Wt {}", reduce_tile_size, inner_tile_size, Ht, Wt); log_debug( diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_w_impl/moreh_int_sum_w_impl.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_w_impl/moreh_int_sum_w_impl.cpp index d8950592df7..641cd11ed9d 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_w_impl/moreh_int_sum_w_impl.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_w_impl/moreh_int_sum_w_impl.cpp @@ -44,7 +44,7 @@ operation::ProgramWithCallbacks moreh_sum_int_w_impl(const Tensor &input, const const bool do_mask_w {(origin_W % TILE_WIDTH) != 0}; const auto mask_w {do_mask_w ? origin_W % TILE_WIDTH : TILE_WIDTH}; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug( LogOp, "math_fidelity {} math_approx_mode {} fp32_dest_acc_en {} packer_l1_acc {}", diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_w_impl/moreh_sum_w_impl.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_w_impl/moreh_sum_w_impl.cpp index 10704672c88..3752c4d3009 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_w_impl/moreh_sum_w_impl.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_w_impl/moreh_sum_w_impl.cpp @@ -37,7 +37,7 @@ operation::ProgramWithCallbacks moreh_sum_w_impl(const Tensor &a, const Tensor & const bool do_mask_w = (origin_W % TILE_WIDTH) != 0; const auto mask_w = do_mask_w ? origin_W % TILE_WIDTH : TILE_WIDTH; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(a.device()->arch(), compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(a.device()->arch(), compute_kernel_config); log_debug( LogOp, "math_fidelity {} math_approx_mode {} fp32_dest_acc_en {} packer_l1_acc {}", diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_impl/moreh_sum_backward_impl.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_impl/moreh_sum_backward_impl.cpp index 120d480bb85..64617b62f2c 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_impl/moreh_sum_backward_impl.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_impl/moreh_sum_backward_impl.cpp @@ -105,7 +105,7 @@ operation::ProgramWithCallbacks moreh_sum_backward_impl( } } const auto num_input_grad_tiles = input_grad.volume() / TILE_HW; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(output_grad.device()->arch(), compute_kernel_config); for (auto i = 0; i < input_grad_rank; ++i) { diff --git a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp index 567b55b40f8..1d04fa81300 100644 --- a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp +++ b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.cpp @@ -12,7 +12,8 @@ DeviceComputeKernelConfig init_device_compute_kernel_config( const MathFidelity default_fidelity, bool default_approx_mode, bool default_fp32_acc, - bool default_l1_acc) { + bool default_l1_acc, + bool default_dst_full_sync_en) { DeviceComputeKernelConfig defaultConfig; if (device_kernel_config.has_value()) { @@ -24,19 +25,24 @@ DeviceComputeKernelConfig init_device_compute_kernel_config( TT_ASSERT(arch == tt::ARCH::GRAYSKULL, "kernel config is not for graykull"); MathFidelity math_fidelity = compute_kernel_config.math_fidelity; bool math_approx_mode = compute_kernel_config.math_approx_mode; + bool dst_full_sync_en = compute_kernel_config.dst_full_sync_en; defaultConfig = GrayskullComputeKernelConfig{ - .math_fidelity = math_fidelity, .math_approx_mode = math_approx_mode}; + .math_fidelity = math_fidelity, + .math_approx_mode = math_approx_mode, + .dst_full_sync_en = dst_full_sync_en}; } else if constexpr (std::is_same_v) { TT_ASSERT(ttnn::device::is_wormhole_or_blackhole(arch), "kernel config is not for wormhole_b0 or blackhole"); MathFidelity math_fidelity = compute_kernel_config.math_fidelity; bool math_approx_mode = compute_kernel_config.math_approx_mode; bool fp32_dest_acc_en = compute_kernel_config.fp32_dest_acc_en; bool packer_l1_acc = compute_kernel_config.packer_l1_acc; + bool dst_full_sync_en = compute_kernel_config.dst_full_sync_en; defaultConfig = WormholeComputeKernelConfig{ .math_fidelity = math_fidelity, .math_approx_mode = math_approx_mode, .fp32_dest_acc_en = fp32_dest_acc_en, - .packer_l1_acc = packer_l1_acc}; + .packer_l1_acc = packer_l1_acc, + .dst_full_sync_en = dst_full_sync_en}; } else { TT_THROW("arch not supported"); } @@ -52,7 +58,8 @@ DeviceComputeKernelConfig init_device_compute_kernel_config( .math_fidelity = default_fidelity, .math_approx_mode = default_approx_mode, .fp32_dest_acc_en = default_fp32_acc, - .packer_l1_acc = default_l1_acc}; + .packer_l1_acc = default_l1_acc, + .dst_full_sync_en = default_dst_full_sync_en}; } } } @@ -75,13 +82,14 @@ bool get_fp32_dest_acc_en(const std::optional& comput compute_kernel_config.value()); } -std::tuple get_compute_kernel_config_args( +std::tuple get_compute_kernel_config_args( tt::ARCH arch, const DeviceComputeKernelConfig compute_kernel_config) { MathFidelity math_fidelity; bool math_approx_mode; bool fp32_dest_acc_en; bool packer_l1_acc; + bool dst_full_sync_en; std::visit( [&](auto&& compute_kernel_config) { @@ -92,19 +100,21 @@ std::tuple get_compute_kernel_config_args( math_approx_mode = compute_kernel_config.math_approx_mode; fp32_dest_acc_en = false; packer_l1_acc = false; + dst_full_sync_en = false; } else if constexpr (std::is_same_v) { TT_ASSERT(ttnn::device::is_wormhole_or_blackhole(arch), "kernel config is not for wormhole_b0 or blackhole"); math_fidelity = compute_kernel_config.math_fidelity; math_approx_mode = compute_kernel_config.math_approx_mode; fp32_dest_acc_en = compute_kernel_config.fp32_dest_acc_en; packer_l1_acc = compute_kernel_config.packer_l1_acc; + dst_full_sync_en = compute_kernel_config.dst_full_sync_en; } else { TT_THROW("arch not supported"); } }, compute_kernel_config); - return std::make_tuple(math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc); + return std::make_tuple(math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en); } } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp index 800aa550744..ad38daea02f 100644 --- a/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp +++ b/ttnn/cpp/ttnn/operations/core/compute_kernel/compute_kernel_config.hpp @@ -12,6 +12,7 @@ namespace ttnn { struct GrayskullComputeKernelConfig { MathFidelity math_fidelity = MathFidelity::LoFi; bool math_approx_mode = true; + bool dst_full_sync_en = false; }; struct WormholeComputeKernelConfig { @@ -19,6 +20,7 @@ struct WormholeComputeKernelConfig { bool math_approx_mode = true; bool fp32_dest_acc_en = false; bool packer_l1_acc = false; + bool dst_full_sync_en = false; }; using BlackholeComputeKernelConfig = WormholeComputeKernelConfig; @@ -31,10 +33,11 @@ DeviceComputeKernelConfig init_device_compute_kernel_config( const MathFidelity default_fidelity = MathFidelity::LoFi, bool default_approx_mode = true, bool default_fp32_acc = false, - bool default_l1_acc = false); + bool default_l1_acc = false, + bool default_dst_full_sync_en = false); bool get_fp32_dest_acc_en(const std::optional& compute_kernel_config); -std::tuple get_compute_kernel_config_args(tt::ARCH arch, const DeviceComputeKernelConfig compute_kernel_config); +std::tuple get_compute_kernel_config_args(tt::ARCH arch, const DeviceComputeKernelConfig compute_kernel_config); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_program_factory.cpp b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_program_factory.cpp index 44f837bed5a..b6c1e3d4777 100644 --- a/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/reduction/fast_reduce_nc/device/fast_reduce_nc_program_factory.cpp @@ -60,7 +60,7 @@ operation::ProgramWithCallbacks reduce_nc_factory(const ttnn::Tensor &input, con const auto [Wt, Ht, inner_tile_size, reduce_tile_size] = extract_and_scale_spatial_dims(input_shape, static_cast(dim)); const auto num_reduce_input_tile = input_shape[dim]; const auto num_output_tiles = output.volume() / TILE_HW; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); // choose granularity as the largest factor of num_reduce_input_tile that is less than or equal to 8 uint32_t input_granularity; for (input_granularity = 8; input_granularity > 1; --input_granularity) { diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_program_factory.cpp index e10446dcbdf..09b633211ea 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op_multi_core_program_factory.cpp @@ -121,7 +121,10 @@ operation::ProgramWithCallbacks matmul_multi_core(const Tensor &a, const Tensor program, "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm.cpp", core_group_1, - tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .compile_args = compute_args_group_1}); + tt_metal::ComputeConfig{ + .math_fidelity = math_fidelity, + .dst_full_sync_en = true, + .compile_args = compute_args_group_1}); if (!core_group_2.ranges().empty()) { vector compute_args_group_2 = { @@ -136,7 +139,10 @@ operation::ProgramWithCallbacks matmul_multi_core(const Tensor &a, const Tensor program, "ttnn/cpp/ttnn/operations/matmul/device/kernels/compute/bmm.cpp", core_group_2, - tt_metal::ComputeConfig{.math_fidelity = math_fidelity, .compile_args = compute_args_group_2}); + tt_metal::ComputeConfig{ + .math_fidelity = math_fidelity, + .dst_full_sync_en = true, + .compile_args = compute_args_group_2}); } for (uint32_t i = 0, num_tiles_written = 0; i < num_cores; i++) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp index 68e0473fb53..d05423c6e49 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp @@ -54,7 +54,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program tt::tt_metal::split_work_to_cores(grid, num_tiles); auto arch = param_in.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/multi_core_program_factory.cpp index 32e6d73e2eb..e982ac75081 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adamw/device/multi_core_program_factory.cpp @@ -55,7 +55,7 @@ MorehAdamWDeviceOperation::MultiCore::cached_program_t MorehAdamWDeviceOperation split_work_to_cores(grid, num_units); auto arch = param_in.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_program_factory.cpp index 6261d1dccf2..c6c81d47360 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_program_factory.cpp @@ -41,7 +41,7 @@ MorehDotOperation::SingleCore::cached_program_t MorehDotOperation::SingleCore::c tt::tt_metal::Device* device = input_a.device(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); const uint32_t in0_t = 2; // a diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_program_factory.cpp index 273f2ea1b4b..811d00fb3f4 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm/device/moreh_layer_norm_program_factory.cpp @@ -112,7 +112,7 @@ MorehLayerNormOperation::ProgramFactory::cached_program_t MorehLayerNormOperatio tt::tt_metal::split_work_to_cores(grid, num_outer); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); // This could be inefficient. diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_program_factory.cpp index a1705c220bc..2aa4f06203b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_gamma_beta_grad_program_factory.cpp @@ -87,7 +87,7 @@ MorehLayerNormBackwardGammaBetaGradOperation::ProgramFactory::create( tt::tt_metal::split_work_to_cores(grid, num_inner); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_program_factory.cpp index 7f5bb891f38..ee112fea330 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_layer_norm_backward/device/moreh_layer_norm_backward_input_grad_program_factory.cpp @@ -82,7 +82,7 @@ MorehLayerNormBackwardInputGradOperation::ProgramFactory::create( tt::tt_metal::split_work_to_cores(grid, num_outer); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// // CircularBuffer Setup diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_multi_core_program_factory.cpp index cea37077dd3..62c48d1b95b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_multi_core_program_factory.cpp @@ -49,8 +49,7 @@ MorehBiasAddBackwardOperation::MultiCoreProgramFactory::create( auto grid = device->compute_with_storage_grid_size(); auto arch = device->arch(); const auto num_cores_y = grid.y; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = - get_compute_kernel_config_args(arch, compute_kernel_config); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); const auto [num_cores_to_be_used, all_cores, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_single_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_single_core_program_factory.cpp index 7a2ebe0c4ba..e78f22b56fc 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_single_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_single_core_program_factory.cpp @@ -57,7 +57,7 @@ MorehBiasAddBackwardOperation::SingleCoreProgramFactory::create( Device* device = output_grad.device(); auto arch = device->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp index 2c4a3af541f..8976019c2ac 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp @@ -238,7 +238,7 @@ MorehMatmulOperation::MultiCoreProgramFactory::cached_program_t MorehMatmulOpera other_mask_h, other_mask_w); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); log_debug( tt::LogOp, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_h_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_h_program_factory.cpp index 730383768c8..87dd9ca4f98 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_h_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_h_program_factory.cpp @@ -55,7 +55,7 @@ MorehMeanOperation::MorehMeanHFactory::cached_program_t MorehMeanOperation::More split_work_to_cores(core_range, units_to_divide); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); // create circular buffers diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_nc_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_nc_program_factory.cpp index 845d42255a6..ce444641db5 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_nc_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_nc_program_factory.cpp @@ -70,7 +70,7 @@ MorehMeanOperation::MorehMeanNCFactory::cached_program_t MorehMeanOperation::Mor split_work_to_cores(core_range, units_to_divide); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_w_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_w_program_factory.cpp index 25a72eda08c..f9223f831b7 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_w_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_w_program_factory.cpp @@ -54,7 +54,7 @@ MorehMeanOperation::MorehMeanWFactory::cached_program_t MorehMeanOperation::More split_work_to_cores(core_range, units_to_divide); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); // create circular buffers diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp index f58869288d1..a8c97375ea2 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp @@ -98,7 +98,7 @@ MorehMeanBackwardOperation::MorehMeanBackwardFactory::create( } } const auto num_input_grad_tiles = input_grad.volume() / tt::constants::TILE_HW; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(output_grad.device()->arch(), compute_kernel_config); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step1/device/moreh_nll_loss_step1_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step1/device/moreh_nll_loss_step1_program_factory.cpp index c2186bae1d3..48969c5f83a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step1/device/moreh_nll_loss_step1_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step1/device/moreh_nll_loss_step1_program_factory.cpp @@ -49,7 +49,7 @@ MorehNllLossStep1DeviceOperation::Factory::cached_program_t MorehNllLossStep1Dev auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/moreh_nll_loss_step2_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/moreh_nll_loss_step2_program_factory.cpp index 5628a866d09..dec3a5c3238 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/moreh_nll_loss_step2_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_step2/device/moreh_nll_loss_step2_program_factory.cpp @@ -46,7 +46,7 @@ MorehNllLossStep2DeviceOperation::Factory::cached_program_t moreh_nll_loss_step2 auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); @@ -225,7 +225,7 @@ MorehNllLossStep2DeviceOperation::Factory::cached_program_t moreh_nll_loss_step2 auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); @@ -413,7 +413,7 @@ MorehNllLossStep2DeviceOperation::Factory::cached_program_t moreh_nll_loss_step2 auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss_backward/device/moreh_nll_loss_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss_backward/device/moreh_nll_loss_backward_program_factory.cpp index 959fa829599..e093fdcb9f0 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss_backward/device/moreh_nll_loss_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss_backward/device/moreh_nll_loss_backward_program_factory.cpp @@ -43,7 +43,7 @@ MorehNllLossBackwardDeviceOperation::Factory::cached_program_t moreh_nll_loss_ba auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); @@ -219,7 +219,7 @@ MorehNllLossBackwardDeviceOperation::Factory::cached_program_t moreh_nll_loss_ba auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); @@ -393,7 +393,7 @@ MorehNllLossBackwardDeviceOperation::Factory::cached_program_t moreh_nll_loss_ba auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss_unreduced_backward/device/moreh_nll_loss_unreduced_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss_unreduced_backward/device/moreh_nll_loss_unreduced_backward_program_factory.cpp index 832b5249a7b..712d649eac3 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss_unreduced_backward/device/moreh_nll_loss_unreduced_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_nll_loss_unreduced_backward/device/moreh_nll_loss_unreduced_backward_program_factory.cpp @@ -40,7 +40,7 @@ MorehNllLossUnreducedBackwardDeviceOperation::Factory::cached_program_t moreh_nl auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); @@ -170,7 +170,7 @@ MorehNllLossUnreducedBackwardDeviceOperation::Factory::cached_program_t moreh_nl auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); @@ -297,7 +297,7 @@ MorehNllLossUnreducedBackwardDeviceOperation::Factory::cached_program_t moreh_nl auto [num_cores, all_cores, core_group_1, core_group_2, units_per_core_group_1, units_per_core_group_2] = split_work_to_cores(grid, units_to_divide); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(device->arch(), compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_h.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_h.cpp index 15893439a1e..20efbe66038 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_h.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_h.cpp @@ -47,7 +47,7 @@ MorehNormOperation::ProgramFactoryH::cached_program_t MorehNormOperation::Progra const auto num_cores_y = grid.y; auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, operation_attributes.compute_kernel_config); const auto diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_other.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_other.cpp index c9708989c16..1016d0811ee 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_other.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_other.cpp @@ -59,7 +59,7 @@ MorehNormOperation::ProgramFactoryOther::cached_program_t MorehNormOperation::Pr const auto num_cores_y = grid.y; auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, operation_attributes.compute_kernel_config); const auto diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_w.cpp index 0937d062e0f..4ca92f56c75 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm/device/moreh_norm_program_factory_w.cpp @@ -47,7 +47,7 @@ MorehNormOperation::ProgramFactoryW::cached_program_t MorehNormOperation::Progra const auto num_cores_y = grid.y; auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, operation_attributes.compute_kernel_config); const auto diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_program_factory.cpp index 7d19f1b49cd..5f7df3904ee 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_norm_backward/device/moreh_norm_backward_program_factory.cpp @@ -91,7 +91,7 @@ MorehNormBackwardOperation::ProgramFactory::cached_program_t MorehNormBackwardOp } const auto num_input_grad_tiles = input_grad.volume() / tt::constants::TILE_HW; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(output_grad.device()->arch(), operation_attributes.compute_kernel_config); auto [floored_p, decimal, p_is_negative] = get_floored_p_and_decimal_and_p_is_negative(p); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_program_factory.cpp index 474218ab73f..1e29f65e2c5 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sgd/device/moreh_sgd_program_factory.cpp @@ -55,7 +55,7 @@ MorehSgdOperation::ProgramFactory::cached_program_t MorehSgdOperation::ProgramFa tt::tt_metal::split_work_to_cores(grid, units_to_divide); auto arch = param_in.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp index b6751d7314e..a2e23628d2f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp @@ -13,7 +13,7 @@ bool is_moreh_softmax_w_small_available(const Tensor& tensor, const DeviceComput int32_t Wt = (w + tt::constants::TILE_WIDTH - 1) / tt::constants::TILE_WIDTH; auto arch = tensor.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); auto data_format = tt::tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); @@ -43,7 +43,7 @@ bool is_moreh_softmax_h_small_available(const Tensor& tensor, const DeviceComput int32_t Ht = (h + tt::constants::TILE_HEIGHT - 1) / tt::constants::TILE_HEIGHT; auto arch = tensor.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); auto data_format = tt::tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp index 5560d84fe17..520854fb899 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp @@ -37,7 +37,7 @@ MorehSoftmaxOperation::MorehSoftmaxCLargeFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_tiles); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp index 56493c381bf..5e1a8339063 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp @@ -36,7 +36,7 @@ MorehSoftmaxOperation::MorehSoftmaxHLargeFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_cols_tiles); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp index 4bc781e5235..1820a9f6df6 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp @@ -37,7 +37,7 @@ MorehSoftmaxOperation::MorehSoftmaxHSmallFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_cols_tiles); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp index 9384a3edb93..d2703e5a4c8 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp @@ -37,7 +37,7 @@ MorehSoftmaxOperation::MorehSoftmaxWLargeFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_kernel_rows); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp index b3a4b9f4191..627f9a2f827 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp @@ -37,7 +37,7 @@ MorehSoftmaxOperation::MorehSoftmaxWSmallFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_kernel_rows); auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp index 6e9047fe896..68c89a6cc4f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp @@ -38,7 +38,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardCLargeFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_tiles); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp index 3b9a213ec76..fc042e650bc 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp @@ -38,7 +38,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardHLargeFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_cols_tiles); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_small/softmax_backward_h_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_small/softmax_backward_h_small.cpp index 87d987d761e..a6b228f795a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_small/softmax_backward_h_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_small/softmax_backward_h_small.cpp @@ -38,7 +38,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardHSmallFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_cols_tiles); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_large/softmax_backward_w_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_large/softmax_backward_w_large.cpp index e162934ba7a..fb3db5e4216 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_large/softmax_backward_w_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_large/softmax_backward_w_large.cpp @@ -38,7 +38,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardWLargeFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_kernel_rows); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_small/softmax_backward_w_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_small/softmax_backward_w_small.cpp index 71c3c8143c6..f5af6071e38 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_small/softmax_backward_w_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_small/softmax_backward_w_small.cpp @@ -38,7 +38,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardWSmallFactory::create( tt::operations::primary::split_work_to_cores(core_range, num_kernel_rows); auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); Program program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_h_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_h_program_factory.cpp index 732771d5195..98905db038e 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_h_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_h_program_factory.cpp @@ -41,7 +41,7 @@ MorehSumOperation::MorehSumHIntFactory::cached_program_t MorehSumOperation::More const bool do_mask_h{(origin_H % tt::constants::TILE_HEIGHT) != 0}; const auto mask_h{do_mask_h ? origin_H % tt::constants::TILE_HEIGHT : tt::constants::TILE_HEIGHT}; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug( tt::LogOp, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_nc_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_nc_program_factory.cpp index 586269f9b6e..f5e8d21c80d 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_nc_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_nc_program_factory.cpp @@ -36,7 +36,7 @@ MorehSumOperation::MorehSumNCIntFactory::cached_program_t MorehSumOperation::Mor tt::operations::primary::extract_and_scale_spatial_dims(input_shape, static_cast(dim)); const auto num_reduce_input_tile{input_shape[dim]}; const auto num_output_tiles{output.volume() / tt::constants::TILE_HW}; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_w_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_w_program_factory.cpp index a6cc4020de8..8e9bf232b7e 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_w_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_w_program_factory.cpp @@ -43,7 +43,7 @@ MorehSumOperation::MorehSumWIntFactory::cached_program_t MorehSumOperation::More const bool do_mask_w{(origin_W % tt::constants::TILE_WIDTH) != 0}; const auto mask_w{do_mask_w ? origin_W % tt::constants::TILE_WIDTH : tt::constants::TILE_WIDTH}; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug( tt::LogOp, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp index 7e7a9cba93c..22bb26e37b8 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp @@ -41,7 +41,7 @@ MorehSumOperation::MorehSumHFactory::cached_program_t MorehSumOperation::MorehSu const bool do_mask_h = (origin_H % tt::constants::TILE_HEIGHT) != 0; const auto mask_h = do_mask_h ? origin_H % tt::constants::TILE_HEIGHT : tt::constants::TILE_HEIGHT; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug( tt::LogOp, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_nc_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_nc_program_factory.cpp index 0bf72ef9f31..f79ca7cf769 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_nc_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_nc_program_factory.cpp @@ -35,7 +35,7 @@ MorehSumOperation::MorehSumNCFactory::cached_program_t MorehSumOperation::MorehS tt::operations::primary::extract_and_scale_spatial_dims(input_shape, static_cast(dim)); const auto num_reduce_input_tile = input_shape[dim]; const auto num_output_tiles = output.volume() / tt::constants::TILE_HW; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp index 55ce22a917b..2c3eb23c611 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp @@ -40,7 +40,7 @@ MorehSumOperation::MorehSumWFactory::cached_program_t MorehSumOperation::MorehSu const bool do_mask_w = (origin_W % tt::constants::TILE_WIDTH) != 0; const auto mask_w = do_mask_w ? origin_W % tt::constants::TILE_WIDTH : tt::constants::TILE_WIDTH; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); log_debug( tt::LogOp, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp index 1fe7587eab9..00d0c17dd76 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum_backward/device/moreh_sum_backward_program_factory.cpp @@ -101,7 +101,7 @@ MorehSumBackwardOperation::ProgramFactory::cached_program_t MorehSumBackwardOper } } const auto num_input_grad_tiles = input_grad.volume() / tt::constants::TILE_HW; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(output_grad.device()->arch(), compute_kernel_config); for (auto i = 0; i < input_grad_rank; ++i) { diff --git a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp index addfb7ea9da..cc2d9e5ab45 100644 --- a/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp +++ b/ttnn/cpp/ttnn/operations/pool/upsample/device/upsample_bilinear_program_factory_multicore.cpp @@ -90,7 +90,7 @@ operation::ProgramWithCallbacks bilinear_multi_core(const Tensor &input, Tensor& uint32_t in_w = input.get_legacy_shape()[2]; uint32_t out_w =output.get_legacy_shape()[2]; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(input.device()->arch(), compute_kernel_config); auto shard_spec = input.shard_spec().value(); diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_h/reduce_op_multi_core_h.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_h/reduce_op_multi_core_h.cpp index 2719b0e1b56..3ad22e3eef1 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_h/reduce_op_multi_core_h.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_h/reduce_op_multi_core_h.cpp @@ -27,7 +27,7 @@ operation::ProgramWithCallbacks reduce_multi_core_h( uint32_t Ht = H / TILE_HEIGHT; uint32_t HtWt = Ht * Wt; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(a.device()->arch(), compute_kernel_config); tt_metal::Program program = tt_metal::CreateProgram(); diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_w/reduce_op_multi_core_w.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_w/reduce_op_multi_core_w.cpp index 8938454a3c4..9205f800f79 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_w/reduce_op_multi_core_w.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/multi_core_w/reduce_op_multi_core_w.cpp @@ -28,7 +28,7 @@ operation::ProgramWithCallbacks reduce_multi_core_w( uint32_t Wt = W / TILE_WIDTH; uint32_t Ht = H / TILE_HEIGHT; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(a.device()->arch(), compute_kernel_config); tt_metal::Program program = tt_metal::CreateProgram(); diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/single_core_hw/reduce_op_single_core_hw.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/single_core_hw/reduce_op_single_core_hw.cpp index 9ff0a768c26..f3cfe56730c 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/single_core_hw/reduce_op_single_core_hw.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/single_core_hw/reduce_op_single_core_hw.cpp @@ -23,7 +23,7 @@ operation::ProgramWithCallbacks reduce_single_core_hw( uint32_t W = shape[3], H = shape[2], NC = shape[1] * shape[0]; uint32_t HW = H * W; - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc] = + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(a.device()->arch(), compute_kernel_config); uint32_t Wt = W / TILE_WIDTH; From 13ab3ac85de2734f183048f8d64eef139cdb3310 Mon Sep 17 00:00:00 2001 From: Nenad Petrovic <109360062+npetrovic-tenstorrent@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:18:03 +0200 Subject: [PATCH 04/58] New sweeps and backward ops (#13513) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * #11512: New softmax unit test * #11512: Add softplus sweep * #11512: Add softplus for WH only * #11512: Add softplus to run config * #11512: Added bw sweeps log and relu6 * #11512: Update git workflow --------- Co-authored-by: “Nenad <“npetrovic@tenstorrent.com”> --- .github/workflows/ttnn-run-sweeps.yaml | 3 + .../sweeps/eltwise/unary/softplus/softplus.py | 87 +++++++++++++++++ .../eltwise/unary_backward/log_bw/log_bw.py | 96 +++++++++++++++++++ .../unary_backward/relu6_bw/relu6_bw.py | 96 +++++++++++++++++++ 4 files changed, 282 insertions(+) create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/softplus/softplus.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_backward/log_bw/log_bw.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_backward/relu6_bw/relu6_bw.py diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 011376e876c..3e5b58bddc4 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -74,6 +74,9 @@ on: - eltwise.unary.sinh.sinh - eltwise.unary.relu_min.relu_min - eltwise.unary.relu_max.relu_max + - eltwise.unary.softplus.softplus + - eltwise.unary_backward.log_bw.log_bw + - eltwise.unary_backward.relu6_bw.relu6_bw - eltwise.binary.subtract.subtract - eltwise.binary.multiply.multiply - eltwise.binary.div.div diff --git a/tests/sweep_framework/sweeps/eltwise/unary/softplus/softplus.py b/tests/sweep_framework/sweeps/eltwise/unary/softplus/softplus.py new file mode 100644 index 00000000000..e0b0780bb0f --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/softplus/softplus.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 +low = 0 +high = 100 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "xfail": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 32), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +def mesh_device_fixture(): + device = ttnn.open_device(device_id=0) + assert ttnn.device.is_wormhole_b0(device), "This op is available for Wormhole_B0 only" + yield (device, "Wormhole_B0") + ttnn.close_device(device) + del device + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + beta = torch.tensor(1, dtype=torch.bfloat16).uniform_(low, high).item() + threshold = torch.tensor(1, dtype=torch.bfloat16).uniform_(low, high).item() + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + torch_output_tensor = torch.nn.functional.softplus(torch_input_tensor_a, beta=beta, threshold=threshold) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.softplus(input_tensor_a, beta=beta, threshold=threshold, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/log_bw/log_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/log_bw/log_bw.py new file mode 100644 index 00000000000..2f42574a4e0 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/log_bw/log_bw.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 16), + "grad_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16], + "grad_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor = gen_func_with_cast_tt(partial(torch_random, low=-10, high=10, dtype=torch.float32), grad_dtype)( + input_shape + ) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-10, high=10, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_a.requires_grad = True + torch_input_tensor_a.retain_grad() + + intermediate_result = torch.log(torch_input_tensor_a) + intermediate_result.backward(gradient=torch_grad_tensor) + torch_output_tensor = torch_input_tensor_a.grad + + grad_tensor = ttnn.from_torch( + torch_grad_tensor, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a.detach().clone(), + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.log_bw(grad_tensor, input_tensor_a, memory_config=output_memory_config)[0] + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/relu6_bw/relu6_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/relu6_bw/relu6_bw.py new file mode 100644 index 00000000000..ff86458ab9b --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/relu6_bw/relu6_bw.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 16), + "grad_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16], + "grad_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_a.requires_grad = True + torch_input_tensor_a.retain_grad() + + intermediate_result = torch.nn.functional.relu6(torch_input_tensor_a) + intermediate_result.backward(gradient=torch_grad_tensor) + torch_output_tensor = torch_input_tensor_a.grad + + grad_tensor = ttnn.from_torch( + torch_grad_tensor, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a.detach().clone(), + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.relu6_bw(grad_tensor, input_tensor_a, memory_config=output_memory_config)[0] + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] From 7e3158788df07850fd1321e1be86f23b6a0263c9 Mon Sep 17 00:00:00 2001 From: Nathan Sidwell Date: Tue, 8 Oct 2024 12:26:56 -0400 Subject: [PATCH 05/58] #11791: Simplify startup section naming (#13578) --- tt_metal/hw/toolchain/erisc-b0-app-sections.ld | 1 - tt_metal/hw/toolchain/erisc-b0-kernel.ld | 5 ++--- tt_metal/hw/toolchain/sections.ld | 16 +++++++--------- tt_metal/hw/toolchain/tmu-crt0.S | 2 +- tt_metal/hw/toolchain/tmu-crt0k.S | 2 +- 5 files changed, 11 insertions(+), 15 deletions(-) diff --git a/tt_metal/hw/toolchain/erisc-b0-app-sections.ld b/tt_metal/hw/toolchain/erisc-b0-app-sections.ld index e9f17455c3c..6595233715f 100644 --- a/tt_metal/hw/toolchain/erisc-b0-app-sections.ld +++ b/tt_metal/hw/toolchain/erisc-b0-app-sections.ld @@ -12,7 +12,6 @@ OUTPUT_FORMAT("elf32-littleriscv", "elf32-littleriscv", "elf32-littleriscv") OUTPUT_ARCH(riscv) ENTRY(ApplicationHandler) -SEARCH_DIR("/opt/riscv32i/riscv32-unknown-elf/lib"); SECTIONS { /* Read-only sections, merged into text segment: */ diff --git a/tt_metal/hw/toolchain/erisc-b0-kernel.ld b/tt_metal/hw/toolchain/erisc-b0-kernel.ld index 37318f61ad1..b34355bd1c5 100644 --- a/tt_metal/hw/toolchain/erisc-b0-kernel.ld +++ b/tt_metal/hw/toolchain/erisc-b0-kernel.ld @@ -20,7 +20,6 @@ OUTPUT_FORMAT("elf32-littleriscv", "elf32-littleriscv", "elf32-littleriscv") OUTPUT_ARCH(riscv) ENTRY(_start) -SEARCH_DIR("/opt/riscv32i/riscv32-unknown-elf/lib"); SECTIONS { /* Read-only sections, merged into text segment: */ @@ -28,9 +27,9 @@ SECTIONS PROVIDE (__global_pointer$ = __firmware_global_pointer); PROVIDE (__erisc_jump_table = ORIGIN(ERISC_JUMPTABLE)); - .init __firmware_start : + .start __firmware_start : { - KEEP (*(SORT_NONE(.init))) + *(.start) } > REGION_APP_KERNEL_CODE code_l1 : { diff --git a/tt_metal/hw/toolchain/sections.ld b/tt_metal/hw/toolchain/sections.ld index efcbe4efcf5..fc441d139d4 100644 --- a/tt_metal/hw/toolchain/sections.ld +++ b/tt_metal/hw/toolchain/sections.ld @@ -39,20 +39,15 @@ OUTPUT_FORMAT("elf32-littleriscv", "elf32-littleriscv", "elf32-littleriscv") OUTPUT_ARCH(riscv) ENTRY(_start) -SEARCH_DIR("/opt/riscv32i/riscv32-unknown-elf/lib"); SECTIONS { - /* Read-only sections, merged into text segment: */ - .init TEXT_START : + .text TEXT_START : { /* Because TEXT_START might not be the start of a region, we need to force this section to be emitted so that following sections do not restart the region, if this one is empty. */ . = .; - KEEP (*(SORT_NONE(.init))) - } > REGION_CODE - .text : - { + *(.start) *(.text.unlikely .text.*_unlikely .text.unlikely.*) *(.text.exit .text.exit.*) *(.text.startup .text.startup.*) @@ -61,9 +56,12 @@ SECTIONS /* .gnu.warning sections are handled specially by elf32.em. */ *(.gnu.warning) } > REGION_CODE - .fini : + .init.fini : { - KEEP (*(SORT_NONE(.fini))) + /* We don't use .init/.fini (this isn't the '90s), make sure there aren't any. */ + KEEP (*(.init)) + KEEP (*(.fini)) + ASSERT(SIZEOF(.init.fini) == 0, ".init/.fini sections present"); } > REGION_CODE l1_data : diff --git a/tt_metal/hw/toolchain/tmu-crt0.S b/tt_metal/hw/toolchain/tmu-crt0.S index 279da79aba6..6ca0b611054 100644 --- a/tt_metal/hw/toolchain/tmu-crt0.S +++ b/tt_metal/hw/toolchain/tmu-crt0.S @@ -1,4 +1,4 @@ -.section .init +.section .start,"ax",@progbits .global _start .type _start, @function diff --git a/tt_metal/hw/toolchain/tmu-crt0k.S b/tt_metal/hw/toolchain/tmu-crt0k.S index 08a16257817..f5d4ec04215 100644 --- a/tt_metal/hw/toolchain/tmu-crt0k.S +++ b/tt_metal/hw/toolchain/tmu-crt0k.S @@ -1,4 +1,4 @@ -.section .init +.section .start,"ax",@progbits .global _start .type _start, @function From 4e5fb94c7b98020ca074241ff140e84e767f0f04 Mon Sep 17 00:00:00 2001 From: Samarth Agarwal Date: Tue, 8 Oct 2024 12:38:43 -0400 Subject: [PATCH 06/58] #13493: Update test_bw_and_latency to support more cases (#13495) * #0: Initial changes to test_bw_and_latency * #0: Changed log message * #0: Added bash script to run test_bw_and_latency * #0: Fixed issues with test and script --- .../dispatch/kernels/bw_and_latency.cpp | 15 +++- .../dispatch/run_bw_and_latency.sh | 80 +++++++++++++++++ .../dispatch/test_bw_and_latency.cpp | 86 ++++++++++++++++--- 3 files changed, 167 insertions(+), 14 deletions(-) create mode 100755 tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/run_bw_and_latency.sh diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/kernels/bw_and_latency.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/kernels/bw_and_latency.cpp index 563a77d51b0..731ac407e0f 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/kernels/bw_and_latency.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/kernels/bw_and_latency.cpp @@ -8,28 +8,41 @@ void kernel_main() { #else uint32_t page_size = get_arg_val(0); #endif + cb_reserve_back(0, PAGE_COUNT); uint32_t cb_addr = get_write_ptr(0); for (int i = 0; i < ITERATIONS; i++) { uint32_t read_ptr = cb_addr; + uint32_t write_ptr = cb_addr; for (int j = 0; j < PAGE_COUNT; j++) { + #if DRAM_BANKED uint64_t noc_addr = get_dram_noc_addr(j, page_size, 0); #else uint64_t noc_addr = NOC_XY_ADDR(NOC_X(NOC_ADDR_X), NOC_Y(NOC_ADDR_Y), NOC_MEM_ADDR); #endif -#if READ_ONE_PACKET + +#if ISSUE_MCAST + uint64_t dst_noc_multicast_addr = + get_noc_multicast_addr(NOC_ADDR_X, NOC_ADDR_Y, MCAST_NOC_END_ADDR_X, MCAST_NOC_END_ADDR_Y, NOC_MEM_ADDR); + noc_async_write_multicast(write_ptr, dst_noc_multicast_addr, page_size, NUM_MCAST_DESTS); +#elif READ_ONE_PACKET noc_async_read_one_packet(noc_addr, read_ptr, page_size); #else noc_async_read(noc_addr, read_ptr, page_size); #endif + #if LATENCY noc_async_read_barrier(); + noc_async_write_barrier(); #endif + read_ptr += page_size; + write_ptr += page_size; } } #if !LATENCY noc_async_read_barrier(); + noc_async_write_barrier(); #endif } diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/run_bw_and_latency.sh b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/run_bw_and_latency.sh new file mode 100755 index 00000000000..90a4972019b --- /dev/null +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/run_bw_and_latency.sh @@ -0,0 +1,80 @@ +#!/bin/bash + +if [ "$ARCH_NAME" = "grayskull" ]; then + echo "Configured core range for grayskull" + max_x=11 + max_y=8 +elif [ "$ARCH_NAME" = "wormhole_b0" ]; then + echo "Configured core range for wormhole_b0" + max_x=7 + max_y=6 +elif [ "$ARCH_NAME" = "blackhole" ]; then + echo "Configured core range for blackhole" + max_x=12 + max_y=9 +else + echo "Unknown arch: $ARCH_NAME" + exit 1 +fi + +function get_half_way_away_core_x() { + half_way_away_core_x=$(( ($1 + (($max_x + 1) / 2)) % ($max_x + 1) )) + echo $half_way_away_core_x +} + +function get_half_way_away_core_y() { + half_way_away_core_y=$(( ($1 + (($max_y + 1) / 2)) % ($max_y + 1) )) + echo $half_way_away_core_y +} + +function read_from_half_way_away_core() { + half_way_away_core_x=$(get_half_way_away_core_x $1) + half_way_away_core_y=$(get_half_way_away_core_y $2) + echo "./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 2 -rx $1 -ry $2 -sx $half_way_away_core_x -sy $half_way_away_core_y" + ./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 2 -rx $1 -ry $2 -sx $half_way_away_core_x -sy $half_way_away_core_y +} + +function mcast_write_to_half_way_away_core() { + half_way_away_core_x=$(get_half_way_away_core_x $1) + half_way_away_core_y=$(get_half_way_away_core_y $2) + echo "./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $1 -ry $2 -sx $half_way_away_core_x -sy $half_way_away_core_y -tx $half_way_away_core_x -ty $half_way_away_core_y" + ./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $1 -ry $2 -sx $half_way_away_core_x -sy $half_way_away_core_y -tx $half_way_away_core_x -ty $half_way_away_core_y +} + +function mcast_write_to_adjacent_core() { + adj_core_y=$(($2 + 1)) + if [ $adj_core_y -gt $max_y ]; then + adj_core_y=$(($2 - 1)) + fi + echo "./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $1 -ry $2 -sx $1 -sy $adj_core_y -tx $1 -ty $adj_core_y" + ./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $1 -ry $2 -sx $1 -sy $adj_core_y -tx $1 -ty $adj_core_y +} + +function mcast_write_from_core_after_curr_core_to_half_way_away_core() { + half_way_away_core_x=$(get_half_way_away_core_x $1) + half_way_away_core_y=$(get_half_way_away_core_y $2) + mcast_start_y=$(($2 + 1)) + echo "./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $1 -ry $2 -sx $1 -sy $mcast_start_y -tx $half_way_away_core_x -ty $half_way_away_core_y" + ./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $1 -ry $2 -sx $1 -sy $mcast_start_y -tx $half_way_away_core_x -ty $half_way_away_core_y +} + +for ((x=0; x<=max_x; x++)); do + for ((y=0; y<=max_y; y++)); do + read_from_half_way_away_core $x $y + mcast_write_to_half_way_away_core $x $y + mcast_write_to_adjacent_core $x $y + mcast_write_from_core_after_curr_core_to_half_way_away_core $x $y + + if [ $y -eq 0 ]; then + mcast_start_y=$(($y + 1)) + echo "./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $x -ry $y -sx 0 -sy $mcast_start_y -tx $max_x -ty $max_y" + ./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $x -ry $y -sx 0 -sy $mcast_start_y -tx $max_x -ty $max_y + fi + + if [ $y -eq $max_y ]; then + mcast_end_y=$(($y - 1)) + echo "./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $x -ry $y -sx 0 -sy 0 -tx $max_x -ty $mcast_end_y" + ./build/test/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency -m 6 -rx $x -ry $y -sx 0 -sy 0 -tx $max_x -ty $mcast_end_y + fi + done +done diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency.cpp index e442ae3f91a..e8999324dc3 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_bw_and_latency.cpp @@ -3,9 +3,13 @@ // SPDX-License-Identifier: Apache-2.0 #include +#include #include #include +#include +#include "core_coord.h" +#include "logger.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/detail/tt_metal.hpp" #include "tt_metal/llrt/rtoptions.hpp" @@ -22,7 +26,7 @@ constexpr uint32_t DEFAULT_BATCH_SIZE_K = 512; ////////////////////////////////////////////////////////////////////////////////////////// // Test dispatch program performance // -// Test read bw and latency from host/dram/l1 +// Test read/write bw and latency from host/dram/l1 ////////////////////////////////////////////////////////////////////////////////////////// using namespace tt; @@ -30,6 +34,7 @@ uint32_t iterations_g = DEFAULT_ITERATIONS; uint32_t warmup_iterations_g = DEFAULT_WARMUP_ITERATIONS; CoreRange worker_g = {{0, 0}, {0, 0}}; CoreCoord src_worker_g = {0, 0}; +CoreRange mcast_src_workers_g = {{0, 0}, {0, 0}}; uint32_t page_size_g; uint32_t page_count_g; uint32_t source_mem_g; @@ -53,15 +58,17 @@ void init(int argc, char **argv) { log_info(LogTest, " -i: iterations (default {})", DEFAULT_ITERATIONS); log_info(LogTest, " -bs: batch size in K of data to xfer in one iteration (default {}K)", DEFAULT_BATCH_SIZE_K); log_info(LogTest, " -p: page size (default {})", DEFAULT_PAGE_SIZE); - log_info(LogTest, " -m: source mem, 0:PCIe, 1:DRAM, 2:L1, 3:ALL_DRAMs, 4:HOST_READ, 5:HOST_WRITE (default 0:PCIe)"); + log_info(LogTest, " -m: source mem, 0:PCIe, 1:DRAM, 2:L1, 3:ALL_DRAMs, 4:HOST_READ, 5:HOST_WRITE, 6:MULTICAST_WRITE (default 0:PCIe)"); log_info(LogTest, " -l: measure latency (default is bandwidth)"); - log_info(LogTest, " -rx: X of core to issue read (default {})", 1); - log_info(LogTest, " -ry: Y of core to issue read (default {})", 0); + log_info(LogTest, " -rx: X of core to issue read or write (default {})", 1); + log_info(LogTest, " -ry: Y of core to issue read or write (default {})", 0); + log_info(LogTest, " -sx: when reading from L1, X of core to read from. when issuing a multicast write, X of start core to write to. (default {})", 0); + log_info(LogTest, " -sy: when reading from L1, Y of core to read from. when issuing a multicast write, Y of start core to write to. (default {})", 0); + log_info(LogTest, " -tx: when issuing a multicast write, X of end core to write to (default {})", 0); + log_info(LogTest, " -ty: when issuing a multicast write, Y of end core to write to (default {})", 0); log_info(LogTest, " -c: when reading from dram, DRAM channel (default 0)"); - log_info(LogTest, " -sx: when reading from L1, X of core to read from (default {})", 0); - log_info(LogTest, " -sy: when reading from L1, Y of core to read (default {})", 0); log_info(LogTest, " -f: time just the finish call (use w/ lazy mode) (default disabled)"); - log_info(LogTest, " -o: use read_one_packet API. restrices page size to 8K max (default {})", 0); + log_info(LogTest, " -o: use read_one_packet API. restricts page size to 8K max (default {})", 0); log_info(LogTest, " -z: enable dispatch lazy mode (default disabled)"); log_info(LogTest, " -hr: hammer write_reg while executing (for PCIe test)"); log_info(LogTest, " -hp: hammer hugepage PCIe memory while executing (for PCIe test)"); @@ -80,9 +87,11 @@ void init(int argc, char **argv) { hammer_pcie_type_g = test_args::get_command_option_uint32(input_args, "-hpt", 0); time_just_finish_g = test_args::has_command_option(input_args, "-f"); source_mem_g = test_args::get_command_option_uint32(input_args, "-m", 0); - dram_channel_g = test_args::get_command_option_uint32(input_args, "-c", 0); uint32_t src_core_x = test_args::get_command_option_uint32(input_args, "-sx", 0); uint32_t src_core_y = test_args::get_command_option_uint32(input_args, "-sy", 0); + uint32_t mcast_end_core_x = test_args::get_command_option_uint32(input_args, "-tx", 0); + uint32_t mcast_end_core_y = test_args::get_command_option_uint32(input_args, "-ty", 0); + dram_channel_g = test_args::get_command_option_uint32(input_args, "-c", 0); uint32_t size_bytes = test_args::get_command_option_uint32(input_args, "-bs", DEFAULT_BATCH_SIZE_K) * 1024; latency_g = test_args::has_command_option(input_args, "-l"); page_size_g = test_args::get_command_option_uint32(input_args, "-p", DEFAULT_PAGE_SIZE); @@ -96,6 +105,25 @@ void init(int argc, char **argv) { worker_g = CoreRange({core_x, core_y}, {core_x, core_y}); src_worker_g = {src_core_x, src_core_y}; + + if (source_mem_g == 6) + { + if (mcast_end_core_x < src_core_x || mcast_end_core_y < src_core_y) + { + log_info(LogTest, "X of end core must be >= X of start core, Y of end core must be >= Y of start core"); + exit(-1); + } + + mcast_src_workers_g = CoreRange({src_core_x, src_core_y}, {mcast_end_core_x, mcast_end_core_y}); + + if (mcast_src_workers_g.intersects(worker_g)) { + log_info( + LogTest, + "Multicast destination rectangle and core that issues the multicast cannot overlap - Multicast " + "destination rectangle: {} Master core: {}", mcast_src_workers_g.str(), worker_g.start_coord.str()); + exit(-1); + } + } } #define CACHE_LINE_SIZE 64 @@ -136,6 +164,10 @@ int main(int argc, char **argv) { uint32_t noc_addr_x, noc_addr_y; uint64_t noc_mem_addr = 0; uint32_t dram_banked = 0; + uint32_t issue_mcast = 0; + uint32_t num_mcast_dests = mcast_src_workers_g.size(); + uint32_t mcast_noc_addr_end_x = 0; + uint32_t mcast_noc_addr_end_y = 0; chip_id_t mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device->id()); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device->id()); @@ -202,6 +234,18 @@ int main(int argc, char **argv) { noc_addr_y = w.y; } break; + case 6: + { + src_mem = "FROM_L1_TO_MCAST"; + issue_mcast = 1; + CoreCoord start = device->physical_core_from_logical_core(mcast_src_workers_g.start_coord, CoreType::WORKER); + CoreCoord end = device->physical_core_from_logical_core(mcast_src_workers_g.end_coord, CoreType::WORKER); + noc_addr_x = start.x; + noc_addr_y = start.y; + mcast_noc_addr_end_x = end.x; + mcast_noc_addr_end_y = end.y; + } + break; } std::map defines = { @@ -212,7 +256,11 @@ int main(int argc, char **argv) { {"NOC_ADDR_Y", std::to_string(noc_addr_y)}, {"NOC_MEM_ADDR", std::to_string(noc_mem_addr)}, {"READ_ONE_PACKET", std::to_string(read_one_packet_g)}, - {"DRAM_BANKED", std::to_string(dram_banked)} + {"DRAM_BANKED", std::to_string(dram_banked)}, + {"ISSUE_MCAST", std::to_string(issue_mcast)}, + {"NUM_MCAST_DESTS", std::to_string(num_mcast_dests)}, + {"MCAST_NOC_END_ADDR_X", std::to_string(mcast_noc_addr_end_x)}, + {"MCAST_NOC_END_ADDR_Y", std::to_string(mcast_noc_addr_end_y)} }; if (!page_size_as_runtime_arg_g) { defines.insert(pair("PAGE_SIZE", std::to_string(page_size_g))); @@ -243,11 +291,23 @@ int main(int argc, char **argv) { log_info(LogTest, "Reading: {} - core ({}, {})", src_mem, w.x, w.y); } else if (source_mem_g == 5) { log_info(LogTest, "Writing: {} - core ({}, {})", src_mem, w.x, w.y); + } else if (source_mem_g == 6) { + log_info(LogTest, "Writing: {} - core grid [({}, {}) - ({}, {})]", src_mem, noc_addr_x, noc_addr_y, mcast_noc_addr_end_x, mcast_noc_addr_end_y); } else { log_info(LogTest, "Reading: {} - core ({}, {})", src_mem, noc_addr_x, noc_addr_y); } - if (source_mem_g != 4) { - log_info(LogTest, "Using API: {}", read_one_packet_g ? "noc_async_read_one_packet" : "noc_async_read"); + if (source_mem_g < 4 || source_mem_g == 6) { + std::string api; + if (issue_mcast) { + api = "noc_async_write_multicast"; + } + else if (read_one_packet_g) { + api = "noc_async_read_one_packet"; + } + else { + api = "noc_async_read"; + } + log_info(LogTest, "Using API: {}", api); log_info(LogTest, "Lazy: {}", lazy_g); log_info(LogTest, "Page size ({}): {}", page_size_as_runtime_arg_g ? "runtime arg" : "compile time define", page_size_g); log_info(LogTest, "Size per iteration: {}", page_count_g * page_size_g); @@ -259,7 +319,7 @@ int main(int argc, char **argv) { vectorblank(page_size_g / sizeof(uint32_t)); std::chrono::duration elapsed_seconds; - if (source_mem_g < 4) { + if (source_mem_g < 4 || source_mem_g == 6) { // Cache stuff for (int i = 0; i < warmup_iterations_g; i++) { EnqueueProgram(cq, program, false); @@ -313,7 +373,7 @@ int main(int argc, char **argv) { Finish(cq); auto end = std::chrono::system_clock::now(); elapsed_seconds = (end-start); - } else { + } else if (source_mem_g == 4 || source_mem_g == 5) { vector vec; vec.resize(page_size_g / sizeof(uint32_t)); From 2a0724ca7ea7bcc4bb7d0325e5813b94b9f3e2b8 Mon Sep 17 00:00:00 2001 From: Bryan Wilder Field Lozano Date: Tue, 8 Oct 2024 12:52:10 -0400 Subject: [PATCH 07/58] #13352: Update build_with_profiler_opt.sh (#13590) --- scripts/build_scripts/build_with_profiler_opt.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/build_scripts/build_with_profiler_opt.sh b/scripts/build_scripts/build_with_profiler_opt.sh index 1c6135e28ad..83ab526d393 100755 --- a/scripts/build_scripts/build_with_profiler_opt.sh +++ b/scripts/build_scripts/build_with_profiler_opt.sh @@ -11,7 +11,7 @@ if [[ -z "$ARCH_NAME" ]]; then exit 1 fi -cmake -B build -G Ninja -DENABLE_TRACY=ON -DTT_METAL_BUILD_TESTS=ON -DTTNN_BUILD_TESTS=ON +cmake -B build -G Ninja -DENABLE_TRACY=ON -DTT_METAL_BUILD_TESTS=ON -DTTNN_BUILD_TESTS=ON -DBUILD_PROGRAMMING_EXAMPLES=ON if [[ $1 == "NO_CLEAN" ]]; then cmake --build build From 9a61bb11f2bbb6a8f16e7c5dff92e26292e0c1e4 Mon Sep 17 00:00:00 2001 From: Andrew Fuller Date: Tue, 8 Oct 2024 09:12:43 -0700 Subject: [PATCH 08/58] Error out with a clear message on a missing submodule. Single submodule is arbitrarily chosen. The typical case is that the developer missed cloning any submodules at all. We're drawing attention to the fact that this repo requires submodules. --- CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0ca911f87ae..521ec920bd6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,6 +1,11 @@ cmake_minimum_required(VERSION 3.16) cmake_policy(VERSION 3.16) +# Sanity check, forgetting to clone submodules is a common omission and results in a poor error message +if (NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/third_party/umd/CMakeLists.txt") + message(FATAL_ERROR "Missing submodules. Run: git submodule update --init --recursive") +endif() + ############################################ # Project setup ############################################ From 3d0acdd2a7f0d1f0b0cb207dec8f45b582a6fd2f Mon Sep 17 00:00:00 2001 From: Joseph Chu <122298491+cfjchu@users.noreply.github.com> Date: Tue, 8 Oct 2024 12:24:05 -0700 Subject: [PATCH 09/58] #13454: API refactor for Mesh::enable_program_cache (#13505) --- models/demos/t3000/falcon40b/demo/demo.py | 3 +-- .../tests/test_falcon_prefill_determinism.py | 9 +++------ .../llama2_70b/tests/test_llama_generation.py | 4 +--- .../t3000/llama2_70b/tests/test_llama_perf.py | 4 +--- .../llama2_70b/tests/test_llama_stress_test.py | 4 +--- .../demos/tg/llama3_70b/tests/test_llama_perf.py | 4 +--- .../tests/multi_chip/test_falcon_causallm.py | 3 +-- .../sweep_framework/sweeps/ccl/all_gather_n300.py | 3 --- tests/ttnn/unit_tests/test_multi_device_async.py | 10 ++++------ tests/ttnn/unit_tests/test_multi_device_events.py | 3 +-- tests/ttnn/unit_tests/test_multi_device_trace.py | 6 ++---- .../ttnn/unit_tests/test_multi_device_trace_TG.py | 6 ++---- .../ttnn/unit_tests/test_multi_device_trace_tgg.py | 6 ++---- tt_metal/distributed/mesh_device.cpp | 12 ++++++++++++ tt_metal/distributed/mesh_device.hpp | 2 ++ ttnn/cpp/ttnn/distributed/distributed_pybind.cpp | 14 +++++++++++++- 16 files changed, 47 insertions(+), 46 deletions(-) diff --git a/models/demos/t3000/falcon40b/demo/demo.py b/models/demos/t3000/falcon40b/demo/demo.py index e04e903ed34..9e3451dde37 100644 --- a/models/demos/t3000/falcon40b/demo/demo.py +++ b/models/demos/t3000/falcon40b/demo/demo.py @@ -527,8 +527,7 @@ def run_falcon_demo_kv( if not perf_mode: print_output_prompts(generated_ids, tokenizer) - for device in devices: - device.disable_and_clear_program_cache() + mesh_device.disable_and_clear_program_cache() generated_text = tokenizer.batch_decode(generated_ids.tolist()) diff --git a/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py b/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py index ba2d857dc5d..993437e0c65 100644 --- a/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py +++ b/models/demos/t3000/falcon40b/tests/test_falcon_prefill_determinism.py @@ -276,8 +276,7 @@ def test_falcon_prefill_end_to_end_determinism( input_shape = [batch, seq_len] model_config = get_model_config(model_config_str, "prefill", input_shape, num_devices) - devices = t3k_mesh_device.get_devices() - compute_grid_size = devices[0].compute_with_storage_grid_size() + compute_grid_size = t3k_mesh_device.compute_with_storage_grid_size() if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") @@ -286,8 +285,7 @@ def test_falcon_prefill_end_to_end_determinism( ) if enable_program_cache: - for device in devices: - device.enable_program_cache() + t3k_mesh_device.enable_program_cache() run_test_falcon_prefill_end_to_end_determinism( t3k_mesh_device, @@ -304,5 +302,4 @@ def test_falcon_prefill_end_to_end_determinism( ) if enable_program_cache: - for device in devices: - device.disable_and_clear_program_cache() + t3k_mesh_device.disable_and_clear_program_cache() diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py index b5b2286aa81..f5af555dc39 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_generation.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_generation.py @@ -152,9 +152,7 @@ def test_LlamaModel_inference( pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") t3k_mesh_device.enable_async(True) - for device_id in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(device_id) - device.enable_program_cache() + t3k_mesh_device.enable_program_cache() args = construct_arg( implementation=implementation, diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py index 131f8abf965..de8fab5b8c2 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf.py @@ -315,9 +315,7 @@ def test_Llama_perf_host( pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") t3k_mesh_device.enable_async(True) - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_program_cache() + t3k_mesh_device.enable_program_cache() disable_compilation_reports() run_test_LlamaModel_end_to_end( diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py b/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py index 621fcd2b3a7..a5b9edc7f81 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_stress_test.py @@ -147,9 +147,7 @@ def test_Llama_stress_test( if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]: pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run") - for i in t3k_mesh_device.get_device_ids(): - device = t3k_mesh_device.get_device(i) - device.enable_program_cache() + t3k_mesh_device.enable_program_cache() disable_compilation_reports() run_test_LlamaModel_stress_test( devices, diff --git a/models/demos/tg/llama3_70b/tests/test_llama_perf.py b/models/demos/tg/llama3_70b/tests/test_llama_perf.py index ce9a16095ce..3190abc90d3 100644 --- a/models/demos/tg/llama3_70b/tests/test_llama_perf.py +++ b/models/demos/tg/llama3_70b/tests/test_llama_perf.py @@ -197,9 +197,7 @@ def test_Llama_perf_host( check_mesh_device(mesh_device, model_config) mesh_device.enable_async(True) - - for device in mesh_device.get_devices(): - device.enable_program_cache() + mesh_device.enable_program_cache() disable_compilation_reports() run_test_LlamaModel_end_to_end( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py index c52d9d4fc28..1de4f9a058c 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_causallm.py @@ -296,8 +296,7 @@ def test_t3k_falcon_causal_lm_with_trace( num_loops, ): t3k_mesh_device.enable_async(enable_async) - for device in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device).enable_program_cache() + t3k_mesh_device.enable_program_cache() torch.manual_seed(0) batch = device_batch_size * t3k_mesh_device.get_num_devices() diff --git a/tests/sweep_framework/sweeps/ccl/all_gather_n300.py b/tests/sweep_framework/sweeps/ccl/all_gather_n300.py index 963e7f48a7c..bd56a03a62b 100644 --- a/tests/sweep_framework/sweeps/ccl/all_gather_n300.py +++ b/tests/sweep_framework/sweeps/ccl/all_gather_n300.py @@ -105,9 +105,6 @@ def run( logger.info(f"Input shape: {input_shape}") logger.info(f"dim: {dim}") - # for device in devices: - # device.disable_and_clear_program_cache() - input_tensor = torch.rand(input_shape).bfloat16() input_tensors = torch.chunk(input_tensor, num_devices, dim) diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index 62be5ba5b63..5a8890c497e 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -149,9 +149,8 @@ def test_multi_device_unary_binary_op_chain(pcie_mesh_device, program_cache, sha from ttnn import ShardTensorToMesh, ConcatMeshToTensor pcie_mesh_device.enable_async(True) - for device in pcie_mesh_device.get_device_ids(): - if program_cache: - pcie_mesh_device.get_device(device).enable_program_cache() + if program_cache: + pcie_mesh_device.enable_program_cache() torch_silu = torch.nn.SiLU() for i in range(50): @@ -190,9 +189,8 @@ def test_multi_device_data_parallel_op_chain(pcie_mesh_device, program_cache, in from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh pcie_mesh_device.enable_async(True) - for device in pcie_mesh_device.get_device_ids(): - if program_cache: - pcie_mesh_device.get_device(device).enable_program_cache() + if program_cache: + pcie_mesh_device.enable_program_cache() torch_silu = torch.nn.SiLU() torch_mish = torch.nn.Mish() diff --git a/tests/ttnn/unit_tests/test_multi_device_events.py b/tests/ttnn/unit_tests/test_multi_device_events.py index c83d5c693a2..0217fe9f33f 100644 --- a/tests/ttnn/unit_tests/test_multi_device_events.py +++ b/tests/ttnn/unit_tests/test_multi_device_events.py @@ -21,8 +21,7 @@ def test_multi_device_events(t3k_mesh_device, shape): # Enable Program Cache and Async Mode t3k_mesh_device.enable_async(True) - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_program_cache() + t3k_mesh_device.enable_program_cache() # Preallocate activation tensors. input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, t3k_mesh_device) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index 1fc07590d90..2e81db7b248 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -28,8 +28,7 @@ def test_multi_device_single_trace(t3k_mesh_device, shape, use_all_gather, enabl # Trace requires program cache to be enabled t3k_mesh_device.enable_async(enable_async) - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_program_cache() + t3k_mesh_device.enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, t3k_mesh_device) @@ -142,8 +141,7 @@ def test_multi_device_multi_trace(t3k_mesh_device, shape, use_all_gather, enable # Trace requires program cache to be enabled t3k_mesh_device.enable_async(enable_async) - for device_id in t3k_mesh_device.get_device_ids(): - t3k_mesh_device.get_device(device_id).enable_program_cache() + t3k_mesh_device.enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, t3k_mesh_device) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py index 7836c9de402..60c5f57d613 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_TG.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_TG.py @@ -27,8 +27,7 @@ def test_multi_device_single_trace(mesh_device, shape, enable_async, enable_mult pytest.skip("Test is only valid on Galaxy") # Trace requires program cache to be enabled mesh_device.enable_async(True) - for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_program_cache() + mesh_device.enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, mesh_device) @@ -129,8 +128,7 @@ def test_multi_device_multi_trace(mesh_device, shape, enable_async, enable_multi # Trace requires program cache to be enabled mesh_device.enable_async(True) - for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_program_cache() + mesh_device.enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, mesh_device) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py index ddb354dc365..9eb27afe2e1 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace_tgg.py @@ -27,8 +27,7 @@ def test_multi_device_single_trace(mesh_device, shape, enable_async, enable_mult pytest.skip("Test is only valid on TGG") # Trace requires program cache to be enabled mesh_device.enable_async(True) - for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_program_cache() + mesh_device.enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, mesh_device) @@ -128,8 +127,7 @@ def test_multi_device_multi_trace(mesh_device, shape, enable_async, enable_multi # Trace requires program cache to be enabled mesh_device.enable_async(True) - for device_id in mesh_device.get_device_ids(): - mesh_device.get_device(device_id).enable_program_cache() + mesh_device.enable_program_cache() # Preallocate activation tensors. These will be used when capturing and executing the trace input_0_dev = ttnn.allocate_tensor_on_device(ttnn.Shape(shape), ttnn.bfloat16, ttnn.TILE_LAYOUT, mesh_device) diff --git a/tt_metal/distributed/mesh_device.cpp b/tt_metal/distributed/mesh_device.cpp index b97fb7b6cca..ec521f9ce40 100644 --- a/tt_metal/distributed/mesh_device.cpp +++ b/tt_metal/distributed/mesh_device.cpp @@ -408,6 +408,18 @@ void MeshDevice::enable_async(bool enable) { } } +void MeshDevice::enable_program_cache() { + for (auto device : this->devices) { + device->enable_program_cache(); + } +} + +void MeshDevice::disable_and_clear_program_cache() { + for (auto device : this->devices) { + device->disable_and_clear_program_cache(); + } +} + std::vector get_t3k_physical_device_ids_ring() { auto& instance = SystemMesh::instance(); auto num_devices = instance.get_num_devices(); diff --git a/tt_metal/distributed/mesh_device.hpp b/tt_metal/distributed/mesh_device.hpp index 6589b3d0fce..7237c8c0158 100644 --- a/tt_metal/distributed/mesh_device.hpp +++ b/tt_metal/distributed/mesh_device.hpp @@ -153,6 +153,8 @@ class MeshDevice : public std::enable_shared_from_this { tt::ARCH arch() const; void enable_async(bool enable); + void enable_program_cache(); + void disable_and_clear_program_cache(); void close_devices(); std::shared_ptr get_view() const; diff --git a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp index 83e71b0a01b..53caf4276ca 100644 --- a/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp +++ b/ttnn/cpp/ttnn/distributed/distributed_pybind.cpp @@ -111,7 +111,19 @@ void py_module(py::module& module) { Args: enable (bool): True to enable async mode, False to disable it. - )doc") + )doc") + .def( + "enable_program_cache", + &MeshDevice::enable_program_cache, + R"doc( + Enable program cache across all devices in the mesh. + )doc") + .def( + "disable_and_clear_program_cache", + &MeshDevice::disable_and_clear_program_cache, + R"doc( + Disable program cache across all devices in the mesh. + )doc") .def_property_readonly("shape", &MeshDevice::shape, R"doc( Get the shape of the device mesh. From f0c6b2da31ff156733ed6871ee95326723c2bfb9 Mon Sep 17 00:00:00 2001 From: Brian Liu Date: Mon, 7 Oct 2024 17:11:55 +0000 Subject: [PATCH 10/58] #13127: Switch to ttnn::SimpleShape for create_device_tensor - create_device_tensor now takes logical and padded ttnn::SimpleShape * For now, create_device_tensor is also overloaded to take ttnn::Shape * Eventually, create_device_tensor should only take logical shape - Add tensor constructor for ttnn::SimpleShape where padding isn't used * Eventually remove once ttnn::SimpleShape and ttnn::Shape are the same - Switch to ttnn::SimpleShape for compute_strides - Switch to ttnn::SimpleShape for compute_buffer_size - Switch compute_volume calls to ttnn::SimpleShape .volume() where possible * Eventually we can remove compute_volume for LegacyShape once shapes are all ttnn::SimpleShape - Switch tensor strides() to return ttnn::SimpleShape - Add check for rank for to_array_4D() for LegacyShape - Update ttnn unit tests to use ttnn::SimpleShape where possible --- .../gtests/tensor/test_create_tensor.cpp | 5 +-- .../unit_tests/gtests/test_async_runtime.cpp | 8 ++-- .../unit_tests/gtests/test_ccl_on_galaxy.cpp | 4 +- .../gtests/test_multi_cq_multi_dev.cpp | 6 +-- .../gtests/test_multiprod_queue.cpp | 4 +- .../gtests/test_repeat_interleave.cpp | 2 +- ttnn/cpp/ttnn/async_runtime.cpp | 12 +++--- ttnn/cpp/ttnn/async_runtime.hpp | 2 +- ttnn/cpp/ttnn/operations/numpy/functions.hpp | 8 ++-- ttnn/cpp/ttnn/tensor/tensor.cpp | 39 +++++++++++-------- ttnn/cpp/ttnn/tensor/tensor.hpp | 13 ++++--- ttnn/cpp/ttnn/tensor/tensor_impl.cpp | 24 ++++++------ ttnn/cpp/ttnn/tensor/tensor_impl.hpp | 8 ++-- ttnn/cpp/ttnn/tensor/tensor_ops.cpp | 2 +- ttnn/cpp/ttnn/tensor/tensor_utils.cpp | 28 ++++++------- ttnn/cpp/ttnn/tensor/tensor_utils.hpp | 12 +++--- ttnn/cpp/ttnn/tensor/types.hpp | 1 + 17 files changed, 93 insertions(+), 85 deletions(-) diff --git a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp index 88bb90b22c6..654e9bd5d54 100644 --- a/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp +++ b/tests/ttnn/unit_tests/gtests/tensor/test_create_tensor.cpp @@ -32,12 +32,11 @@ void run_create_tensor_test(tt::tt_metal::Device* device, ttnn::SimpleShape inpu host_data[i] = 1; } - ttnn::Shape shape(input_shape.as_vector()); - auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, shape, dtype, Layout::TILE, mem_cfg); + auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, input_shape, dtype, Layout::TILE, mem_cfg); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; - Tensor input_tensor = Tensor(input_storage, shape, dtype, Layout::TILE); + Tensor input_tensor = Tensor(input_storage, input_shape, dtype, Layout::TILE); tt::log_debug("input_data: \n {}", input_tensor.write_to_string()); ttnn::write_buffer(io_cq, input_tensor, {host_data}); diff --git a/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp b/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp index 3b6daa1f0cc..5734bdc8924 100644 --- a/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp +++ b/tests/ttnn/unit_tests/gtests/test_async_runtime.cpp @@ -54,8 +54,8 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncPreallocatedOutputs) { auto workload_event = std::make_shared(); // Running sum-reduce with preallocated output // Preallocate Input and Output Tensors on Device - auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, input_shape, DataType::BFLOAT16, Layout::TILE, mem_cfg); - auto output_buffer = ttnn::allocate_buffer_on_device(output_buf_size_datums * datum_size_bytes, device, np_out.get_shape(), DataType::BFLOAT16, Layout::TILE, mem_cfg); + auto input_buffer = ttnn::allocate_buffer_on_device(input_buf_size_datums * datum_size_bytes, device, input_shape.padded_shape(), DataType::BFLOAT16, Layout::TILE, mem_cfg); + auto output_buffer = ttnn::allocate_buffer_on_device(output_buf_size_datums * datum_size_bytes, device, np_out.get_padded_shape(), DataType::BFLOAT16, Layout::TILE, mem_cfg); auto input_storage = tt::tt_metal::DeviceStorage{input_buffer}; auto output_storage = tt::tt_metal::DeviceStorage{output_buffer}; Tensor input_tensor = Tensor(input_storage, input_shape, DataType::BFLOAT16, Layout::TILE); @@ -105,7 +105,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeAllocatedBuffers) { std::vector inputs = {4, 9, 16, 25, 36, 64}; uint32_t io_cq = 1; uint32_t workload_dispatch_cq = 0; - ttnn::Shape shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, 1024, 1024})); + ttnn::SimpleShape shape{1, 1, 1024, 1024}; auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); auto readback_data = std::shared_ptr(new bfloat16[buf_size_datums]); @@ -158,7 +158,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestAsyncRuntimeBufferDestructor) { uint32_t buf_size_datums = 1024 * 1024; uint32_t datum_size_bytes = 2; - ttnn::Shape shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, 1024, 1024})); + ttnn::SimpleShape shape{1, 1, 1024, 1024}; // Inside the loop, initialize a buffer with limited lifetime. // This will asynchronously allocate the buffer, wait for the allocation to complete (address to be assigned to the buffer), destroy the buffer (which will asynchronously // deallocate the buffer) in a loop diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 18e635aa4f1..027537ae3a5 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -117,7 +117,7 @@ TEST(GalaxyTests, TestAllGatherDeadlock) { .memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::DRAM, .shard_spec = std::nullopt}; - ttnn::Shape shape = ttnn::Shape(LegacyShape({1, 1, 32, 16384})); + ttnn::SimpleShape shape{1, 1, 32, 16384}; const uint32_t buf_size_datums = 32 * 16384; const uint32_t datum_size_bytes = 2; auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); @@ -210,7 +210,7 @@ TEST(GalaxyTests, TestReduceScatterDeadlock) { .memory_layout = TensorMemoryLayout::INTERLEAVED, .buffer_type = BufferType::DRAM, .shard_spec = std::nullopt}; - ttnn::Shape shape = ttnn::Shape(LegacyShape({1, 2, 256, static_cast(256 * ring_devices.size())})); + ttnn::SimpleShape shape{1, 2, 256, static_cast(256 * ring_devices.size())}; const uint32_t buf_size_datums = 2 * 256 * 256 * ring_devices.size(); const uint32_t datum_size_bytes = 2; // Output of reduce scatter is input_numel / num_devices_used_in_scatter_op diff --git a/tests/ttnn/unit_tests/gtests/test_multi_cq_multi_dev.cpp b/tests/ttnn/unit_tests/gtests/test_multi_cq_multi_dev.cpp index e06824058b6..52f4320fba0 100644 --- a/tests/ttnn/unit_tests/gtests/test_multi_cq_multi_dev.cpp +++ b/tests/ttnn/unit_tests/gtests/test_multi_cq_multi_dev.cpp @@ -44,7 +44,7 @@ TEST_F(MultiCommandQueueT3KFixture, Test2CQMultiDeviceProgramsOnCQ1) { .buffer_type = BufferType::DRAM, .shard_spec = std::nullopt}; - ttnn::Shape shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 3, 2048, 2048})); + ttnn::SimpleShape shape{1, 3, 2048, 2048}; uint32_t buf_size_datums = 2048 * 2048 * 3; uint32_t datum_size_bytes = 2; auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); @@ -94,7 +94,7 @@ TEST_F(MultiCommandQueueT3KFixture, Test2CQMultiDeviceProgramsOnCQ0) { .buffer_type = BufferType::DRAM, .shard_spec = std::nullopt}; - ttnn::Shape shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 3, 2048, 2048})); + ttnn::SimpleShape shape{1, 3, 2048, 2048}; uint32_t buf_size_datums = 2048 * 2048 * 3; uint32_t datum_size_bytes = 2; auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); @@ -145,7 +145,7 @@ TEST_F(MultiCommandQueueT3KFixture, Test2CQMultiDeviceWithCQ1Only) { .buffer_type = BufferType::DRAM, .shard_spec = std::nullopt}; - ttnn::Shape shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 3, 2048, 2048})); + ttnn::SimpleShape shape{1, 3, 2048, 2048}; uint32_t buf_size_datums = 2048 * 2048 * 3; uint32_t datum_size_bytes = 2; auto host_data = std::shared_ptr(new bfloat16[buf_size_datums]); diff --git a/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp b/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp index d6590d9a395..20e3350dc38 100644 --- a/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp +++ b/tests/ttnn/unit_tests/gtests/test_multiprod_queue.cpp @@ -39,7 +39,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestMultiProducerLockBasedQueue) { uint32_t tensor_buf_size = 1024 * 1024; uint32_t datum_size_bytes = 2; - ttnn::Shape tensor_shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, 1024, 1024})); + ttnn::SimpleShape tensor_shape{1, 1, 1024, 1024}; auto t0_host_data = std::shared_ptr(new bfloat16[tensor_buf_size]); auto t0_readback_data = std::shared_ptr(new bfloat16[tensor_buf_size]); auto t1_host_data = std::shared_ptr(new bfloat16[tensor_buf_size]); @@ -117,7 +117,7 @@ TEST_F(MultiCommandQueueSingleDeviceFixture, TestMultiAppThreadSync) { std::shared_ptr write_event = std::make_shared(); std::shared_ptr read_event = std::make_shared(); - ttnn::Shape tensor_shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, 1024, 1024})); + ttnn::SimpleShape tensor_shape{1, 1, 1024, 1024}; auto host_data = std::shared_ptr(new bfloat16[tensor_buf_size]); auto allocated_buffer = ttnn::allocate_buffer_on_device(tensor_buf_size * datum_size_bytes, device, tensor_shape, DataType::BFLOAT16, Layout::TILE, mem_cfg); auto allocated_storage = tt::tt_metal::DeviceStorage{allocated_buffer}; diff --git a/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp b/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp index 884ee2475e3..1dee81c29e0 100644 --- a/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp +++ b/tests/ttnn/unit_tests/gtests/test_repeat_interleave.cpp @@ -30,7 +30,7 @@ void run_repeat_interleave_test(tt::tt_metal::Device* device, const uint32_t rep const uint32_t input_buf_size_datums = 32 * 32; const uint32_t output_buf_size_datums = input_buf_size_datums * repeats; const uint32_t datum_size_bytes = 2; - ttnn::Shape input_shape = ttnn::Shape(tt::tt_metal::LegacyShape({1, 1, 32, 32})); + ttnn::SimpleShape input_shape{1, 1, 32, 32}; auto host_data = std::shared_ptr(new uint16_t[input_buf_size_datums]); auto readback_data = std::shared_ptr(new uint16_t[output_buf_size_datums]); diff --git a/ttnn/cpp/ttnn/async_runtime.cpp b/ttnn/cpp/ttnn/async_runtime.cpp index 2a8bc818df6..29845f66316 100644 --- a/ttnn/cpp/ttnn/async_runtime.cpp +++ b/ttnn/cpp/ttnn/async_runtime.cpp @@ -14,12 +14,12 @@ using queue_id = uint8_t; DeviceBuffer allocate_interleaved_buffer_on_device( size_t buffer_size_bytes, Device* device, - const Shape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, const std::optional& tile) { - uint32_t page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value, tile); + uint32_t page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape, tile); return std::make_shared(device, buffer_size_bytes, page_size, memory_config.buffer_type); } @@ -31,19 +31,19 @@ DeviceBuffer allocate_contiguous_buffer_on_device( DeviceBuffer allocate_sharded_buffer_on_device( size_t buffer_size_bytes, Device* device, - const Shape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const ShardSpecBuffer& shard_params, const MemoryConfig& memory_config, const std::optional& tile) { tt::tt_metal::tensor_impl::validate_sharded_buffer_allocation( - shape.value, layout, data_type, shard_params, memory_config, tile); + shape, layout, data_type, shard_params, memory_config, tile); const auto& page_shape = shard_params.page_shape; uint32_t size_of_element = tt::tt_metal::tensor_impl::element_size_bytes(data_type); uint32_t page_size = page_shape[0] * page_shape[1] * size_of_element; if (layout == Layout::TILE) { - page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape.value, tile); + page_size = tt::tt_metal::tensor_impl::get_page_size(data_type, layout, buffer_size_bytes, shape, tile); } return std::make_shared( @@ -53,7 +53,7 @@ DeviceBuffer allocate_sharded_buffer_on_device( DeviceBuffer allocate_buffer_on_device( size_t buffer_size_bytes, types::Device* device, - const Shape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, diff --git a/ttnn/cpp/ttnn/async_runtime.hpp b/ttnn/cpp/ttnn/async_runtime.hpp index 672a214f210..2be1296bc56 100644 --- a/ttnn/cpp/ttnn/async_runtime.hpp +++ b/ttnn/cpp/ttnn/async_runtime.hpp @@ -12,7 +12,7 @@ namespace ttnn { using DeviceBuffer = std::shared_ptr; using queue_id = uint8_t; - DeviceBuffer allocate_buffer_on_device(size_t buffer_size_bytes, types::Device* device, const Shape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, const std::optional& shard_spec = std::nullopt, const std::optional& tile = std::nullopt); + DeviceBuffer allocate_buffer_on_device(size_t buffer_size_bytes, types::Device* device, const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, const std::optional& shard_spec = std::nullopt, const std::optional& tile = std::nullopt); void write_buffer(queue_id cq_id, Tensor& dst, std::vector> src, const std::optional transfer_size = std::nullopt); diff --git a/ttnn/cpp/ttnn/operations/numpy/functions.hpp b/ttnn/cpp/ttnn/operations/numpy/functions.hpp index 0762f3a9f42..62c2dc058cc 100644 --- a/ttnn/cpp/ttnn/operations/numpy/functions.hpp +++ b/ttnn/cpp/ttnn/operations/numpy/functions.hpp @@ -242,7 +242,7 @@ static Tensor arange( owned_buffer[index++] = static_cast(value); } } - auto output = Tensor(OwnedStorage{owned_buffer}, {1, 1, 1, static_cast(size)}, data_type, layout); + auto output = Tensor(OwnedStorage{owned_buffer}, ttnn::SimpleShape{1, 1, 1, static_cast(size)}, data_type, layout); if (device != nullptr) { output = output.to(device, output_mem_config); } @@ -454,7 +454,7 @@ static Tensor fill_first_val_into_tensor( tt::tt_metal::tensor_impl::read_data_from_device_buffer(device_buffer, data_vec); } auto input_buffer = owned_buffer::create(std::move(data_vec)); - const tt::tt_metal::LegacyShape input_tensor_strides = input_tensor.strides(); + const ttnn::SimpleShape input_tensor_strides = input_tensor.strides(); for (uint32_t i = 0; i < physical_volume; i++) { owned_buffer[i] = input_buffer[0]; } @@ -488,7 +488,7 @@ static Tensor prod_result_computation_GS( tt::tt_metal::tensor_impl::read_data_from_device_buffer(device_buffer, data_vec); } auto input_buffer = owned_buffer::create(std::move(data_vec)); - const tt::tt_metal::LegacyShape input_tensor_strides = input_tensor.strides(); + const ttnn::SimpleShape input_tensor_strides = input_tensor.strides(); auto result = static_cast(1.0f); for (uint32_t i = s_a[0] - 1; i < s_a[0]; i++) { for (int32_t j = s_a[1] - 1; j < s_a[1]; j++) { @@ -537,7 +537,7 @@ static Tensor prod_result_computation_WH_B0( tt::tt_metal::tensor_impl::read_data_from_device_buffer(device_buffer, data_vec); } auto input_buffer = owned_buffer::create(std::move(data_vec)); - const tt::tt_metal::LegacyShape input_tensor_strides = input_tensor.strides(); + const ttnn::SimpleShape input_tensor_strides = input_tensor.strides(); auto result = static_cast(1.0f); // need to access the last 4 rows and alternating columns of index 17 ,19, 21, 23, 25, 27, 29, 31 for (uint32_t i = s_a[0] - 1; i < s_a[0]; i++) { diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 95fe891ceb6..4f0fe5d95e2 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -93,7 +93,7 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L } else if constexpr (std::is_same_v) { TT_ASSERT(storage.buffer->device() != nullptr); workers = {storage.buffer->device()}; - tensor_impl::validate_on_device_dtype_and_layout(storage.buffer->device(), shape.value, dtype, layout); + tensor_impl::validate_on_device_dtype_and_layout(storage.buffer->device(), shape.padded_shape(), dtype, layout); // Increment main thread ref count for all tensors on device this->tensor_attributes->increment_main_thread_ref_count(this->workers.at(0)); // This tensor is being created from scratch in a worker. Track this and allow it to be explicitly @@ -111,7 +111,7 @@ Tensor::Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, L auto buffer = storage.get_buffer_for_device_id(device_id); TT_ASSERT(buffer->device() != nullptr); TT_ASSERT(buffer->device()->id() == device_id); - tensor_impl::validate_on_device_dtype_and_layout(buffer->device(), shape.value, dtype, layout); + tensor_impl::validate_on_device_dtype_and_layout(buffer->device(), shape.padded_shape(), dtype, layout); workers.push_back(buffer->device()); } // Increment main thread ref count for all tensors on cluster @@ -224,6 +224,8 @@ Tensor::~Tensor() { tensor_attributes.reset(); } +Tensor::Tensor(const Storage storage, const ttnn::SimpleShape& shape, DataType dtype, Layout layout, const std::optional& tile) : Tensor(storage, ttnn::Shape(shape.as_vector()), dtype, layout, tile) {} + void Tensor::deallocate(bool force) { ZoneScopedN("TensorDeallocate"); // GraphTracker::instance().track_function_start("Tensor::deallocate", *this, force); @@ -629,7 +631,7 @@ StorageType Tensor::storage_type() const { this->get_storage()); } -const tt::tt_metal::LegacyShape Tensor::strides() const { return tt::tt_metal::LegacyShape(tt::tt_metal::compute_strides(this->get_legacy_shape())); } +const ttnn::SimpleShape Tensor::strides() const { return ttnn::SimpleShape(tt::tt_metal::compute_strides(this->get_padded_shape())); } uint32_t Tensor::volume() const { return tt::tt_metal::compute_volume(this->get_legacy_shape()); } @@ -653,19 +655,20 @@ tt::tt_metal::Padding Tensor::get_padding() const { } Tensor create_device_tensor( - const tt::tt_metal::LegacyShape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { + const ttnn::SimpleShape& logical_shape, const ttnn::SimpleShape& padded_shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { ZoneScoped; - GraphTracker::instance().track_function_start("tt::tt_metal::create_device_tensor", shape, data_type, layout, device, memory_config); + GraphTracker::instance().track_function_start("tt::tt_metal::create_device_tensor", padded_shape, data_type, layout, device, memory_config); + if (memory_config.is_sharded()) { TT_ASSERT(memory_config.shard_spec.has_value()); auto& shard_spec = memory_config.shard_spec.value(); auto& shard_shape = shard_spec.shape; - auto width = shape[-1]; + auto width = padded_shape[-1]; auto other_dims = 1; - for (int i = 0; i < shape.rank() - 1; i++) { - other_dims *= shape[i]; + for (int i = 0; i < padded_shape.rank() - 1; i++) { + other_dims *= padded_shape[i]; } auto element_size = tensor_impl::element_size_bytes(data_type); @@ -673,26 +676,30 @@ Tensor create_device_tensor( std::array tensor2d_size = {other_dims / page_shape[0], width / page_shape[1]}; ShardSpecBuffer shard_spec_buffer(shard_spec, page_shape, tensor2d_size); size_t packed_size_in_bytes = - tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type)); + tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(padded_shape, data_type)); auto device_buffer = tensor_impl::allocate_buffer_on_device( - packed_size_in_bytes, device, shape, data_type, layout, memory_config, shard_spec_buffer, tile); + packed_size_in_bytes, device, padded_shape, data_type, layout, memory_config, shard_spec_buffer, tile); - auto output = Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, tile); + auto output = Tensor(DeviceStorage{device_buffer}, ttnn::Shape(logical_shape.as_vector(), padded_shape.as_vector()), data_type, layout, tile); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; } else { size_t packed_size_in_bytes = - tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(shape, data_type)); + tensor_impl::packed_buffer_size_bytes_wrapper(data_type, compute_buffer_size(padded_shape, data_type)); auto device_buffer = tensor_impl::allocate_buffer_on_device( - packed_size_in_bytes, device, shape, data_type, layout, memory_config, std::nullopt, tile); - auto output = Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, tile); + packed_size_in_bytes, device, padded_shape, data_type, layout, memory_config, std::nullopt, tile); + auto output = Tensor(DeviceStorage{device_buffer}, ttnn::Shape(logical_shape.as_vector(), padded_shape.as_vector()), data_type, layout, tile); output = tt::tt_metal::set_tensor_id(output); GraphTracker::instance().track_function_end(output); return output; } } +Tensor create_device_tensor( + const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { + return create_device_tensor(shape.logical_shape(), shape.padded_shape(), data_type, layout, device, memory_config, tile); +} namespace detail { template @@ -818,7 +825,7 @@ Tensor allocate_tensor_on_device( Tensor device_tensor = Tensor({device}); uint32_t device_tensor_ref_count = device_tensor.tensor_attributes->record_main_thread_ref_count(); device->push_work([shape, data_type, layout, device, memory_config, tile, device_tensor]() mutable { - auto local_tensor = create_device_tensor(shape.value, data_type, layout, device, memory_config, tile); + auto local_tensor = create_device_tensor(shape, data_type, layout, device, memory_config, tile); device_tensor.populate_buffers_and_metadata(local_tensor); }); device_tensor.tensor_attributes->update_main_thread_ref_count(device, device_tensor_ref_count); @@ -842,7 +849,7 @@ Tensor allocate_tensor_on_device( for (int worker_index = 0; worker_index < num_workers; ++worker_index) { auto& worker = workers[worker_index]; worker->push_work([shape, data_type, layout, worker, memory_config, tile, device_tensor, worker_index]() mutable { - auto local_tensor = create_device_tensor(shape.value, data_type, layout, worker, memory_config, tile); + auto local_tensor = create_device_tensor(shape, data_type, layout, worker, memory_config, tile); insert_buffer_and_shape_for_device(worker, local_tensor, device_tensor, worker_index); uint32_t num_workers_completed = (device_tensor.tensor_attributes->num_workers_completed)++; diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index 311c3d60ef6..e23832be836 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -89,6 +89,7 @@ struct Tensor { deallocate_through_destructor(false) {} Tensor(const Storage storage, const ttnn::Shape shape, DataType dtype, Layout layout, const std::optional& tile = std::nullopt); + Tensor(const Storage storage, const ttnn::SimpleShape& shape, DataType dtype, Layout layout, const std::optional& tile = std::nullopt); // Constructor to initialize unpopulated tensor with workers and storage specified. Use this when creating tensor // handles in async mode. @@ -211,7 +212,7 @@ struct Tensor { // Extra Helper Functions // ====================================================================================== StorageType storage_type() const; - const tt::tt_metal::LegacyShape strides() const; + const ttnn::SimpleShape strides() const; uint32_t volume() const; // todo: rename volume to get_volume to indicate that its blocking @@ -293,22 +294,22 @@ struct Tensor { }; Tensor create_device_tensor( - const tt::tt_metal::LegacyShape &shape, + const ttnn::SimpleShape &logical_shape, + const ttnn::SimpleShape &padded_shape, DataType dtype, Layout layout, Device *device, const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, const std::optional& tile = std::nullopt); -static Tensor create_device_tensor( +// TODO: Remove once ALL ops switch over to return ttnn::SimpleShape in compute_output_shapes +Tensor create_device_tensor( const ttnn::Shape &shape, DataType dtype, Layout layout, Device *device, const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, - const std::optional& tile = std::nullopt) { - return create_device_tensor(shape.value, dtype, layout, device, memory_config, tile); -} + const std::optional& tile = std::nullopt); // template // void *get_host_buffer(const Tensor &tensor); diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp index 5dc5560a5d6..d129bc9dc97 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.cpp @@ -43,7 +43,7 @@ uint32_t element_size_bytes(DataType dtype) { } } -uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const tt::tt_metal::LegacyShape& shape, const std::optional& tile) { +uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const ttnn::SimpleShape& shape, const std::optional& tile) { uint32_t W = shape[-1]; uint32_t page_size = 0; const auto tile_HW = tile.has_value() ? tile->get_tile_hw() : constants::TILE_HW; @@ -104,7 +104,7 @@ std::array get_sharded_page_shape(Layout layout, DataType dtype, st } void validate_sharded_buffer_allocation( - const tt::tt_metal::LegacyShape& shape, + const ttnn::SimpleShape& shape, Layout layout, DataType data_type, const ShardSpecBuffer& shard_params, @@ -115,7 +115,7 @@ void validate_sharded_buffer_allocation( uint32_t num_cores = shard_spec.num_cores(); - uint32_t total_height = tt_metal::compute_volume(shape) / shape[-1]; + uint32_t total_height = shape.volume() / shape[-1]; uint32_t total_width = shape[-1]; if (memory_config.memory_layout == TensorMemoryLayout::HEIGHT_SHARDED) { TT_ASSERT( @@ -184,7 +184,7 @@ namespace detail { DeviceBuffer allocate_interleaved_buffer_on_device( size_t buffer_size_bytes, Device* device, - const tt::tt_metal::LegacyShape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, @@ -201,14 +201,14 @@ DeviceBuffer allocate_contiguous_buffer_on_device( DeviceBuffer allocate_sharded_buffer_on_device( size_t buffer_size_bytes, Device* device, - const tt::tt_metal::LegacyShape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const ShardSpecBuffer& shard_params, const MemoryConfig& memory_config, const std::optional& tile) { validate_sharded_buffer_allocation(shape, layout, data_type, shard_params, memory_config, tile); - const auto& page_shape = shard_params.page_shape; + const auto& page_shape = ttnn::SimpleShape(shard_params.page_shape); uint32_t page_size = get_page_size(data_type, layout, buffer_size_bytes, page_shape, tile); return std::make_shared( @@ -220,7 +220,7 @@ DeviceBuffer allocate_sharded_buffer_on_device( DeviceBuffer allocate_buffer_on_device( size_t buffer_size_bytes, Device* device, - const tt::tt_metal::LegacyShape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, @@ -238,7 +238,7 @@ DeviceBuffer allocate_buffer_on_device( } } -void validate_on_device_dtype_and_layout(Device* device, const tt::tt_metal::LegacyShape& shape, DataType dtype, Layout layout) { +void validate_on_device_dtype_and_layout(Device* device, const ttnn::SimpleShape& shape, DataType dtype, Layout layout) { // TODO: Get supported layout and dtypes from device auto supported_dtype = [&dtype]() { TT_ASSERT( @@ -799,7 +799,7 @@ template typename BufferType> DeviceBuffer initialize_data_on_device( BufferType& data_to_write, Device* device, - const tt::tt_metal::LegacyShape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, @@ -826,7 +826,7 @@ template DeviceBuffer to_device_buffer( const Storage& storage, Device* device, - const tt::tt_metal::LegacyShape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, @@ -884,7 +884,7 @@ Tensor to_device( TT_ASSERT(target_device != nullptr && "Need target device in order to move tensor to device!"); TT_ASSERT(tensor.is_allocated() && "Need data to exist in order to move it to device"); - auto shape = tensor.get_legacy_shape(); + auto shape = tensor.get_padded_shape(); auto data_type = tensor.get_dtype(); auto layout = tensor.get_layout(); auto tile = tensor.get_tile(); @@ -905,7 +905,7 @@ Tensor to_device( auto device_buffer = tensor_impl::to_device_buffer( tensor.get_storage(), target_device, shape, data_type, layout, memory_config, shard_spec_buffer_opt, tile, queue); - return Tensor(DeviceStorage{device_buffer}, shape, data_type, layout, tile); + return Tensor(DeviceStorage{device_buffer}, tensor.get_shape(), data_type, layout, tile); } template Tensor to_device( diff --git a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp index ab3b17f714b..92e7bf101e3 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_impl.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_impl.hpp @@ -234,9 +234,9 @@ inline std::vector convert_layout_tile_to_row_major(const tt::tt_metal::Legac // ====================================================================================== // Validators // ====================================================================================== -void validate_on_device_dtype_and_layout(Device* device, const tt::tt_metal::LegacyShape& shape, DataType dtype, Layout layout); +void validate_on_device_dtype_and_layout(Device* device, const ttnn::SimpleShape& shape, DataType dtype, Layout layout); void validate_sharded_buffer_allocation( - const tt::tt_metal::LegacyShape& shape, + const ttnn::SimpleShape& shape, Layout layout, DataType data_type, const ShardSpecBuffer& shard_params, @@ -252,12 +252,12 @@ void validate_sharded_buffer_allocation( // Data reader, writer, and initializers // ====================================================================================== -uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const tt::tt_metal::LegacyShape& shape, const std::optional& tile = std::nullopt); +uint32_t get_page_size(DataType dtype, Layout layout, uint32_t total_size_bytes, const ttnn::SimpleShape& shape, const std::optional& tile = std::nullopt); DeviceBuffer allocate_buffer_on_device( size_t buffer_size_bytes, Device* device, - const tt::tt_metal::LegacyShape& shape, + const ttnn::SimpleShape& shape, DataType data_type, Layout layout, const MemoryConfig& memory_config, diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp index 25e2cc60210..932e28087c2 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.cpp @@ -56,7 +56,7 @@ Tensor tensor_to(const Tensor& input_tensor, Device* target_device, const Memory } else { tensor_impl::validate_on_device_dtype_and_layout( target_device, - async_safe_tensor.get_legacy_shape(), + async_safe_tensor.get_padded_shape(), async_safe_tensor.get_dtype(), async_safe_tensor.get_layout()); auto local_tensor = diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp index 6a2c7574230..e6ea11fb1bc 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.cpp @@ -28,8 +28,8 @@ Tensor to_weight_special_padding_tile_layout( assert(in1_block_h_datums >= w_shape[1] * w_shape[3]); uint32_t block_height_padding = in1_block_h_datums - (w_shape[1] * w_shape[3]); auto weight_matrix_rows = ((w_shape[1] * w_shape[3]) + block_height_padding) * w_shape[2]; - tt::tt_metal::LegacyShape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(compute_volume(output_shape)); + ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(output_shape.volume()); for (auto r = 0; r < w_shape[2]; r++) { for (auto s = 0; s < w_shape[3]; s++) { for (auto c = 0; c < w_shape[1]; c++) { @@ -112,8 +112,8 @@ Tensor to_weight_tile_layout( weight_matrix_rows = (uint32_t)std::ceil((double)weight_matrix_rows / (double)in1_block_h_datums) * in1_block_h_datums; } - tt::tt_metal::LegacyShape output_shape = {1, 1, weight_matrix_rows, weight_matrix_cols}; - auto output_buffer = owned_buffer::create(compute_volume(output_shape)); + ttnn::SimpleShape output_shape{1, 1, weight_matrix_rows, weight_matrix_cols}; + auto output_buffer = owned_buffer::create(output_shape.volume()); for (auto r = 0; r < w_shape[2]; r++) { for (auto s = 0; s < w_shape[3]; s++) { for (auto c = 0; c < w_shape[1]; c++) { @@ -243,11 +243,11 @@ Helper function to aid in converting grouped weight tensor to ungrouped weight t template static Tensor conv_group_weight_zero_pad_helper( Tensor& conv_weight_tensor, - tt::tt_metal::LegacyShape& original_weight_shape, - tt::tt_metal::LegacyShape& output_weight_shape, + const ttnn::SimpleShape& original_weight_shape, + const ttnn::SimpleShape& output_weight_shape, uint32_t num_groups, DataType output_dtype) { - owned_buffer::Buffer output_buffer = owned_buffer::create(compute_volume(output_weight_shape)); + owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); auto conv_weight_tensor_buffer = borrowed_buffer::get_as(conv_weight_tensor); for (int curr_batch_idx = 0; curr_batch_idx < original_weight_shape[0]; curr_batch_idx++) { @@ -289,10 +289,10 @@ Helper function to aid in converting depthwise weight tensor to broadcasted weig template static Tensor conv_depthwise_weight_bcast_helper( Tensor& conv_weight_tensor, - tt::tt_metal::LegacyShape& original_weight_shape, - tt::tt_metal::LegacyShape& output_weight_shape, + const ttnn::SimpleShape& original_weight_shape, + const ttnn::SimpleShape& output_weight_shape, DataType output_dtype) { - owned_buffer::Buffer output_buffer = owned_buffer::create(compute_volume(output_weight_shape)); + owned_buffer::Buffer output_buffer = owned_buffer::create(output_weight_shape.volume()); auto conv_weight_tensor_buffer = borrowed_buffer::get_as(conv_weight_tensor); // Copy the original weight tensor to the output tensor for (int i = 0; i < output_weight_shape[0]; i++) { @@ -330,12 +330,12 @@ Tensor convert_conv_weight_tensor_to_grouped_layout( // Define output tensor shape. This is going to be channel dimension of weight tensor * num_groups - this value // should match number of input channels being convolved with the weight tensor auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); - tt::tt_metal::LegacyShape original_conv_weight_tensor_shape = { + ttnn::SimpleShape original_conv_weight_tensor_shape{ original_conv_weight_tensor_shape_test[0], original_conv_weight_tensor_shape_test[1], original_conv_weight_tensor_shape_test[2], original_conv_weight_tensor_shape_test[3]}; - tt::tt_metal::LegacyShape output_conv_weight_tensor_shape = { + ttnn::SimpleShape output_conv_weight_tensor_shape{ original_conv_weight_tensor_shape[0], original_conv_weight_tensor_shape[1] * num_groups, original_conv_weight_tensor_shape[2], @@ -402,12 +402,12 @@ Tensor convert_conv_weight_tensor_to_depthwise_layout( "Convolution weights should be in row major layout for repeating the required dimensions"); auto original_conv_weight_tensor_shape_test = conv_weight_tensor.get_shape(); uint32_t num_input_channels_to_repeat = act_block_h_ntiles * constants::TILE_HEIGHT; - tt::tt_metal::LegacyShape original_conv_weight_tensor_shape = { + ttnn::SimpleShape original_conv_weight_tensor_shape{ original_conv_weight_tensor_shape_test[0], original_conv_weight_tensor_shape_test[1], original_conv_weight_tensor_shape_test[2], original_conv_weight_tensor_shape_test[3]}; - tt::tt_metal::LegacyShape output_conv_weight_tensor_shape = { + ttnn::SimpleShape output_conv_weight_tensor_shape{ original_conv_weight_tensor_shape[0], num_input_channels_to_repeat, original_conv_weight_tensor_shape[2], diff --git a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp index 26041c26367..3f4e309c307 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_utils.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_utils.hpp @@ -38,19 +38,20 @@ const tt::tt_metal::LegacyShape infer_dims_for_reshape(int N, int C, int H, int const tt::tt_metal::LegacyShape infer_dims_for_reshape_RM(int N, int C, int H, int W, uint32_t old_volume); +// TODO: Remove this once we switch to SimpleShape .volume() static std::size_t compute_volume(const tt::tt_metal::LegacyShape& shape) { size_t volume = 1; - for (auto index = 0; index < shape.size(); index++) { + for (auto index = 0; index < shape.rank(); index++) { volume *= shape[index]; } return volume; } -static std::vector compute_strides(const tt::tt_metal::LegacyShape& shape) { +static std::vector compute_strides(const ttnn::SimpleShape& shape) { if (shape.rank() == 0) return {}; - auto num_elements = compute_volume(shape); + auto num_elements = shape.volume(); std::vector strides; for (std::int32_t index = 0; index < shape.rank(); index++) { if (shape[index] == 0) { @@ -73,9 +74,8 @@ static int compute_flat_indices(const vector& indices, const vector -static std::size_t compute_buffer_size(const T& shape, DataType data_type) { - const size_t volume = compute_volume(shape); +static std::size_t compute_buffer_size(const ttnn::SimpleShape& shape, DataType data_type) { + const size_t volume = shape.volume(); if (data_type == DataType::BFLOAT8_B) { TT_ASSERT(volume % constants::TILE_HW == 0); const auto bfloat8_b_volume = volume / constants::TILE_HW * constants::BFLOAT8_B_TILE_HW; diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index 05c58fa3008..baffe41d56c 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -303,6 +303,7 @@ class LegacyShape { friend std::ostream &operator<<(std::ostream &os, const LegacyShape &shape); Array4D to_array_4D() const { + TT_FATAL(rank() == 4, "to_array_4D is only valid for 4D shapes! Called for {}.", *this); Array4D ret_array; for (int i = 0; i < rank(); i++) { ret_array[i] = this->operator[](i); From 8683a6509fc19f4227c9ca96aa7117f92b373f3b Mon Sep 17 00:00:00 2001 From: Aditya Saigal <129097327+tt-asaigal@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:59:11 -0400 Subject: [PATCH 11/58] #0: Cache in_worker_thread boolean inside InWorkerThread (#13595) - Repeated set lookups/recomputations caused significant degradation in host performance --- tt_metal/impl/device/device_pool.hpp | 4 ---- tt_metal/tt_metal.cpp | 10 ++++++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tt_metal/impl/device/device_pool.hpp b/tt_metal/impl/device/device_pool.hpp index e667efe9e3e..3e5ac48b793 100644 --- a/tt_metal/impl/device/device_pool.hpp +++ b/tt_metal/impl/device/device_pool.hpp @@ -24,10 +24,6 @@ class DevicePool { return *_inst; } - static bool is_instantiated() { - return (_inst != nullptr); - } - static void initialize( std::vector device_ids, const uint8_t num_hw_cqs, diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 0cdc17d0050..2b28881d46e 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -377,10 +377,16 @@ void CloseDevices(std::map devices) { } bool InWorkerThread() { - bool in_worker_thread = false; - if (tt::DevicePool::is_instantiated()) { + // These are values are cached per thread. in_worker_thread is a 1:1 function of the thread_id. + // Therefore it does not need to be recomputed or looked up using the worker_thread_ids each time. + // This is a performance optimization, since looking up the thread id inside worker_thread_ids for + // each function call significantly degrades runtime perf. + thread_local static bool in_worker_thread = false; + thread_local static bool is_thread_status_checked = false; + if (not is_thread_status_checked) { auto worker_thread_ids = tt::DevicePool::instance().get_worker_thread_ids(); in_worker_thread = worker_thread_ids.find(std::this_thread::get_id()) != worker_thread_ids.end(); + is_thread_status_checked = true; } return in_worker_thread; } From b64b445f1c17de277ce2299c13ee975fdeef25d5 Mon Sep 17 00:00:00 2001 From: Aditya Saigal <129097327+tt-asaigal@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:52:05 -0400 Subject: [PATCH 12/58] #0: Temporarily disable function inlining on prefetcher when Watcher is enabled (#13597) - This is temporary to ensure that debug tools can be used while meeting code space constraints --- tt_metal/impl/device/device.cpp | 24 ++++++++++++++++++++---- tt_metal/impl/device/device.hpp | 2 +- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 197247a1f46..80b0bff5c8a 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -688,7 +688,8 @@ void Device::configure_kernel_variant( NOC upstream_noc_index, NOC downstream_noc_index, bool is_active_eth_core, - bool send_to_brisc) { + bool send_to_brisc, + bool force_watcher_no_inline) { const auto& grid_size = this->grid_size(); @@ -711,6 +712,9 @@ void Device::configure_kernel_variant( {"DOWNSTREAM_SLAVE_NOC_Y", std::to_string(NOC_0_Y(downstream_noc_index, grid_size.y, downstream_slave_physical_core.y))}, {"FD_CORE_TYPE", std::to_string(programmable_core_type_index)}, }; + if (force_watcher_no_inline) { + defines.at("WATCHER_NOINLINE") = std::to_string(force_watcher_no_inline); + } if (llrt::OptionsG.watcher_dispatch_disabled()) { defines["FORCE_WATCHER_OFF"] = "1"; } @@ -2185,7 +2189,11 @@ void Device::compile_command_queue_programs() { std::map {}, my_noc_index, my_noc_index, - my_noc_index + my_noc_index, + false, + false, + // TEMP: Disable function inlining on Prefetcher when watcher is enabled but no_inline is not specified to respect code space + tt::llrt::OptionsG.get_watcher_enabled() && (not tt::llrt::OptionsG.get_watcher_noinline()) ); auto [tensix_num_worker_cores, tensix_worker_physical_grid] = get_physical_worker_grid_config(this->id(), num_hw_cqs, dispatch_core_type); @@ -2337,7 +2345,11 @@ void Device::compile_command_queue_programs() { std::map {}, my_noc_index, my_noc_index, - my_noc_index + my_noc_index, + false, + false, + // TEMP: Disable function inlining on Prefetcher when watcher is enabled but no_inline is not specified to respect code space + tt::llrt::OptionsG.get_watcher_enabled() && (not tt::llrt::OptionsG.get_watcher_noinline()) ); cq_id = (cq_id + 1) % num_hw_cqs; } @@ -2514,7 +2526,11 @@ void Device::compile_command_queue_programs() { std::map {}, my_noc_index, my_noc_index, - my_noc_index + my_noc_index, + false, + false, + // TEMP: Disable function inlining on Prefetcher when watcher is enabled but no_inline is not specified to respect code space + tt::llrt::OptionsG.get_watcher_enabled() && (not tt::llrt::OptionsG.get_watcher_noinline()) ); cq_id = (cq_id + 1) % num_hw_cqs; } diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index e46fd7487e9..278c10a39f0 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -234,7 +234,7 @@ class Device { void init_command_queue_device(); void initialize_synchronous_sw_cmd_queue(); void configure_kernel_variant(Program& program, string path, std::vector compile_args, CoreCoord kernel_core, CoreCoord Kernel_physical_core, - CoreType dispatch_core_type, CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, CoreCoord downstream_slave_physical_core, std::map defines_in, NOC my_noc_index, NOC upstream_noc_index, NOC downstream_noc_index, bool is_active_eth_core = false, bool send_to_brisc = false); + CoreType dispatch_core_type, CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, CoreCoord downstream_slave_physical_core, std::map defines_in, NOC my_noc_index, NOC upstream_noc_index, NOC downstream_noc_index, bool is_active_eth_core = false, bool send_to_brisc = false, bool force_watcher_no_inline = false); void compile_command_queue_programs(); void configure_command_queue_programs(); void clear_l1_state(); From 58aad728fd80cd657534b084f03badf382052ddd Mon Sep 17 00:00:00 2001 From: Eyon Land <41128502+eyonland@users.noreply.github.com> Date: Tue, 8 Oct 2024 18:07:34 -0500 Subject: [PATCH 13/58] #13370: Introduce inline namespace v0 (#13440) Co-authored-by: Patrick Roberts --- tt_metal/detail/reports/memory_reporter.hpp | 8 +- tt_metal/detail/tt_metal.hpp | 6 +- tt_metal/graph/graph_tracking.hpp | 4 + tt_metal/host_api.hpp | 2 + tt_metal/impl/buffers/buffer.hpp | 22 +++- tt_metal/impl/buffers/circular_buffer.hpp | 10 +- .../impl/buffers/circular_buffer_types.hpp | 2 + tt_metal/impl/debug/dprint_server.hpp | 4 +- tt_metal/impl/device/device.hpp | 13 +- tt_metal/impl/dispatch/command_queue.cpp | 120 +++++++++--------- tt_metal/impl/dispatch/command_queue.hpp | 13 +- tt_metal/impl/event/event.hpp | 7 +- tt_metal/impl/kernels/kernel.hpp | 8 ++ tt_metal/impl/program/program.hpp | 18 ++- tt_metal/impl/trace/trace.hpp | 2 + tt_metal/tools/profiler/tt_metal_profiler.cpp | 3 + tt_metal/tt_metal.cpp | 19 +-- .../common/types/ccl_types_args_emitters.hpp | 7 +- .../host/reduce_scatter_worker_builder.hpp | 6 +- ttnn/cpp/ttnn/tensor/tensor_ops.hpp | 9 +- 20 files changed, 179 insertions(+), 104 deletions(-) diff --git a/tt_metal/detail/reports/memory_reporter.hpp b/tt_metal/detail/reports/memory_reporter.hpp index 290bd1f9849..e5138f02a35 100644 --- a/tt_metal/detail/reports/memory_reporter.hpp +++ b/tt_metal/detail/reports/memory_reporter.hpp @@ -9,9 +9,12 @@ #include #include namespace tt::tt_metal { +inline namespace v0 { class Program; class Device; + +} // namespace v0 namespace detail { /** @@ -74,6 +77,5 @@ class MemoryReporter { std::ofstream program_detailed_memory_usage_report_; }; -} // namespace detail - -} // namespace tt::tt_metal +} // namespace detail +} // namespace tt::tt_metal diff --git a/tt_metal/detail/tt_metal.hpp b/tt_metal/detail/tt_metal.hpp index 336b812b1ed..59857f68541 100644 --- a/tt_metal/detail/tt_metal.hpp +++ b/tt_metal/detail/tt_metal.hpp @@ -13,9 +13,11 @@ #include "tt_metal/impl/dispatch/dispatch_core_manager.hpp" namespace tt::tt_metal { +inline namespace v0 { class Program; class Buffer; class Device; +} // namespace v0 namespace detail { @@ -278,5 +280,5 @@ namespace tt::tt_metal { void AllocateBuffer(Buffer* buffer, bool bottom_up); void DeallocateBuffer(Buffer *buffer); - } -} + } // namespace detail +} // namespace tt::tt_metal diff --git a/tt_metal/graph/graph_tracking.hpp b/tt_metal/graph/graph_tracking.hpp index 0c29c537002..dcea7b8dcd9 100644 --- a/tt_metal/graph/graph_tracking.hpp +++ b/tt_metal/graph/graph_tracking.hpp @@ -13,8 +13,12 @@ #include "tt_metal/impl/buffers/buffer.hpp" namespace tt::tt_metal { +inline namespace v0 { class Program; + +} // namespace v0 + class IGraphProcessor{ public: enum class RunMode { diff --git a/tt_metal/host_api.hpp b/tt_metal/host_api.hpp index d2b4d32bce1..3fa93334462 100644 --- a/tt_metal/host_api.hpp +++ b/tt_metal/host_api.hpp @@ -30,6 +30,7 @@ class CoreRangeSet; namespace tt { namespace tt_metal { +inline namespace v0 { class Program; class Device; @@ -657,6 +658,7 @@ bool EventQuery(const std::shared_ptr &event); */ void Synchronize(Device *device, const std::optional cq_id = std::nullopt); +} // namespace v0 } // namespace tt_metal } // namespace tt diff --git a/tt_metal/impl/buffers/buffer.hpp b/tt_metal/impl/buffers/buffer.hpp index c5731e7e137..c77cb98d189 100644 --- a/tt_metal/impl/buffers/buffer.hpp +++ b/tt_metal/impl/buffers/buffer.hpp @@ -22,12 +22,13 @@ #include "tt_metal/tt_stl/reflection.hpp" #include "llrt/hal.hpp" -namespace tt { - -namespace tt_metal { +namespace tt::tt_metal { +inline namespace v0 { class Device; +} // namespace v0 + struct ShardSpec { /* The individual cores the shard grid is mapped to */ CoreRangeSet grid; @@ -101,6 +102,8 @@ struct ShardSpecBuffer { } }; +inline namespace v0 { + struct BufferConfig { Device *device; DeviceAddr size; // Size in bytes @@ -124,6 +127,8 @@ struct ShardedBufferConfig { bool allocate = true; }; +} // namespace v0 + bool is_sharded(const TensorMemoryLayout &layout); struct BufferPageMapping { @@ -140,6 +145,8 @@ struct BufferPageMapping { std::vector> core_shard_shape_; }; +inline namespace v0 { + class Buffer { public: Buffer() : @@ -278,6 +285,8 @@ class Buffer { std::optional bottom_up_; }; +} // namespace v0 + BufferPageMapping generate_buffer_page_mapping(const Buffer &buffer); namespace detail { @@ -310,6 +319,8 @@ class buffer_map_t { extern buffer_map_t BUFFER_MAP; } // namespace detail +inline namespace v0 { + using HostDataType = std::variant< const std::shared_ptr>, const std::shared_ptr>, @@ -319,9 +330,8 @@ using HostDataType = std::variant< const std::shared_ptr>, const void *>; -} // namespace tt_metal - -} // namespace tt +} // namespace v0 +} // namespace tt::tt_metal namespace tt::stl::json { template <> diff --git a/tt_metal/impl/buffers/circular_buffer.hpp b/tt_metal/impl/buffers/circular_buffer.hpp index 0e305aeb76c..cac5ad99918 100644 --- a/tt_metal/impl/buffers/circular_buffer.hpp +++ b/tt_metal/impl/buffers/circular_buffer.hpp @@ -8,9 +8,8 @@ #include "common/tt_backend_api_types.hpp" #include "tt_metal/impl/buffers/circular_buffer_types.hpp" -namespace tt { - -namespace tt_metal { +namespace tt::tt_metal { +inline namespace v0 { class CircularBuffer { public: @@ -66,6 +65,5 @@ class CircularBuffer { // add a callback to invalidate circular buffer allocation }; -} // namespace tt_metal - -} // namespace tt +} // namespace v0 +} // namespace tt::tt_metal diff --git a/tt_metal/impl/buffers/circular_buffer_types.hpp b/tt_metal/impl/buffers/circular_buffer_types.hpp index b27fc39ce61..a70d633b4b9 100644 --- a/tt_metal/impl/buffers/circular_buffer_types.hpp +++ b/tt_metal/impl/buffers/circular_buffer_types.hpp @@ -17,6 +17,7 @@ #include "tt_metal/impl/tile/tile.hpp" namespace tt::tt_metal { +inline namespace v0 { using CBHandle = uintptr_t; @@ -151,4 +152,5 @@ class CircularBufferConfig { std::optional max_size_ = std::nullopt; }; +} // namespace v0 } // namespace tt::tt_metal diff --git a/tt_metal/impl/debug/dprint_server.hpp b/tt_metal/impl/debug/dprint_server.hpp index be5bbab5353..535a52a30e1 100644 --- a/tt_metal/impl/debug/dprint_server.hpp +++ b/tt_metal/impl/debug/dprint_server.hpp @@ -11,8 +11,10 @@ namespace tt { namespace tt_metal { +inline namespace v0 { class Device; -} +} // namespace v0 +} // namespace tt_metal /* @brief Attaches a device to be monitored by the print server. If no devices were present on the diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 278c10a39f0..92dbdb38bb8 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -23,14 +23,19 @@ namespace tt { namespace tt_metal { - // Fwd declares enum class BufferType; + +inline namespace v0 { + class Buffer; class Program; +class CommandQueue; + +} // namespace v0 + class JitBuildEnv; class HWCommandQueue; -class CommandQueue; class TraceBuffer; namespace detail { @@ -54,6 +59,8 @@ static constexpr float INF_GS = 1.6948e38; static constexpr float INF_WHB0 = 1.7014e+38; static constexpr float INF_BH = INF_WHB0; +inline namespace v0 { + // A physical PCIexpress Tenstorrent device class Device { public: @@ -334,6 +341,8 @@ class Device { std::unordered_map> trace_buffer_pool_; }; +} // namespace v0 + inline HalProgrammableCoreType Device::get_programmable_core_type(CoreCoord phys_core) const { HalProgrammableCoreType programmable_core_type = HalProgrammableCoreType::TENSIX; diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index c1137b1a571..0e94bcf7afd 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -2840,6 +2840,8 @@ void EnqueueDeallocateBuffer( }); } +inline namespace v0 { + void EnqueueReadBuffer( CommandQueue& cq, std::variant, std::shared_ptr> buffer, @@ -2889,22 +2891,6 @@ void EnqueueReadBuffer( .type = EnqueueCommandType::ENQUEUE_READ_BUFFER, .blocking = blocking, .buffer = buffer, .dst = dst}); } -void EnqueueReadBufferImpl( - CommandQueue& cq, - std::variant, std::shared_ptr> buffer, - void* dst, - bool blocking) { - std::visit( - [&cq, dst, blocking](auto&& b) { - using T = std::decay_t; - if constexpr ( - std::is_same_v> || std::is_same_v>) { - cq.hw_command_queue().enqueue_read_buffer(b, dst, blocking); - } - }, - buffer); -} - void EnqueueWriteBuffer( CommandQueue& cq, std::variant, std::shared_ptr> buffer, @@ -2915,14 +2901,6 @@ void EnqueueWriteBuffer( .type = EnqueueCommandType::ENQUEUE_WRITE_BUFFER, .blocking = blocking, .buffer = buffer, .src = src}); } -void EnqueueWriteBufferImpl( - CommandQueue& cq, - std::variant, std::shared_ptr> buffer, - HostDataType src, - bool blocking) { - cq.hw_command_queue().enqueue_write_buffer(buffer, src, blocking); -} - void EnqueueProgram( CommandQueue& cq, Program& program, bool blocking) { detail::DispatchStateCheck(true); @@ -2930,21 +2908,6 @@ void EnqueueProgram( CommandInterface{.type = EnqueueCommandType::ENQUEUE_PROGRAM, .blocking = blocking, .program = &program}); } -void EnqueueProgramImpl( - CommandQueue& cq, Program& program, bool blocking) { - ZoneScoped; - - Device* device = cq.device(); - detail::CompileProgram(device, program); - program.allocate_circular_buffers(device); - detail::ValidateCircularBufferRegion(program, device); - cq.hw_command_queue().enqueue_program(program, blocking); - // Program relinquishes ownership of all global buffers its using, once its been enqueued. Avoid mem - // leaks on device. - program.release_buffers(); - -} - void EnqueueRecordEvent(CommandQueue& cq, const std::shared_ptr& event) { detail::DispatchStateCheck(true); cq.run_command(CommandInterface{ @@ -2954,10 +2917,6 @@ void EnqueueRecordEvent(CommandQueue& cq, const std::shared_ptr& event) { }); } -void EnqueueRecordEventImpl(CommandQueue& cq, const std::shared_ptr& event) { - cq.hw_command_queue().enqueue_record_event(event); -} - void EnqueueWaitForEvent(CommandQueue& cq, const std::shared_ptr& event) { detail::DispatchStateCheck(true); cq.run_command(CommandInterface{ @@ -2967,19 +2926,6 @@ void EnqueueWaitForEvent(CommandQueue& cq, const std::shared_ptr& event) }); } -void EnqueueWaitForEventImpl(CommandQueue& cq, const std::shared_ptr& event) { - event->wait_until_ready(); // Block until event populated. Worker thread. - log_trace( - tt::LogMetal, - "EnqueueWaitForEvent() issued on Event(device_id: {} cq_id: {} event_id: {}) from device_id: {} cq_id: {}", - event->device->id(), - event->cq_id, - event->event_id, - cq.device()->id(), - cq.id()); - cq.hw_command_queue().enqueue_wait_for_event(event); -} - void EventSynchronize(const std::shared_ptr& event) { detail::DispatchStateCheck(true); event->wait_until_ready(); // Block until event populated. Parent thread. @@ -3028,8 +2974,6 @@ void Finish(CommandQueue& cq) { tt::watcher_get_log_file_name()); } -void FinishImpl(CommandQueue& cq) { cq.hw_command_queue().finish(); } - void EnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) { detail::DispatchStateCheck(true); TT_FATAL(cq.device()->get_trace(trace_id) != nullptr, "Trace instance {} must exist on device", trace_id); @@ -3037,6 +2981,66 @@ void EnqueueTrace(CommandQueue& cq, uint32_t trace_id, bool blocking) { CommandInterface{.type = EnqueueCommandType::ENQUEUE_TRACE, .blocking = blocking, .trace_id = trace_id}); } +} // namespace v0 + +void EnqueueReadBufferImpl( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + void* dst, + bool blocking) { + std::visit( + [&cq, dst, blocking](auto&& b) { + using T = std::decay_t; + if constexpr ( + std::is_same_v> || std::is_same_v>) { + cq.hw_command_queue().enqueue_read_buffer(b, dst, blocking); + } + }, + buffer); +} + +void EnqueueWriteBufferImpl( + CommandQueue& cq, + std::variant, std::shared_ptr> buffer, + HostDataType src, + bool blocking) { + cq.hw_command_queue().enqueue_write_buffer(buffer, src, blocking); +} + +void EnqueueProgramImpl( + CommandQueue& cq, Program& program, bool blocking) { + ZoneScoped; + + Device* device = cq.device(); + detail::CompileProgram(device, program); + program.allocate_circular_buffers(device); + detail::ValidateCircularBufferRegion(program, device); + cq.hw_command_queue().enqueue_program(program, blocking); + // Program relinquishes ownership of all global buffers its using, once its been enqueued. Avoid mem + // leaks on device. + program.release_buffers(); + +} + +void EnqueueRecordEventImpl(CommandQueue& cq, const std::shared_ptr& event) { + cq.hw_command_queue().enqueue_record_event(event); +} + +void EnqueueWaitForEventImpl(CommandQueue& cq, const std::shared_ptr& event) { + event->wait_until_ready(); // Block until event populated. Worker thread. + log_trace( + tt::LogMetal, + "EnqueueWaitForEvent() issued on Event(device_id: {} cq_id: {} event_id: {}) from device_id: {} cq_id: {}", + event->device->id(), + event->cq_id, + event->event_id, + cq.device()->id(), + cq.id()); + cq.hw_command_queue().enqueue_wait_for_event(event); +} + +void FinishImpl(CommandQueue& cq) { cq.hw_command_queue().finish(); } + void EnqueueTraceImpl(CommandQueue& cq, uint32_t trace_id, bool blocking) { cq.hw_command_queue().enqueue_trace(trace_id, blocking); } diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index 8024b937d1b..cbaf4c9b089 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -21,11 +21,15 @@ #include "tt_metal/impl/trace/trace_buffer.hpp" namespace tt::tt_metal { +inline namespace v0 { +class CommandQueue; class Event; class Trace; using RuntimeArgs = std::vector>; +} // namespace v0 + // Only contains the types of commands which are enqueued onto the device enum class EnqueueCommandType { ENQUEUE_READ_BUFFER, @@ -47,7 +51,6 @@ enum class EnqueueCommandType { string EnqueueCommandTypeToString(EnqueueCommandType ctype); -class CommandQueue; class CommandInterface; using WorkerQueue = LockFreeQueue; @@ -591,8 +594,8 @@ class HWCommandQueue { friend void EnqueueWaitForEventImpl(CommandQueue& cq, const std::shared_ptr& event); friend void FinishImpl(CommandQueue& cq); friend void EnqueueRecordEvent(CommandQueue& cq, const std::shared_ptr& event); - friend class CommandQueue; - friend class Device; + friend CommandQueue; + friend Device; }; // Common interface for all command queue types @@ -610,6 +613,8 @@ struct CommandInterface { std::optional trace_id; }; +inline namespace v0 { + class CommandQueue { friend class Device; friend class Trace; @@ -697,6 +702,8 @@ class CommandQueue { inline static uint32_t num_passthrough_cqs = 0; }; +} // namespace v0 + // Primitives used to place host only operations on the SW Command Queue. // These are used in functions exposed through tt_metal.hpp or host_api.hpp void EnqueueAllocateBuffer(CommandQueue& cq, Buffer* buffer, bool bottom_up, bool blocking); diff --git a/tt_metal/impl/event/event.hpp b/tt_metal/impl/event/event.hpp index 803b447bd53..17503153a30 100644 --- a/tt_metal/impl/event/event.hpp +++ b/tt_metal/impl/event/event.hpp @@ -8,8 +8,8 @@ #include #include "tt_metal/common/assert.hpp" #include "tt_metal/common/logger.hpp" -namespace tt::tt_metal -{ +namespace tt::tt_metal { +inline namespace v0 { class Device; struct Event { @@ -31,4 +31,5 @@ namespace tt::tt_metal TT_ASSERT(cq_id != -1, "Event must have initialized cq_id"); } }; -} +} // namespace v0 +} // namespace tt::tt_metal diff --git a/tt_metal/impl/kernels/kernel.hpp b/tt_metal/impl/kernels/kernel.hpp index b1f53671b57..db113b226f3 100644 --- a/tt_metal/impl/kernels/kernel.hpp +++ b/tt_metal/impl/kernels/kernel.hpp @@ -19,8 +19,12 @@ namespace tt { namespace tt_metal { +inline namespace v0 { class Device; + +} // namespace v0 + constexpr uint32_t max_runtime_args = 256; constexpr uint32_t idle_eth_max_runtime_args = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_SIZE / sizeof(uint32_t); @@ -47,6 +51,8 @@ struct KernelSource { } }; +inline namespace v0 { + class Kernel : public JitBuildSettings { public: Kernel( @@ -236,6 +242,8 @@ class ComputeKernel : public Kernel { std::string config_hash() const override; }; +} // namespace v0 + std::ostream& operator<<(std::ostream& os, const DataMovementProcessor& processor); struct KernelDefinesHash { diff --git a/tt_metal/impl/program/program.hpp b/tt_metal/impl/program/program.hpp index 679d6fb4d35..b3eff1841ce 100644 --- a/tt_metal/impl/program/program.hpp +++ b/tt_metal/impl/program/program.hpp @@ -19,13 +19,20 @@ namespace tt { namespace tt_metal { // Fwd declares +inline namespace v0 { + class Buffer; class Kernel; class CircularBuffer; class Device; class Program; -class JitBuildOptions; class CircularBufferConfig; + +} // namespace v0 + +class EnqueueProgramCommand; +class HWCommandQueue; +class JitBuildOptions; namespace detail{ void ValidateCircularBufferRegion(const Program &program, const Device *device); KernelHandle AddKernel (Program &program, std::shared_ptr kernel, const HalProgrammableCoreType core_type); @@ -70,9 +77,9 @@ struct ProgramConfig { uint32_t cb_size; }; -class Program { - friend class KernelGroup; +inline namespace v0 { +class Program { public: Program(); @@ -254,10 +261,11 @@ class Program { bool runs_on_noc_unicast_only_cores(); bool runs_on_noc_multicast_only_cores(); - friend class HWCommandQueue; - friend class EnqueueProgramCommand; + friend HWCommandQueue; + friend EnqueueProgramCommand; }; +} // namespace v0 } // namespace tt_metal } // namespace tt diff --git a/tt_metal/impl/trace/trace.hpp b/tt_metal/impl/trace/trace.hpp index b3e561ce80f..ca77c375181 100644 --- a/tt_metal/impl/trace/trace.hpp +++ b/tt_metal/impl/trace/trace.hpp @@ -15,6 +15,7 @@ #include "tt_metal/impl/trace/trace_buffer.hpp" namespace tt::tt_metal { +inline namespace v0 { class Trace { private: @@ -31,4 +32,5 @@ class Trace { static std::shared_ptr create_empty_trace_buffer(); }; +} // namespace v0 } // namespace tt::tt_metal diff --git a/tt_metal/tools/profiler/tt_metal_profiler.cpp b/tt_metal/tools/profiler/tt_metal_profiler.cpp index 959469e89e1..bf8fcf952c5 100644 --- a/tt_metal/tools/profiler/tt_metal_profiler.cpp +++ b/tt_metal/tools/profiler/tt_metal_profiler.cpp @@ -21,6 +21,7 @@ namespace tt { namespace tt_metal { +inline namespace v0 { void DumpDeviceProfileResults(Device* device, const Program& program) { #if defined(TRACY_ENABLE) @@ -46,6 +47,8 @@ void DumpDeviceProfileResults(Device* device, const Program& program) { #endif } +} // namespace v0 + namespace detail { std::map tt_metal_device_profiler_map; diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 2b28881d46e..42def67b5ad 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -157,14 +157,14 @@ std::optional get_semaphore_id(const Program &program, const CoreRange return semaphore_id; } -inline void SetRuntimeArgs( +inline void SetRuntimeArgsImpl( const Program &program, KernelHandle kernel_id, const CoreCoord &c, const std::vector &runtime_args) { if (runtime_args.size() != 0) { detail::GetKernel(program, kernel_id)->set_runtime_args(c, runtime_args); } } -inline void SetRuntimeArgs( +inline void SetRuntimeArgsImpl( const Program &program, KernelHandle kernel_id, const CoreRange &core_range, @@ -179,7 +179,7 @@ inline void SetRuntimeArgs( } } -inline void SetRuntimeArgs( +inline void SetRuntimeArgsImpl( const Program &program, KernelHandle kernel_id, const CoreRangeSet &core_range_set, @@ -196,7 +196,7 @@ inline void SetRuntimeArgs( } } -inline void SetRuntimeArgs( +inline void SetRuntimeArgsImpl( CommandQueue &cq, const std::shared_ptr kernel, const std::variant &core_spec, @@ -227,7 +227,7 @@ inline void SetRuntimeArgs( core_spec); } -inline void SetRuntimeArgs( +inline void SetRuntimeArgsImpl( CommandQueue &cq, const std::shared_ptr kernel, const std::vector &core_spec, @@ -877,6 +877,8 @@ void DeallocateBuffer(Buffer *buffer) { } // namespace detail +inline namespace v0 { + size_t GetNumAvailableDevices() { return tt::Cluster::instance().number_of_user_devices(); } @@ -1119,7 +1121,7 @@ void SetRuntimeArgs( not CommandQueue::async_mode_set(), "This variant of SetRuntimeArgs can only be called when Asynchronous SW Command Queues are disabled for Fast " "Dispatch."); - std::visit([&](auto &&core_spec) { SetRuntimeArgs(program, kernel_id, core_spec, runtime_args); }, core_spec); + std::visit([&](auto &&core_spec) { SetRuntimeArgsImpl(program, kernel_id, core_spec, runtime_args); }, core_spec); } void SetRuntimeArgs( @@ -1147,7 +1149,7 @@ void SetRuntimeArgs( const std::variant &core_spec, std::shared_ptr runtime_args) { detail::DispatchStateCheck(not device->using_slow_dispatch()); - SetRuntimeArgs(device->command_queue(), kernel, core_spec, runtime_args, false); + SetRuntimeArgsImpl(device->command_queue(), kernel, core_spec, runtime_args, false); } void SetRuntimeArgs( @@ -1161,7 +1163,7 @@ void SetRuntimeArgs( core_spec.size(), runtime_args.size()); detail::DispatchStateCheck(not device->using_slow_dispatch()); - SetRuntimeArgs(device->command_queue(), kernel, core_spec, runtime_args, false); + SetRuntimeArgsImpl(device->command_queue(), kernel, core_spec, runtime_args, false); } void SetCommonRuntimeArgs(const Program &program, KernelHandle kernel_id, const std::vector &runtime_args) { @@ -1222,6 +1224,7 @@ void Synchronize(Device *device, const std::optional cq_id) { } } +} // namespace v0 } // namespace tt_metal } // namespace tt diff --git a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp index 59abce85eee..534a90f0dac 100644 --- a/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/common/types/ccl_types_args_emitters.hpp @@ -11,9 +11,12 @@ namespace tt { namespace tt_metal { class Tensor; + +inline namespace v0 { class Device; -} // namespace tt_metal -} // namespace tt +} // namespace v0 +} // namespace tt_metal +} // namespace tt namespace ttnn { namespace ccl { diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp index f853856ae32..0008def47b9 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/host/reduce_scatter_worker_builder.hpp @@ -12,12 +12,14 @@ namespace tt { namespace tt_metal { +inline namespace v0 { // Forward declarations class Device; -} // namespace tt_metal -} // namespace tt +} // namespace v0 +} // namespace tt_metal +} // namespace tt namespace ttnn { namespace ccl { diff --git a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp index 6de950409e5..56113a8db25 100644 --- a/ttnn/cpp/ttnn/tensor/tensor_ops.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor_ops.hpp @@ -7,11 +7,14 @@ namespace tt::tt_metal { struct Tensor; -class CommandQueue; struct MemoryConfig; -class Device; class MeshDevice; -} + +inline namespace v0 { +class CommandQueue; +class Device; +} // namespace v0 +} // namespace tt::tt_metal namespace tt::tt_metal::tensor_ops { From c844ed3d903956a911bd12cafa678647efe31ae3 Mon Sep 17 00:00:00 2001 From: Aditya Saigal <129097327+tt-asaigal@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:10:35 -0400 Subject: [PATCH 14/58] #0: Add functional tests for 2CQ ResNet-50 (#13561) --- .github/workflows/tg-frequent-tests-impl.yaml | 2 +- .../tests/test_resnet50_performant.py | 66 +++++++++++++++++++ 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tg-frequent-tests-impl.yaml b/.github/workflows/tg-frequent-tests-impl.yaml index 57be999d057..aaf5208327b 100644 --- a/.github/workflows/tg-frequent-tests-impl.yaml +++ b/.github/workflows/tg-frequent-tests-impl.yaml @@ -36,7 +36,7 @@ jobs: run: tar -xvf ttm_${{ matrix.test-group.arch }}.tar - uses: ./.github/actions/install-python-deps - name: Run frequent regression tests - timeout-minutes: 60 + timeout-minutes: 90 run: | source ${{ github.workspace }}/python_env/bin/activate cd $TT_METAL_HOME diff --git a/models/demos/tg/resnet50/tests/test_resnet50_performant.py b/models/demos/tg/resnet50/tests/test_resnet50_performant.py index c3f10fe9272..6a82bbab68c 100644 --- a/models/demos/tg/resnet50/tests/test_resnet50_performant.py +++ b/models/demos/tg/resnet50/tests/test_resnet50_performant.py @@ -77,3 +77,69 @@ def test_run_resnet50_trace_inference( math_fidelity, model_location_generator, ) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_command_queues": 2}], indirect=True) +@pytest.mark.parametrize( + "device_batch_size, act_dtype, weight_dtype, math_fidelity", + ((16, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi),), +) +@pytest.mark.parametrize("enable_async_mode", [True], indirect=True) +@pytest.mark.parametrize( + "mesh_device", + ((8, 4),), + indirect=True, +) +def test_run_resnet50_2cqs_inference( + mesh_device, + use_program_cache, + device_batch_size, + act_dtype, + weight_dtype, + math_fidelity, + enable_async_mode, + model_location_generator, +): + run_resnet50_2cqs_inference( + mesh_device, + device_batch_size, + act_dtype, + weight_dtype, + math_fidelity, + model_location_generator, + ) + + +@run_for_wormhole_b0() +@pytest.mark.parametrize( + "device_params", [{"l1_small_size": 24576, "trace_region_size": 800768, "num_command_queues": 2}], indirect=True +) +@pytest.mark.parametrize( + "device_batch_size, act_dtype, weight_dtype, math_fidelity", + ((16, ttnn.bfloat8_b, ttnn.bfloat8_b, ttnn.MathFidelity.LoFi),), +) +@pytest.mark.parametrize("enable_async_mode", [True], indirect=True) +@pytest.mark.parametrize( + "mesh_device", + ((8, 4),), + indirect=True, +) +def test_run_resnet50_trace_2cqs_inference( + mesh_device, + use_program_cache, + device_batch_size, + act_dtype, + weight_dtype, + math_fidelity, + enable_async_mode, + model_location_generator, +): + run_resnet50_trace_2cqs_inference( + mesh_device, + device_batch_size, + act_dtype, + weight_dtype, + math_fidelity, + model_location_generator, + ) From 2564aebed2fd298c66378a050bd961e9165fd3af Mon Sep 17 00:00:00 2001 From: Aditya Saigal <129097327+tt-asaigal@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:26:46 -0400 Subject: [PATCH 15/58] #0: Update ResNet-50 Performance in README (#13611) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a96ce18abc0..416d466ba23 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ | [ResNet-50 (224x224)](./models/demos/grayskull/resnet50) | 20 | [e150](https://tenstorrent.com/hardware/grayskull) | 5,100 | 10,000 | | | [ResNet-50 (224x224)](./models/demos/wormhole/resnet50) | 16 | [n150](https://tenstorrent.com/hardware/wormhole) | 4,100 | 7,000 | | | [ResNet-50 (224x224) (data parallel)](./models/demos/t3000/resnet50) | 128 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 32,250 | 56,000 | | -| [ResNet-50 (224x224) (data parallel)](./models/demos/tg/resnet50) | 512 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 66,150 | 224,000 | | +| [ResNet-50 (224x224) (data parallel)](./models/demos/tg/resnet50) | 512 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 95,900 | 224,000 | | | [ResNet-50 (224x224) (data parallel)](./models/demos/tgg/resnet50) | 1024 | [Two Galaxies](https://tenstorrent.com/hardware/galaxy) | 128,800 | 448,000 | | | [ViT](./models/demos/grayskull/vit) | 9 | [e150](https://tenstorrent.com/hardware/grayskull) | 1,360 | 2,000 | | | [ViT](./models/demos/wormhole/vit) | 8 | [n150](https://tenstorrent.com/hardware/wormhole) | 912 | 1,600 | | From 9e9dc003e0a3938c75f934ce149ab0ffa6ad4619 Mon Sep 17 00:00:00 2001 From: Yu Gao <145494740+yugaoTT@users.noreply.github.com> Date: Tue, 8 Oct 2024 19:56:54 -0400 Subject: [PATCH 16/58] Add dynamic NoC support for GS, WH, and BH (#13376) * #0: add noc modes to kernel config * #0: remove all hardcoded noc_index * #0: add u-bench for read DRAM and write to remote L1 * #0: u-bench code clean up + add constexpr to cmd_bufs * #0: rename to DM_DEDICATED_NOC and DM_DYNAMIC_NOC * #0: fix soc descriptor for moving FD cores * #0: add fix to fast div, change back soc desc * #0: fix dram read l1 write for GS * #0: #0: fix dram read l1 write for BH * #0: minor fix after rebase * #0: re-calculate ret addr for atomic cmd * #0: code clean up + add dynamic noc for eth * #0: add read dram write l1 u-bench to CI * #0: reduce code size, remove dynamic noc for eth, rename dynamic_noc_init * #0: reduce code size by not using extern for noc_index when compile kernel * #0: remove NOC_MODE FW define, move NOC_MODE outside of noc_nonblocking_api * #0: reduce code size by remove else condition for noc_local_state_init --- tests/scripts/run_moreh_microbenchmark.sh | 8 +- tests/scripts/test_moreh_microbenchmark.py | 67 ++ .../kernels/reader_dram.cpp | 117 +++ .../kernels/writer_l1.cpp | 50 + .../test_dram_read_l1_write.cpp | 959 ++++++++++++++++++ .../perf_microbenchmark/CMakeLists.txt | 1 + tt_metal/hw/firmware/src/brisc.cc | 18 +- tt_metal/hw/firmware/src/brisck.cc | 8 +- tt_metal/hw/firmware/src/erisck.cc | 2 - tt_metal/hw/firmware/src/idle_erisck.cc | 3 +- tt_metal/hw/firmware/src/ncrisck.cc | 8 +- .../hw/inc/blackhole/noc_nonblocking_api.h | 62 +- tt_metal/hw/inc/dataflow_api.h | 492 ++++----- tt_metal/hw/inc/debug/sanitize_noc.h | 1 - tt_metal/hw/inc/dev_msgs.h | 13 +- .../hw/inc/grayskull/noc_nonblocking_api.h | 60 +- tt_metal/hw/inc/mod_div_lib.h | 22 +- tt_metal/hw/inc/risc_common.h | 2 + .../hw/inc/wormhole/noc_nonblocking_api.h | 64 +- tt_metal/impl/allocator/allocator.cpp | 2 +- tt_metal/impl/kernels/data_types.hpp | 5 + tt_metal/impl/kernels/kernel.cpp | 5 + tt_metal/impl/kernels/kernel_types.hpp | 3 + tt_metal/impl/program/program.cpp | 9 + 24 files changed, 1721 insertions(+), 260 deletions(-) create mode 100644 tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/reader_dram.cpp create mode 100644 tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/writer_l1.cpp create mode 100644 tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/test_dram_read_l1_write.cpp diff --git a/tests/scripts/run_moreh_microbenchmark.sh b/tests/scripts/run_moreh_microbenchmark.sh index de20e4221ac..cdccd2f8302 100755 --- a/tests/scripts/run_moreh_microbenchmark.sh +++ b/tests/scripts/run_moreh_microbenchmark.sh @@ -33,8 +33,12 @@ run_profiling_test() { pytest --capture=tee-sys $TT_METAL_HOME/tests/scripts/test_moreh_microbenchmark.py::test_matmul_l1 -k $ARCH_NAME if [[ "$ARCH_NAME" == "wormhole_b0" ]]; then - pytest --capture=tee-sys $TT_METAL_HOME/tests/scripts/test_moreh_microbenchmark.py::test_matmul_single_core_sharded -k $ARCH_NAME - pytest --capture=tee-sys $TT_METAL_HOME/tests/scripts/test_moreh_microbenchmark.py::test_dram_read_12_core -k $ARCH_NAME + pytest --capture=tee-sys $TT_METAL_HOME/tests/scripts/test_moreh_microbenchmark.py::test_matmul_single_core_sharded -k $ARCH_NAME + pytest --capture=tee-sys $TT_METAL_HOME/tests/scripts/test_moreh_microbenchmark.py::test_dram_read_12_core -k $ARCH_NAME + fi + # bypass wh_b0 for now until we can move FD cores to last col + if [[ "$ARCH_NAME" != "wormhole_b0" ]]; then + pytest --capture=tee-sys $TT_METAL_HOME/tests/scripts/test_moreh_microbenchmark.py::test_dram_read_l1_write_core -k $ARCH_NAME fi } diff --git a/tests/scripts/test_moreh_microbenchmark.py b/tests/scripts/test_moreh_microbenchmark.py index 3d9d8a50782..dc1e3b9b4c9 100755 --- a/tests/scripts/test_moreh_microbenchmark.py +++ b/tests/scripts/test_moreh_microbenchmark.py @@ -265,6 +265,28 @@ def run_dram_read_cmd(k, n, num_blocks, df, num_banks, bank_start_id): run_moreh_single_test("DRAM BW test multi-core", command) +def run_dram_read_l1_write_cmd(k, n, num_blocks, df, num_banks, bank_start_id): + command = ( + "TT_METAL_DEVICE_PROFILER=1 ./build/test/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/test_dram_read_l1_write " + + " --k " + + str(k) + + " --n " + + str(n) + + " --num-blocks " + + str(num_blocks) + + " --num-tests " + + str(1) + + " --data-type " + + str(df) + + " --num-banks " + + str(num_banks) + + " --bank-start-id " + + str(bank_start_id) + + " --bypass-check " + ) + run_moreh_single_test("DRAM BW test multi-core", command) + + # noc def test_noc_local(r=9, c=12, nt=256, cb=1): command = ( @@ -672,6 +694,51 @@ def test_dram_read_12_core(arch, freq, test_vector, num_tests, nblock, data_form assert bw_bound <= throughput +@pytest.mark.parametrize( + "arch, freq, test_vector, num_tests, nblock, data_format, num_banks, bank_start_id", + [ + ("grayskull", 1202, np.array([32768 * 2, 8 * 128]), 1, 64, 1, 8, 0), + ("wormhole_b0", 1000, np.array([32768 * 2, 12 * 128]), 1, 64, 1, 12, 0), + ("blackhole", 800, np.array([32768 * 8, 8 * 128]), 1, 256, 1, 8, 0), + ], +) +def test_dram_read_l1_write_core(arch, freq, test_vector, num_tests, nblock, data_format, num_banks, bank_start_id): + data = [] + cycle_list = [] + time_list = [] + throughput_list = [] + for _ in range(num_tests): + k = int(test_vector[0]) + n = int(test_vector[1]) + if data_format == 0: + input_size = k * n * 1088 // 1024 + elif data_format == 1: + input_size = k * n * 2048 // 1024 + run_dram_read_l1_write_cmd(k, n, nblock, data_format, num_banks, bank_start_id) + cycle = profile_results_kernel_duration() + time = cycle / freq / 1000.0 / 1000.0 + throughput = input_size / cycle * freq / 1000.0 + cycle_list.append(cycle) + time_list.append(time) + throughput_list.append(throughput) + cycle = sum(cycle_list) / len(cycle_list) + time = sum(time_list) / len(time_list) + throughput = sum(throughput_list) / len(throughput_list) + logger.info("DRAM read cycle: " + str(cycle)) + logger.info("DRAM read time: " + str(time)) + logger.info("DRAM read throughput: " + str(throughput)) + data.append([throughput]) + # check within range + dev_freq = get_device_freq() + if arch == "grayskull": + bw_bound = 100.0 + elif arch == "wormhole_b0": + bw_bound = 260.0 + elif arch == "blackhole": + bw_bound = 340.0 + assert bw_bound <= throughput + + @pytest.mark.parametrize( "arch, freq, r, c, test_vector_global, test_vector_local", [ diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/reader_dram.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/reader_dram.cpp new file mode 100644 index 00000000000..48c659c54ce --- /dev/null +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/reader_dram.cpp @@ -0,0 +1,117 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +#include "debug/dprint.h" + +template +FORCE_INLINE +void noc_async_read_tile_dram_sharded(uint32_t src_addr, uint32_t dest_addr, uint32_t bank_id = 0, const uint32_t vc = 0) { + uint32_t src_addr_; + uint32_t src_noc_xy; + + src_addr_ = src_addr + bank_base_address; + src_addr_ += bank_to_dram_offset[bank_id]; + src_noc_xy = dram_bank_to_noc_xy[noc_index][bank_id]; + + WAYPOINT("NRTW"); + DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc_index, get_noc_addr_helper(src_noc_xy, src_addr_), dest_addr, page_size); + while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + WAYPOINT("NRTD"); + + if constexpr(use_vc) { + uint32_t noc_rd_cmd_field = NOC_CMD_CPY | NOC_CMD_RD | NOC_CMD_RESP_MARKED | NOC_CMD_VC_STATIC | NOC_CMD_STATIC_VC(vc); + NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CTRL, noc_rd_cmd_field); + } + + NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dest_addr); + NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, src_addr_); // (uint32_t)src_addr + NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 + NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, page_size); // len_bytes + NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_reads_num_issued[noc_index] += 1; +} + +void kernel_main() { + constexpr uint32_t input_addr = get_compile_time_arg_val(0); + constexpr uint32_t input_start_tile_id = get_compile_time_arg_val(1); + constexpr uint32_t num_blocks = get_compile_time_arg_val(2); + constexpr uint32_t num_pages = get_compile_time_arg_val(3); + constexpr uint32_t block_num_tiles = get_compile_time_arg_val(4); + constexpr uint32_t page_size = get_compile_time_arg_val(5); + + constexpr uint32_t block_size_bytes = page_size * num_pages; + + const uint32_t bank_id = get_arg_val(0); + const uint32_t vc = get_arg_val(1); + + constexpr uint32_t cb_id = 0; + + uint32_t src_base_addr = noc_async_read_tile_dram_sharded_set_state(input_addr, bank_id, vc); + uint32_t src_read_addr = 0; + +#ifdef ARCH_GRAYSKULL + for (uint32_t block = 0; block < num_blocks; ++block) { + // Operand 1 + cb_reserve_back(cb_id, block_num_tiles); + auto l1_write_addr = get_write_ptr(cb_id); + + for (uint32_t h = 0; h < num_pages; ++h) { + noc_async_read_tile_dram_sharded_with_state(src_base_addr, src_read_addr, l1_write_addr); + src_read_addr += page_size; + l1_write_addr += page_size; + } + + noc_async_read_barrier(); + cb_push_back(cb_id, block_num_tiles); + } +#else + constexpr uint32_t total_num_blocks_in_buffer = 3; + constexpr uint32_t total_num_trid = 4; + uint32_t num_free_blocks_in_buffer = total_num_blocks_in_buffer; + uint32_t curr_block_trid = 1; + uint32_t block_trid_to_wait = 1; + + cb_reserve_back(cb_id, block_num_tiles); + uint32_t l1_write_addr_offset = 0; + uint32_t l1_write_addr_start = get_write_ptr(cb_id); + uint32_t l1_write_addr = l1_write_addr_start; + for (uint32_t block = 0; block < num_blocks; ++block) { + noc_async_read_tile_dram_sharded_set_trid(curr_block_trid); + + for (uint32_t h = 0; h < num_pages; ++h) { + noc_async_read_tile_dram_sharded_with_state_with_trid( + src_base_addr, src_read_addr, l1_write_addr, curr_block_trid); + src_read_addr += page_size; + l1_write_addr += page_size; + } + + if (num_free_blocks_in_buffer == 2) { + noc_async_read_barrier_with_trid(block_trid_to_wait); + cb_push_back(cb_id, block_num_tiles); + // wait for next block trid + block_trid_to_wait = block_trid_to_wait == 3 ? 1 : (block_trid_to_wait + 1); + // reserve for next block + cb_reserve_back(cb_id, block_num_tiles * 2); + } else { + num_free_blocks_in_buffer -= 1; + } + + if (curr_block_trid == total_num_blocks_in_buffer) { + l1_write_addr_offset = 0; + curr_block_trid = 1; + } else { + l1_write_addr_offset += block_size_bytes; + curr_block_trid += 1; + } + l1_write_addr = l1_write_addr_start + l1_write_addr_offset; + } + // last block to wait + noc_async_read_barrier_with_trid(block_trid_to_wait); + cb_push_back(cb_id, block_num_tiles); +#endif +} diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/writer_l1.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/writer_l1.cpp new file mode 100644 index 00000000000..3184c98f187 --- /dev/null +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/writer_l1.cpp @@ -0,0 +1,50 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +#include "debug/dprint.h" + + +void kernel_main() { + constexpr uint32_t num_blocks = get_compile_time_arg_val(0); + constexpr uint32_t num_pages = get_compile_time_arg_val(1); + constexpr uint32_t block_num_tiles = get_compile_time_arg_val(2); + constexpr uint32_t page_size = get_compile_time_arg_val(3); + constexpr uint32_t noc = get_compile_time_arg_val(4); + + const uint32_t vc = get_arg_val(0); + const uint32_t noc_x = get_arg_val(1); + const uint32_t noc_y = get_arg_val(2); + + constexpr uint32_t cb_id = 0; + + uint32_t l1_write_addr = get_write_ptr(cb_id); + const uint64_t l1_noc_write_addr = get_noc_addr(noc_x, noc_y, l1_write_addr, noc); + + noc_async_write_one_packet_set_state(l1_noc_write_addr, page_size, noc, vc); + + for (uint32_t block = 0; block < num_blocks; ++block) { + + auto remote_l1_write_addr = l1_noc_write_addr; + + cb_wait_front(cb_id, block_num_tiles); + auto l1_read_addr = get_read_ptr(cb_id); + + for (uint32_t h = 0; h < num_pages; ++h) { + noc_async_write_one_packet_with_state(l1_read_addr, remote_l1_write_addr, noc); + l1_read_addr += page_size; + remote_l1_write_addr += page_size; + } + + noc_async_write_barrier(noc); + + cb_pop_front(cb_id, block_num_tiles); + + } + + +} diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/test_dram_read_l1_write.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/test_dram_read_l1_write.cpp new file mode 100644 index 00000000000..cf618979b32 --- /dev/null +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/test_dram_read_l1_write.cpp @@ -0,0 +1,959 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "common/bfloat8.hpp" +#include "common/bfloat16.hpp" +#include "common/tt_backend_api_types.hpp" +#include "tt_metal/detail/tt_metal.hpp" +#include "tt_metal/detail/util.hpp" +#include "tt_metal/host_api.hpp" +#include "tt_metal/tt_metal/perf_microbenchmark/common/util.hpp" +#include "tt_metal/common/work_split.hpp" +#include + +using namespace tt; +using std::chrono::duration_cast; +using std::chrono::microseconds; + +//////////////////////////////////////////////////////////////////////////////// +// A tensix core that's next to a DRAM bank reads from the bank, and writes to +// the neighbour receiver tensix core. It creates a bfloat16/bfloat8_b format +// DRAM buffer of a given input size, and write it to the DRAM banks in the round +// robin style. +// +// Disclaimer: +// - This benchmark is designed to support an input size larger than 4GB. But +// current tt-metal does not seem to support buffer allocation larger than 4GB +// yet. +// - Also, detail::ReadFromBuffer API used in DRAM write test may take a long time if +// the input size is large. +// +// Usage example: +// ./test_dram_offchip +// --k +// --n +// --num-blocks +// --k +// --k +// --num-tests +// --data-type +// --num-banks +// --bank-start-id +// --bypass-check (set to bypass checking performance criteria fulfillment) +//////////////////////////////////////////////////////////////////////////////// + + + +template +std::vector slice_vec(std::vector const &v, int m, int n) { + auto first = v.cbegin() + m; + auto last = v.cbegin() + n + 1; + + std::vector vec(first, last); + return vec; +} + +void get_max_page_size_and_num_pages(uint32_t num_tiles, uint32_t tile_size, uint32_t& page_size, uint32_t& num_pages) { + uint64_t total_size = static_cast(num_tiles) * tile_size; + + page_size = (8192 / tile_size) * tile_size; + while (total_size % page_size != 0 && page_size >= tile_size) { + page_size -= tile_size; + } + num_pages = total_size / page_size; +} + +std::tuple create_program( + tt_metal::Device *device, + const CoreRangeSet &all_dram_reader_cores, + const CoreRangeSet &all_l1_receiver_cores, + const uint32_t &single_tile_size, + const tt::DataFormat &tile_format, + uint32_t num_tiles_cb, + uint32_t num_tiles_per_core, + uint32_t k, + uint32_t n, + uint32_t num_blocks, + uint32_t num_banks, + std::vectorall_dram_reader_cores_ordered, + std::vectorall_l1_writer_cores_ordered, + uint32_t bank_start_id, + const uint32_t &input_buffer_addr) { + tt_metal::Program program = tt_metal::Program(); + + uint32_t start_tile_id = 0; + uint32_t kt = k / 32; + uint32_t nt = n / 32; + uint32_t block_h = kt / num_blocks; + uint32_t block_w = nt / num_banks; + uint32_t block_num_tiles = block_h * block_w; + + // DRAM reader CB + uint32_t reader_cb_index = 0; + uint32_t reader_cb_size = block_h * block_w * single_tile_size * 3; + uint32_t page_size, num_pages; + get_max_page_size_and_num_pages(block_num_tiles, single_tile_size, page_size, num_pages); + + uint32_t reader_cb_addr = device->get_base_allocator_addr(HalMemType::L1); + tt_metal::CircularBufferConfig reader_cb_config = + tt_metal::CircularBufferConfig(reader_cb_size, {{reader_cb_index, tile_format}}) + .set_page_size(reader_cb_index, single_tile_size); + auto reader_cb = tt_metal::CreateCircularBuffer(program, all_dram_reader_cores, reader_cb_config); + + std::vector reader_compile_time_args = { + (std::uint32_t) input_buffer_addr, + (std::uint32_t) start_tile_id, + (std::uint32_t) num_blocks, + (std::uint32_t) num_pages, + (std::uint32_t) block_num_tiles, + (std::uint32_t) page_size, + (std::uint32_t) tt_metal::NOC::RISCV_0_default + }; + + auto reader_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/reader_dram.cpp", + all_dram_reader_cores, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_0, + .noc = tt_metal::NOC::RISCV_0_default, + .noc_mode = tt_metal::NOC_MODE::DM_DYNAMIC_NOC, + .compile_args = reader_compile_time_args}); + + std::vector writer_compile_time_args = { + (std::uint32_t) num_blocks, + (std::uint32_t) num_pages, + (std::uint32_t) block_num_tiles, + (std::uint32_t) page_size, + (std::uint32_t) tt_metal::NOC::RISCV_0_default + }; + + auto writer_kernel = tt_metal::CreateKernel( + program, + "tests/tt_metal/tt_metal/perf_microbenchmark/9_dram_adjacent_read_remote_l1_write/kernels/writer_l1.cpp", + all_dram_reader_cores, + tt_metal::DataMovementConfig{ + .processor = tt_metal::DataMovementProcessor::RISCV_1, + .noc = tt_metal::NOC::RISCV_1_default, + .noc_mode = tt_metal::NOC_MODE::DM_DYNAMIC_NOC, + .compile_args = writer_compile_time_args}); + + std::vector bank_ids; + for (int i=0; i < all_dram_reader_cores_ordered.size(); i++) { + auto core = all_dram_reader_cores_ordered[i]; + uint32_t bank_id = i + bank_start_id; + uint32_t vc = bank_id & 0x1; + + bank_ids.push_back(bank_id); + + for (int j=0; j reader_rt_args = { + (std::uint32_t) bank_id, + (std::uint32_t) vc + }; + + log_info("core: {}, vc: {}", core, vc); + + tt_metal::SetRuntimeArgs(program, reader_kernel, core, reader_rt_args); + + auto writer_core = all_l1_writer_cores_ordered[i]; + auto writer_core_phy = device->worker_core_from_logical_core(writer_core); + + std::vector writer_rt_args = { + (std::uint32_t) (vc + 2) & 0x3, + (std::uint32_t) writer_core_phy.x, + (std::uint32_t) writer_core_phy.y + }; + + tt_metal::SetRuntimeArgs(program, writer_kernel, core, writer_rt_args); + } + return {std::move(program), reader_kernel, reader_cb_addr}; +} + + +bool validation( + tt_metal::Device *device, + tt_metal::Buffer &input_buffer, + std::vector &input_vec, + const uint32_t &num_cores, + std::vector &all_cores, + const uint32_t &num_tiles_per_core, + const uint32_t &cb_addr, + const uint32_t &single_tile_size, + uint32_t num_tiles_cb, + uint32_t df, + uint32_t num_banks, + uint32_t num_blocks, + uint32_t block_h, + uint32_t block_w, + uint32_t num_datum_per_slice) { + + uint32_t core_id = 0; + for (auto core: all_cores) { + std::vector result_vec; + tt_metal::detail::ReadFromDeviceL1( + device, core, cb_addr, num_tiles_cb * single_tile_size, result_vec); + + uint32_t num_datum_per_block = block_h * block_w * num_datum_per_slice; + uint32_t tensor_slice_stride = core_id * num_datum_per_slice; + uint32_t last_block_offset = (num_blocks - 1) * num_datum_per_block * num_banks; + uint32_t start_index = tensor_slice_stride + last_block_offset; + uint32_t num_slices = block_h * block_w; + + if (df == 0) { + auto result_bfp8 = unpack_bfp8_tiles_into_float_vec(result_vec, true, true); + auto input_bfp8 = unpack_bfp8_tiles_into_float_vec(input_vec, true, true); + + for (uint32_t i=0; i < num_slices; ++i) { + uint32_t input_step = start_index + i * num_datum_per_slice * num_banks; + std::vector input_slice(input_bfp8.begin() + input_step, input_bfp8.begin() + input_step + num_datum_per_slice); + uint32_t result_step = i * num_datum_per_slice; + std::vector result_slice(result_bfp8.begin() + result_step, result_bfp8.begin() + result_step + num_datum_per_slice); + + if (input_slice != result_slice) { + return false; + } + } + + } else { + auto result_bf16 = unpack_uint32_vec_into_bfloat16_vec(result_vec); + auto input_bf16 = unpack_uint32_vec_into_bfloat16_vec(input_vec); + + for (uint32_t i=0; i < num_slices; ++i) { + uint32_t input_step = start_index + i * num_datum_per_slice * num_banks; + std::vector input_slice(input_bf16.begin() + input_step, input_bf16.begin() + input_step + num_datum_per_slice); + uint32_t result_step = i * num_datum_per_slice; + std::vector result_slice(result_bf16.begin() + result_step, result_bf16.begin() + result_step + num_datum_per_slice); + + if (input_slice != result_slice) { + return false; + } + } + } + core_id ++; + } + return true; +} + +uint32_t get_dram_bandwidth(tt::ARCH arch) { + constexpr uint32_t GS_DRAM_BANDWIDTH_GB_PER_SEC = 100; + constexpr uint32_t WH_DRAM_BANDWIDTH_GB_PER_SEC = 384; + + uint32_t dram_bandwidth_gb_per_sec = 0; + if (arch == tt::ARCH::WORMHOLE || arch == tt::ARCH::WORMHOLE_B0) { + dram_bandwidth_gb_per_sec = WH_DRAM_BANDWIDTH_GB_PER_SEC; + } else if (arch == tt::ARCH::GRAYSKULL) { + dram_bandwidth_gb_per_sec = GS_DRAM_BANDWIDTH_GB_PER_SEC; + } + return dram_bandwidth_gb_per_sec; +} + + +void get_dram_reader_core_coords_blackhole( + tt_metal::Device* device, CoreRangeSet& all_cores, std::vector& all_cores_ordered) { + + const metal_SocDescriptor& soc_d = tt::Cluster::instance().get_soc_desc(device->id()); + uint32_t full_grid_size_x = soc_d.grid_size.x; + + // get all the logical coord + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + // get dram banks and coords + uint32_t num_banks = device->num_dram_channels(); + uint32_t max_bank_id = num_banks - 1; + std::vector dram_coord_phy; + for (int i = 0; i < num_banks; ++i) { + dram_coord_phy.push_back(device->dram_core_from_dram_channel(i)); + } + + // get worker logical coords + std::vector all_worker_cores_logical; + for (int i = 0; i < num_cores_x; ++i) { + for (int j = 0; j < num_cores_y; ++j) { + all_worker_cores_logical.push_back(CoreCoord(i, j)); + } + } + + // get x coords of the workers + std::vector all_worker_cores_x_physical; + for (int i = 0; i < num_cores_x; ++i) { + auto core_phy = device->worker_core_from_logical_core(CoreCoord(i, 0)); + all_worker_cores_x_physical.push_back(core_phy.x); + } + + // get the harvested rows, we treat dram and eth cores as harvested as well + std::vector harvested_cols; + for (int i = 0; i < full_grid_size_x; ++i) { + auto x = i; + + if (std::find(all_worker_cores_x_physical.begin(), all_worker_cores_x_physical.end(), x) == + all_worker_cores_x_physical.end()) { + harvested_cols.push_back(x); + } + } + + // get the ajacent cores of DRAM banks + std::vector adj_core_physical; + for (int i = 0; i < num_banks; ++i) { + auto dram_core = dram_coord_phy[i]; + uint32_t adj_core_x = dram_core.x + 1; + uint32_t adj_core_y = dram_core.y; + adj_core_physical.push_back(CoreCoord(adj_core_x, adj_core_y)); + } + + // move worker if they are in the harvested cols + for (auto& coord : adj_core_physical) { + auto x = coord.x; + + // if row is harvested, move core down by 1 + while (std::find(harvested_cols.begin(), harvested_cols.end(), x) != harvested_cols.end() and x < (full_grid_size_x - 1)) { + x += 1; + } + + coord.x = x; + } + + // find the logical coord from physical coord + std::vector adj_core_logical_realloc; + for (int i = 0; i < adj_core_physical.size(); ++i) { + for (int j = 0; j < all_worker_cores_logical.size(); ++j) { + auto core = device->worker_core_from_logical_core(all_worker_cores_logical[j]); + if (adj_core_physical[i] == core) { + adj_core_logical_realloc.push_back(all_worker_cores_logical[j]); + } + } + } + + // create sets + std::set all_cores_set; + for (int i = 0; i < num_banks; ++i) { + all_cores_set.insert(CoreRange(adj_core_logical_realloc[i])); + } + all_cores = CoreRangeSet(all_cores_set); + all_cores_ordered = adj_core_logical_realloc; +} + + +void get_l1_writer_core_coords_blackhole( + tt_metal::Device* device, std::vector& all_dram_reader_cores, CoreRangeSet& all_cores, std::vector& all_cores_ordered) { + + const metal_SocDescriptor& soc_d = tt::Cluster::instance().get_soc_desc(device->id()); + uint32_t full_grid_size_x = soc_d.grid_size.x; + + // get all the logical coord + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + // get worker logical coords + std::vector all_worker_cores_logical; + for (int i = 0; i < num_cores_x; ++i) { + for (int j = 0; j < num_cores_y; ++j) { + all_worker_cores_logical.push_back(CoreCoord(i, j)); + } + } + + // get x coords of the workers + std::vector all_worker_cores_x_physical; + for (int i = 0; i < num_cores_x; ++i) { + auto core_phy = device->worker_core_from_logical_core(CoreCoord(i, 0)); + all_worker_cores_x_physical.push_back(core_phy.x); + } + + // get the harvested rows, we treat dram and eth cores as harvested as well + std::vector harvested_cols; + for (int i = 0; i < full_grid_size_x; ++i) { + auto x = i; + + if (std::find(all_worker_cores_x_physical.begin(), all_worker_cores_x_physical.end(), x) == + all_worker_cores_x_physical.end()) { + harvested_cols.push_back(x); + } + } + + // get the ajacent cores of DRAM readers, for grayskull the l1 writers are below DRAM readers + std::vector adj_core_physical; + for (int i = 0; i < all_dram_reader_cores.size(); ++i) { + auto dram_reader_core = all_dram_reader_cores[i]; + auto dram_reader_core_phy = device->worker_core_from_logical_core(dram_reader_core); + uint32_t adj_core_x = dram_reader_core_phy.x + 1; + uint32_t adj_core_y = dram_reader_core_phy.y; + adj_core_physical.push_back(CoreCoord(adj_core_x, adj_core_y)); + } + + // move worker if they are in the harvested rows + for (auto& coord : adj_core_physical) { + auto x = coord.x; + + // if row is harvested, move core down by 1 + while (std::find(harvested_cols.begin(), harvested_cols.end(), x) != harvested_cols.end() and x < (full_grid_size_x - 1)) { + x += 1; + } + + coord.x = x; + } + + // find the logical coord from physical coord + std::vector adj_core_logical_realloc; + for (int i = 0; i < adj_core_physical.size(); ++i) { + for (int j = 0; j < all_worker_cores_logical.size(); ++j) { + auto core = device->worker_core_from_logical_core(all_worker_cores_logical[j]); + if (adj_core_physical[i] == core) { + adj_core_logical_realloc.push_back(all_worker_cores_logical[j]); + } + } + } + + // create sets + std::set all_cores_set; + for (int i = 0; i < adj_core_logical_realloc.size(); ++i) { + all_cores_set.insert(CoreRange(adj_core_logical_realloc[i])); + } + all_cores = CoreRangeSet(all_cores_set); + all_cores_ordered = adj_core_logical_realloc; +} + +void get_dram_reader_core_coords_grayskull( + tt_metal::Device* device, CoreRangeSet& all_cores, std::vector& all_cores_ordered) { + + const metal_SocDescriptor& soc_d = tt::Cluster::instance().get_soc_desc(device->id()); + uint32_t full_grid_size_y = soc_d.grid_size.y; + + // get all the logical coord + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + // get dram banks and coords + uint32_t num_banks = device->num_dram_channels(); + uint32_t max_bank_id = num_banks - 1; + std::vector dram_coord_phy; + for (int i = 0; i < num_banks; ++i) { + dram_coord_phy.push_back(device->dram_core_from_dram_channel(i)); + } + + // get worker logical coords + std::vector all_worker_cores_logical; + for (int i = 0; i < num_cores_x; ++i) { + for (int j = 0; j < num_cores_y; ++j) { + all_worker_cores_logical.push_back(CoreCoord(i, j)); + } + } + + // get y coords of the workers + std::vector all_worker_cores_y_physical; + for (int i = 0; i < num_cores_y; ++i) { + auto core_phy = device->worker_core_from_logical_core(CoreCoord(0, i)); + all_worker_cores_y_physical.push_back(core_phy.y); + } + + // get the harvested rows, we treat dram and eth cores as harvested as well + std::vector harvested_rows; + for (int i = 0; i < full_grid_size_y; ++i) { + auto y = i; + + if (std::find(all_worker_cores_y_physical.begin(), all_worker_cores_y_physical.end(), y) == + all_worker_cores_y_physical.end()) { + harvested_rows.push_back(y); + } + } + + // get the ajacent cores of DRAM banks + std::vector adj_core_physical; + for (int i = 0; i < num_banks; ++i) { + auto dram_core = dram_coord_phy[i]; + uint32_t adj_core_x = dram_core.x; + uint32_t adj_core_y = dram_core.y + 1; + adj_core_physical.push_back(CoreCoord(adj_core_x, adj_core_y)); + } + + // move worker if they are in the harvested rows + for (auto& coord : adj_core_physical) { + auto y = coord.y; + + // if row is harvested, move core down by 1 + while (std::find(harvested_rows.begin(), harvested_rows.end(), y) != harvested_rows.end() and y < (full_grid_size_y - 1)) { + y += 1; + } + + coord.y = y; + } + + // find the logical coord from physical coord + std::vector adj_core_logical_realloc; + for (int i = 0; i < adj_core_physical.size(); ++i) { + for (int j = 0; j < all_worker_cores_logical.size(); ++j) { + auto core = device->worker_core_from_logical_core(all_worker_cores_logical[j]); + if (adj_core_physical[i] == core) { + adj_core_logical_realloc.push_back(all_worker_cores_logical[j]); + } + } + } + + // create sets + std::set all_cores_set; + for (int i = 0; i < num_banks; ++i) { + all_cores_set.insert(CoreRange(adj_core_logical_realloc[i])); + } + all_cores = CoreRangeSet(all_cores_set); + all_cores_ordered = adj_core_logical_realloc; +} + +void get_l1_writer_core_coords_grayskull( + tt_metal::Device* device, std::vector& all_dram_reader_cores, CoreRangeSet& all_cores, std::vector& all_cores_ordered) { + + const metal_SocDescriptor& soc_d = tt::Cluster::instance().get_soc_desc(device->id()); + uint32_t full_grid_size_y = soc_d.grid_size.y; + + // get all the logical coord + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + // get worker logical coords + std::vector all_worker_cores_logical; + for (int i = 0; i < num_cores_x; ++i) { + for (int j = 0; j < num_cores_y; ++j) { + all_worker_cores_logical.push_back(CoreCoord(i, j)); + } + } + + // get y coords of the workers + std::vector all_worker_cores_y_physical; + for (int i = 0; i < num_cores_y; ++i) { + auto core_phy = device->worker_core_from_logical_core(CoreCoord(0, i)); + all_worker_cores_y_physical.push_back(core_phy.y); + } + + // get the harvested rows, we treat dram and eth cores as harvested as well + std::vector harvested_rows; + for (int i = 0; i < full_grid_size_y; ++i) { + auto y = i; + + if (std::find(all_worker_cores_y_physical.begin(), all_worker_cores_y_physical.end(), y) == + all_worker_cores_y_physical.end()) { + harvested_rows.push_back(y); + } + } + + // get the ajacent cores of DRAM readers, for grayskull the l1 writers are below DRAM readers + std::vector adj_core_physical; + for (int i = 0; i < all_dram_reader_cores.size(); ++i) { + auto dram_reader_core = all_dram_reader_cores[i]; + auto dram_reader_core_phy = device->worker_core_from_logical_core(dram_reader_core); + uint32_t adj_core_x = dram_reader_core_phy.x; + uint32_t adj_core_y = dram_reader_core_phy.y + 1; + adj_core_physical.push_back(CoreCoord(adj_core_x, adj_core_y)); + } + + // move worker if they are in the harvested rows + for (auto& coord : adj_core_physical) { + auto y = coord.y; + + // if row is harvested, move core down by 1 + while (std::find(harvested_rows.begin(), harvested_rows.end(), y) != harvested_rows.end() and y < (full_grid_size_y - 1)) { + y += 1; + } + + coord.y = y; + } + + // find the logical coord from physical coord + std::vector adj_core_logical_realloc; + for (int i = 0; i < adj_core_physical.size(); ++i) { + for (int j = 0; j < all_worker_cores_logical.size(); ++j) { + auto core = device->worker_core_from_logical_core(all_worker_cores_logical[j]); + if (adj_core_physical[i] == core) { + adj_core_logical_realloc.push_back(all_worker_cores_logical[j]); + } + } + } + + // create sets + std::set all_cores_set; + for (int i = 0; i < adj_core_logical_realloc.size(); ++i) { + all_cores_set.insert(CoreRange(adj_core_logical_realloc[i])); + } + all_cores = CoreRangeSet(all_cores_set); + all_cores_ordered = adj_core_logical_realloc; +} + +void get_dram_reader_core_coords_wormhole_b0( + tt_metal::Device* device, CoreRangeSet& all_cores, std::vector& all_cores_ordered) { + + // get all the logical coord + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + // get dram banks and coords + uint32_t num_banks = device->num_dram_channels(); + uint32_t max_bank_id = num_banks - 1; + std::vector dram_coord_phy; dram_coord_phy.reserve(num_banks); + for (int i = 0; i < num_banks; ++i) { + dram_coord_phy.push_back(device->dram_core_from_dram_channel(i)); + } + + // get worker logical coords + std::vector all_worker_cores_logical; all_worker_cores_logical.reserve(num_cores_x * num_cores_y); + for (int i = 0; i < num_cores_x; ++i) { + for (int j = 0; j < num_cores_y; ++j) { + all_worker_cores_logical.push_back(CoreCoord(i, j)); + } + } + + // get the ajacent cores of DRAM banks + std::vector adj_core_physical; adj_core_physical.reserve(num_banks); + for (int i = 0; i < num_banks; ++i) { + auto dram_core = dram_coord_phy[i]; + uint32_t adj_core_x = dram_core.x + 1; + uint32_t adj_core_y = dram_core.y; + adj_core_physical.push_back(CoreCoord(adj_core_x, adj_core_y)); + } + + // find the logical coord from physical coord + std::vector adj_core_logical; adj_core_logical.reserve(num_banks); + for (int i = 0; i < adj_core_physical.size(); ++i) { + for (int j = 0; j < all_worker_cores_logical.size(); ++j) { + auto core = device->worker_core_from_logical_core(all_worker_cores_logical[j]); + if (adj_core_physical[i] == core) { + adj_core_logical.push_back(all_worker_cores_logical[j]); + } + } + } + + // create sets + std::set all_cores_set; + for (int i = 0; i < num_banks; ++i) { + all_cores_set.insert(CoreRange(adj_core_logical[i])); + } + all_cores = CoreRangeSet(all_cores_set); + all_cores_ordered = adj_core_logical; +} + + +void get_l1_writer_core_coords_wormhole_b0( + tt_metal::Device* device, std::vector& all_dram_reader_cores, CoreRangeSet& all_cores, std::vector& all_cores_ordered) { + + // get all the logical coord + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + + // get worker logical coords + std::vector all_worker_cores_logical; + for (int i = 0; i < num_cores_x; ++i) { + for (int j = 0; j < num_cores_y; ++j) { + all_worker_cores_logical.push_back(CoreCoord(i, j)); + } + } + + // get the ajacent cores of DRAM readers, for wormhole the l1 writers are on the left or right DRAM readers + std::vector adj_core_physical; + for (int i = 0; i < all_dram_reader_cores.size(); ++i) { + auto dram_reader_core = all_dram_reader_cores[i]; + auto dram_reader_core_phy = device->worker_core_from_logical_core(dram_reader_core); + uint32_t adj_core_x = dram_reader_core_phy.x + 1; + uint32_t adj_core_y = dram_reader_core_phy.y; + adj_core_physical.push_back(CoreCoord(adj_core_x, adj_core_y)); + } + + // find the logical coord from physical coord + std::vector adj_core_logical_realloc; + for (int i = 0; i < adj_core_physical.size(); ++i) { + for (int j = 0; j < all_worker_cores_logical.size(); ++j) { + auto core = device->worker_core_from_logical_core(all_worker_cores_logical[j]); + if (adj_core_physical[i] == core) { + adj_core_logical_realloc.push_back(all_worker_cores_logical[j]); + } + } + } + + // create sets + std::set all_cores_set; + for (int i = 0; i < adj_core_logical_realloc.size(); ++i) { + all_cores_set.insert(CoreRange(adj_core_logical_realloc[i])); + } + all_cores = CoreRangeSet(all_cores_set); + all_cores_ordered = adj_core_logical_realloc; +} + +int main(int argc, char **argv) { + if (getenv("TT_METAL_SLOW_DISPATCH_MODE") != nullptr) { + log_error("Test not supported w/ slow dispatch, exiting"); + } + + bool pass = true; + bool use_device_profiler = false; + bool bypass_check = false; + uint32_t df = 0; + std::vector dram_bandwidth; + uint32_t num_tests = 1; + uint32_t num_blocks = 8; + uint64_t k = 8192, n = 128; + uint32_t dram_bandwidth_spec = 0; + uint32_t num_banks = 1; + uint32_t bank_start_id = 1; + + log_info("start DRAM benchmark"); + + try { + //////////////////////////////////////////////////////////////////////////// + // Initial Runtime Args Parse + //////////////////////////////////////////////////////////////////////////// + std::vector input_args(argv, argv + argc); + try { + std::tie(k, input_args) = + test_args::get_command_option_uint64_and_remaining_args(input_args, "--k", 8192); + + std::tie(n, input_args) = + test_args::get_command_option_uint64_and_remaining_args(input_args, "--n", 12*128); + + std::tie(num_blocks, input_args) = + test_args::get_command_option_uint64_and_remaining_args(input_args, "--num-blocks", 8); + + std::tie(num_tests, input_args) = + test_args::get_command_option_uint32_and_remaining_args(input_args, "--num-tests", 1); + + std::tie(use_device_profiler, input_args) = + test_args::has_command_option_and_remaining_args(input_args, "--use-device-profiler"); + + std::tie(bypass_check, input_args) = + test_args::has_command_option_and_remaining_args(input_args, "--bypass-check"); + + std::tie(df, input_args) = + test_args::get_command_option_uint32_and_remaining_args(input_args, "--data-type", 0); + + std::tie(num_banks, input_args) = + test_args::get_command_option_uint32_and_remaining_args(input_args, "--num-banks", 12); + + std::tie(bank_start_id, input_args) = + test_args::get_command_option_uint32_and_remaining_args(input_args, "--bank-start-id", 0); + + test_args::validate_remaining_args(input_args); + } catch (const std::exception &e) { + log_error(tt::LogTest, "Command line arguments found exception", e.what()); + TT_ASSERT(false); + } + + if (use_device_profiler) { +#if !defined(TRACY_ENABLE) + log_error( + LogTest, + "Metal library and test code should be build with " + "profiler option using ./scripts/build_scripts/build_with_profiler_opt.sh"); +#endif + auto device_profiler = getenv("TT_METAL_DEVICE_PROFILER"); + TT_FATAL( + device_profiler, + "Before running the program, do one of the following in a shell: " + "either export the environment variable by executing export TT_METAL_DEVICE_PROFILER=1, " + "or run the program with TT_METAL_DEVICE_PROFILER=1 prefixed to the command"); + } + + //////////////////////////////////////////////////////////////////////////// + // Parameters Setup + //////////////////////////////////////////////////////////////////////////// + uint32_t input_size = 0; + tt::DataFormat tile_format = tt::DataFormat::Bfp8_b; + if (df == 0) { + input_size = k * n * 1088 / 1024; + tile_format = tt::DataFormat::Bfp8_b; + } else if (df == 1) { + input_size = k * n * 2; + tile_format = tt::DataFormat::Float16_b; + } else { + TT_THROW("Input data format {} is invalid. Please change.", df); + } + uint32_t kt = k / 32; + uint32_t nt = n / 32; + uint32_t block_h = kt / num_blocks; + uint32_t block_w = nt / num_banks; + uint32_t num_datum_per_slice = 32 * 32; + + uint32_t single_tile_size = tt_metal::detail::TileSize(tile_format); + if (input_size % single_tile_size != 0) { + auto align_to_single_tile = [=](uint64_t value) -> uint64_t { + return ((value + (single_tile_size - 1)) / single_tile_size) * single_tile_size; + }; + + auto input_size_aligned = align_to_single_tile(input_size); + log_info(LogTest, "input size {} is aligned to {} bytes", input_size, input_size_aligned); + input_size = input_size_aligned; + } + //////////////////////////////////////////////////////////////////////////// + // Device Setup + //////////////////////////////////////////////////////////////////////////// + int device_id = 0; + tt_metal::Device *device = tt_metal::CreateDevice(device_id); + dram_bandwidth_spec = get_dram_bandwidth(device->arch()); + + TT_ASSERT(device->arch() == ARCH::WORMHOLE_B0, "device must be wh_b0"); + + int clock_freq_mhz = get_tt_npu_clock(device); + + uint32_t num_tiles = static_cast((input_size + single_tile_size - 1) / single_tile_size); + uint32_t num_cores = num_banks; // number of DRAM banks + + CoreRangeSet all_dram_reader_cores = CoreRangeSet{{}}; + std::vector all_dram_reader_cores_ordered; + CoreRangeSet all_l1_receiver_cores = CoreRangeSet{{}}; + std::vector all_l1_writer_cores_ordered; + if (device->arch() == tt::ARCH::BLACKHOLE) { + get_dram_reader_core_coords_blackhole(device, all_dram_reader_cores, all_dram_reader_cores_ordered); + get_l1_writer_core_coords_blackhole(device, all_dram_reader_cores_ordered, all_l1_receiver_cores, all_l1_writer_cores_ordered); + } else if (device->arch() == tt::ARCH::WORMHOLE_B0) { + get_dram_reader_core_coords_wormhole_b0(device, all_dram_reader_cores, all_dram_reader_cores_ordered); + get_l1_writer_core_coords_wormhole_b0(device, all_dram_reader_cores_ordered, all_l1_receiver_cores, all_l1_writer_cores_ordered); + } else { + get_dram_reader_core_coords_grayskull(device, all_dram_reader_cores, all_dram_reader_cores_ordered); + get_l1_writer_core_coords_grayskull(device, all_dram_reader_cores_ordered, all_l1_receiver_cores, all_l1_writer_cores_ordered); + } + + uint32_t num_tiles_per_core = num_tiles / num_cores; + uint32_t num_tiles_cb = num_tiles_per_core / num_blocks; + + log_info("all_dram_reader_cores"); + for (auto core: all_dram_reader_cores_ordered) { + auto phys_core = device->worker_core_from_logical_core(core); + log_info("logical core: {}, physical core: {}", core, phys_core); + } + log_info("all_l1_writer_cores"); + for (auto core: all_l1_writer_cores_ordered) { + auto phys_core = device->worker_core_from_logical_core(core); + log_info("logical core: {}, physical core: {}", core, phys_core); + } + + log_info( + LogTest, + "Measuring DRAM bandwidth for input_size = {} bytes ({:.3f} MB, " + "{} tiles), using {} cores", + input_size, + static_cast(input_size) / 1024 / 1024, + num_tiles, + num_cores); + + //////////////////////////////////////////////////////////////////////////// + // Input Setup + //////////////////////////////////////////////////////////////////////////// + std::vector input_vec; + if (tile_format == tt::DataFormat::Bfp8_b) { + // input_vec = create_constant_vector_of_bfp8( + // input_size, 100, true); + input_vec = create_random_vector_of_bfp8( + input_size, true, 100, 1234); + } else { + // input_vec = create_constant_vector_of_bfloat16( + // input_size * total_banks / num_banks, 100); + input_vec = create_random_vector_of_bfloat16( + input_size, 100, 1234); + } + + tt_metal::Buffer input_buffer( + device, input_vec.size() * sizeof(uint32_t), single_tile_size, tt_metal::BufferType::DRAM); + + //////////////////////////////////////////////////////////////////////////// + // Application Setup + //////////////////////////////////////////////////////////////////////////// + auto [program, kernel, output_cb_addr] = create_program(device, all_dram_reader_cores, all_l1_receiver_cores, single_tile_size, tile_format, num_tiles_cb, num_tiles_per_core, k, n, num_blocks, num_banks, all_dram_reader_cores_ordered, all_l1_writer_cores_ordered, bank_start_id, input_buffer.address()); + + //////////////////////////////////////////////////////////////////////////// + // Copy Input To DRAM or L1 + //////////////////////////////////////////////////////////////////////////// + tt_metal::detail::WriteToBuffer(input_buffer, input_vec); + + //////////////////////////////////////////////////////////////////////////// + // Execution Application + //////////////////////////////////////////////////////////////////////////// + tt_metal::detail::CompileProgram(device, program); + + log_info(LogTest, "Num tests {}", num_tests); + for (uint32_t i = 0; i < num_tests; ++i) { + auto t_begin = std::chrono::steady_clock::now(); + EnqueueProgram(device->command_queue(), program, false); + Finish(device->command_queue()); + tt_metal::DumpDeviceProfileResults(device, program); + auto t_end = std::chrono::steady_clock::now(); + auto elapsed_us = duration_cast(t_end - t_begin).count(); + dram_bandwidth.push_back((input_size / 1024.0 / 1024.0 / 1024.0) / (elapsed_us / 1000.0 / 1000.0)); + log_info( + LogTest, + "Time elapsed for DRAM accesses: {:.3f}ms ({:.3f}GB/s)", + elapsed_us / 1000.0, + dram_bandwidth[i]); + } + + //////////////////////////////////////////////////////////////////////////// + // Validation & Teardown + //////////////////////////////////////////////////////////////////////////// + + pass = validation( + device, + input_buffer, + input_vec, + num_cores, + all_l1_writer_cores_ordered, + num_tiles_per_core, + output_cb_addr, + single_tile_size, + num_tiles_cb, + df, + num_banks, + num_blocks, + block_h, + block_w, + num_datum_per_slice); + + pass &= tt_metal::CloseDevice(device); + } catch (const std::exception &e) { + pass = false; + // Capture the exception error message + log_error(LogTest, "{}", e.what()); + // Capture system call errors that may have returned from driver/kernel + log_error(LogTest, "System error message: {}", std::strerror(errno)); + } + + // Determine if it passes performance goal + auto avg_dram_bandwidth = calculate_average(dram_bandwidth); + if (pass && bypass_check == false) { + // goal is 90% of peak DRAM bandwidth performance + double target_bandwidth = static_cast(dram_bandwidth_spec) * 0.9; + if (avg_dram_bandwidth < target_bandwidth) { + pass = false; + log_error( + LogTest, + "The DRAM bandwidth does not meet the criteria. " + "Current: {:.3f}GB/s, goal: {:.3f}GB/s", + avg_dram_bandwidth, + target_bandwidth); + } + } + + if (pass) { + log_info(LogTest, "Test Passed"); + } else { + log_error(LogTest, "Test Failed"); + } + + return 0; +} diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/CMakeLists.txt b/tests/tt_metal/tt_metal/perf_microbenchmark/CMakeLists.txt index c855cac5c49..94875c6114f 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/CMakeLists.txt @@ -37,6 +37,7 @@ set(PERF_MICROBENCH_TESTS_SRCS 6_dram_offchip/test_dram_offchip.cpp 7_kernel_launch/test_kernel_launch.cpp 8_dram_adjacent_core_read/test_dram_read.cpp + 9_dram_adjacent_read_remote_l1_write/test_dram_read_l1_write.cpp ) foreach (TEST_SRC ${PERF_MICROBENCH_TESTS_SRCS}) diff --git a/tt_metal/hw/firmware/src/brisc.cc b/tt_metal/hw/firmware/src/brisc.cc index 03e0aa8c9ac..a3b22ccfdf1 100644 --- a/tt_metal/hw/firmware/src/brisc.cc +++ b/tt_metal/hw/firmware/src/brisc.cc @@ -341,7 +341,6 @@ int main() { noc_index = 0; risc_init(); device_setup(); - noc_init(MEM_NOC_ATOMIC_RET_VAL_ADDR); // Set ncrisc's resume address to 0 so we know when ncrisc has overwritten it mailboxes->ncrisc_halt.resume_addr = 0; @@ -356,6 +355,8 @@ int main() { mailboxes->go_message.signal = RUN_MSG_DONE; + uint8_t noc_mode; + uint8_t prev_noc_mode = DM_INVALID_NOC; while (1) { init_sync_registers(); reset_ncrisc_with_iram(); @@ -410,6 +411,17 @@ int main() { run_triscs(enables); noc_index = launch_msg_address->kernel_config.brisc_noc_id; + noc_mode = launch_msg_address->kernel_config.brisc_noc_mode; + + // re-initialize the NoCs + if (prev_noc_mode != noc_mode) { + if (noc_mode == DM_DEDICATED_NOC) { + noc_init(MEM_NOC_ATOMIC_RET_VAL_ADDR); + } else { + dynamic_noc_init(); + } + } + prev_noc_mode = noc_mode; uint32_t kernel_config_base = firmware_config_init(mailboxes, ProgrammableCoreType::TENSIX, DISPATCH_CLASS_TENSIX_DM0); uint32_t tt_l1_ptr *cb_l1_base = (uint32_t tt_l1_ptr *)(kernel_config_base + @@ -425,7 +437,9 @@ int main() { RECORD_STACK_USAGE(); } else { // This was not initialized in kernel_init - noc_local_state_init(noc_index); + if (noc_mode == DM_DEDICATED_NOC) { + noc_local_state_init(noc_index); + } } WAYPOINT("D"); diff --git a/tt_metal/hw/firmware/src/brisck.cc b/tt_metal/hw/firmware/src/brisck.cc index fa9a3ca51ab..7b01d4ba354 100644 --- a/tt_metal/hw/firmware/src/brisck.cc +++ b/tt_metal/hw/firmware/src/brisck.cc @@ -18,7 +18,6 @@ #include "tools/profiler/kernel_profiler.hpp" #include -uint8_t noc_index = NOC_INDEX; extern uint32_t __kernel_init_local_l1_base[]; void kernel_launch() { @@ -31,7 +30,12 @@ void kernel_launch() { #else firmware_kernel_common_init((void tt_l1_ptr *)(__kernel_init_local_l1_base)); - noc_local_state_init(noc_index); + if constexpr (NOC_MODE == DM_DEDICATED_NOC) { + noc_local_state_init(NOC_INDEX); + } else { + noc_local_state_init(NOC_0); + noc_local_state_init(NOC_1); + } { DeviceZoneScopedMainChildN("BRISC-KERNEL"); diff --git a/tt_metal/hw/firmware/src/erisck.cc b/tt_metal/hw/firmware/src/erisck.cc index d3f63dab862..d6e916728f5 100644 --- a/tt_metal/hw/firmware/src/erisck.cc +++ b/tt_metal/hw/firmware/src/erisck.cc @@ -15,13 +15,11 @@ #include "noc_nonblocking_api.h" #include "stream_io_map.h" #include "tdma_xmov.h" -#include "dataflow_api.h" #include "debug/dprint.h" #include "tools/profiler/kernel_profiler.hpp" #include -uint8_t noc_index = NOC_INDEX; CBInterface cb_interface[NUM_CIRCULAR_BUFFERS]; diff --git a/tt_metal/hw/firmware/src/idle_erisck.cc b/tt_metal/hw/firmware/src/idle_erisck.cc index b05833770b5..99f000c3de6 100644 --- a/tt_metal/hw/firmware/src/idle_erisck.cc +++ b/tt_metal/hw/firmware/src/idle_erisck.cc @@ -21,14 +21,13 @@ #include -uint8_t noc_index = NOC_INDEX; extern uint32_t __kernel_init_local_l1_base[]; void kernel_launch() { DeviceZoneScopedMainChildN("ERISC-KERNEL"); firmware_kernel_common_init((void tt_l1_ptr *)__kernel_init_local_l1_base); - noc_local_state_init(noc_index); + noc_local_state_init(NOC_INDEX); kernel_main(); } diff --git a/tt_metal/hw/firmware/src/ncrisck.cc b/tt_metal/hw/firmware/src/ncrisck.cc index c86776754f9..f59e2ce313e 100644 --- a/tt_metal/hw/firmware/src/ncrisck.cc +++ b/tt_metal/hw/firmware/src/ncrisck.cc @@ -19,7 +19,6 @@ #include "kernel_includes.hpp" -uint8_t noc_index = NOC_INDEX; uint32_t noc_reads_num_issued[NUM_NOCS]; uint32_t noc_nonposted_writes_num_issued[NUM_NOCS]; @@ -44,7 +43,12 @@ void kernel_launch() { firmware_kernel_common_init((void tt_l1_ptr *)(MEM_NCRISC_INIT_IRAM_L1_BASE + (uint32_t)__kernel_init_local_l1_base - MEM_NCRISC_IRAM_BASE)); #endif - noc_local_state_init(noc_index); + if constexpr (NOC_MODE == DM_DEDICATED_NOC) { + noc_local_state_init(NOC_INDEX); + } else { + noc_local_state_init(NOC_0); + noc_local_state_init(NOC_1); + } kernel_main(); #endif diff --git a/tt_metal/hw/inc/blackhole/noc_nonblocking_api.h b/tt_metal/hw/inc/blackhole/noc_nonblocking_api.h index ccb4c7fa167..8ec6b3921bd 100644 --- a/tt_metal/hw/inc/blackhole/noc_nonblocking_api.h +++ b/tt_metal/hw/inc/blackhole/noc_nonblocking_api.h @@ -7,13 +7,29 @@ #include #include "noc_parameters.h" +#include "dev_msgs.h" //// /*TODO: RT review this file, currently using wormhole b0 copy, check if any changes needed for BH*/ -const uint32_t NCRISC_WR_CMD_BUF = 0; // for large writes -const uint32_t NCRISC_RD_CMD_BUF = 1; // for all reads -const uint32_t NCRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) -const uint32_t NCRISC_AT_CMD_BUF = 3; // for atomics +constexpr uint32_t DYNAMIC_NOC_NCRISC_WR_CMD_BUF = 2; // all writes share cmd buf +constexpr uint32_t DYNAMIC_NOC_NCRISC_WR_REG_CMD_BUF = 2; +constexpr uint32_t DYNAMIC_NOC_NCRISC_AT_CMD_BUF = 2; +constexpr uint32_t DYNAMIC_NOC_NCRISC_RD_CMD_BUF = 3; + +constexpr uint32_t DYNAMIC_NOC_BRISC_WR_CMD_BUF = 0; // all writes share cmd buf +constexpr uint32_t DYNAMIC_NOC_BRISC_WR_REG_CMD_BUF = 0; +constexpr uint32_t DYNAMIC_NOC_BRISC_AT_CMD_BUF = 0; +constexpr uint32_t DYNAMIC_NOC_BRISC_RD_CMD_BUF = 1; + +constexpr uint32_t NCRISC_WR_CMD_BUF = 0; // for large writes +constexpr uint32_t NCRISC_RD_CMD_BUF = 1; // for all reads +constexpr uint32_t NCRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) +constexpr uint32_t NCRISC_AT_CMD_BUF = 3; // for atomics + +constexpr uint32_t BRISC_WR_CMD_BUF = 0; // for large writes +constexpr uint32_t BRISC_RD_CMD_BUF = 1; // for all reads +constexpr uint32_t BRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) +constexpr uint32_t BRISC_AT_CMD_BUF = 3; // for atomics // BH has 64 bit address space but pipegen was not updated to support this so WH scheme of encoding addresses is used (36 bits of address followed by coordinates) // This means that lo and mid registers need to have the address portion while the coordinates go into hi register @@ -191,6 +207,32 @@ inline __attribute__((always_inline)) void noc_init(uint32_t atomic_ret_val) { } } +inline __attribute__((always_inline)) void dynamic_noc_init() { +#pragma GCC unroll 0 + for (int noc = 0; noc < NUM_NOCS; noc++) { + uint32_t noc_id_reg = NOC_CMD_BUF_READ_REG(noc, 0, NOC_NODE_ID); + uint32_t my_x = noc_id_reg & NOC_NODE_ID_MASK; + uint32_t my_y = (noc_id_reg >> NOC_ADDR_NODE_ID_BITS) & NOC_NODE_ID_MASK; + uint64_t xy_local_addr = NOC_XY_ADDR(my_x, my_y, 0); + + uint32_t noc_rd_cmd_field = NOC_CMD_CPY | NOC_CMD_RD | NOC_CMD_RESP_MARKED | NOC_CMD_VC_STATIC | NOC_CMD_STATIC_VC(1); + + // program brisc cmd_buf 0 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_RD_CMD_BUF, NOC_CTRL, noc_rd_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_RD_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program brisc cmd_buf 1 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_WR_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program ncrisc cmd_buf 2 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_RD_CMD_BUF, NOC_CTRL, noc_rd_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_RD_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program ncrisc cmd_buf 3 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_WR_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + } +} + // set noc local memory state for a single kernel from the global state inline __attribute__((always_inline)) void noc_local_state_init(int noc) { noc_reads_num_issued[noc] = NOC_STATUS_READ_REG(noc, NIU_MST_RD_RESP_RECEIVED); @@ -336,6 +378,7 @@ inline __attribute__((always_inline)) void noc_fast_write_dw_inline( } } +template inline __attribute__((always_inline)) void noc_fast_atomic_increment( uint32_t noc, uint32_t cmd_buf, @@ -344,8 +387,17 @@ inline __attribute__((always_inline)) void noc_fast_atomic_increment( uint32_t incr, uint32_t wrap, bool linked, - bool posted = false) { + bool posted = false, + uint32_t atomic_ret_val = 0) { while (!noc_cmd_buf_ready(noc, cmd_buf)); + if constexpr (noc_mode == DM_DYNAMIC_NOC) { + uint32_t noc_id_reg = NOC_CMD_BUF_READ_REG(noc, 0, NOC_NODE_ID); + uint32_t my_x = noc_id_reg & NOC_NODE_ID_MASK; + uint32_t my_y = (noc_id_reg >> NOC_ADDR_NODE_ID_BITS) & NOC_NODE_ID_MASK; + uint64_t atomic_ret_addr = NOC_XY_ADDR(my_x, my_y, atomic_ret_val); + NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_RET_ADDR_LO, (uint32_t)(atomic_ret_addr & 0xFFFFFFFF)); + NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_RET_ADDR_COORDINATE, (uint32_t)(atomic_ret_addr >> NOC_ADDR_COORD_SHIFT)); + } NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_TARG_ADDR_LO, (uint32_t)(addr & 0xFFFFFFFF)); NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_TARG_ADDR_MID, (uint32_t)(addr >> 32) & 0x1000000F); NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_TARG_ADDR_COORDINATE, (uint32_t)(addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index a0d6da80996..e38abdbd062 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -28,11 +28,41 @@ #include "debug/assert.h" #include "dev_msgs.h" +#if defined(KERNEL_BUILD) +constexpr uint8_t noc_index = NOC_INDEX; +constexpr uint8_t noc_mode = NOC_MODE; +#else extern uint8_t noc_index; +constexpr uint8_t noc_mode = DM_DEDICATED_NOC; +#endif extern uint32_t tt_l1_ptr *rta_l1_base; extern uint32_t tt_l1_ptr *crta_l1_base; extern uint32_t tt_l1_ptr *sem_l1_base[]; +#if defined(KERNEL_BUILD) +#if defined(COMPILE_FOR_BRISC) +constexpr uint32_t read_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? BRISC_RD_CMD_BUF : DYNAMIC_NOC_BRISC_RD_CMD_BUF; +constexpr uint32_t write_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? BRISC_WR_CMD_BUF : DYNAMIC_NOC_BRISC_WR_CMD_BUF; +constexpr uint32_t write_reg_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? BRISC_WR_REG_CMD_BUF : DYNAMIC_NOC_BRISC_WR_REG_CMD_BUF; +constexpr uint32_t write_at_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? BRISC_AT_CMD_BUF : DYNAMIC_NOC_BRISC_AT_CMD_BUF; +#elif defined(COMPILE_FOR_NCRISC) +constexpr uint32_t read_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_RD_CMD_BUF : DYNAMIC_NOC_NCRISC_RD_CMD_BUF; +constexpr uint32_t write_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_WR_CMD_BUF : DYNAMIC_NOC_NCRISC_WR_CMD_BUF; +constexpr uint32_t write_reg_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_WR_REG_CMD_BUF : DYNAMIC_NOC_NCRISC_WR_REG_CMD_BUF; +constexpr uint32_t write_at_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_AT_CMD_BUF : DYNAMIC_NOC_NCRISC_AT_CMD_BUF; +#else // use the default cmf buffers for compute/eth +constexpr uint32_t read_cmd_buf __attribute__((used)) = NCRISC_RD_CMD_BUF; +constexpr uint32_t write_cmd_buf __attribute__((used)) = NCRISC_WR_CMD_BUF; +constexpr uint32_t write_reg_cmd_buf __attribute__((used)) = NCRISC_WR_REG_CMD_BUF; +constexpr uint32_t write_at_cmd_buf __attribute__((used)) = NCRISC_AT_CMD_BUF; +#endif +#else // FW build +constexpr uint32_t read_cmd_buf __attribute__((used)) = NCRISC_RD_CMD_BUF; +constexpr uint32_t write_cmd_buf __attribute__((used)) = NCRISC_WR_CMD_BUF; +constexpr uint32_t write_reg_cmd_buf __attribute__((used)) = NCRISC_WR_REG_CMD_BUF; +constexpr uint32_t write_at_cmd_buf __attribute__((used)) = NCRISC_AT_CMD_BUF; +#endif + /** @file */ /** @@ -88,11 +118,11 @@ uint32_t get_bank_index(uint32_t id, uint32_t bank_offset_index) { template FORCE_INLINE -uint32_t get_noc_xy(uint32_t bank_index) { +uint32_t get_noc_xy(uint32_t bank_index, uint8_t noc = noc_index) { if constexpr (DRAM) { // DRAM - return dram_bank_to_noc_xy[noc_index][bank_index]; + return dram_bank_to_noc_xy[noc][bank_index]; } else { // L1 - return l1_bank_to_noc_xy[noc_index][bank_index]; + return l1_bank_to_noc_xy[noc][bank_index]; } } @@ -490,22 +520,23 @@ std::uint64_t get_noc_multicast_addr( std::uint32_t noc_y_start, std::uint32_t noc_x_end, std::uint32_t noc_y_end, - std::uint32_t addr) { + std::uint32_t addr, + uint8_t noc = noc_index) { /* Get an encoding which contains tensix core and address you want to read from/write to via the noc */ - return NOC_MULTICAST_ADDR(NOC_X(noc_x_start), NOC_Y(noc_y_start), NOC_X(noc_x_end), NOC_Y(noc_y_end), addr); + return NOC_MULTICAST_ADDR(DYNAMIC_NOC_X(noc, noc_x_start), DYNAMIC_NOC_Y(noc, noc_y_start), DYNAMIC_NOC_X(noc, noc_x_end), DYNAMIC_NOC_Y(noc, noc_y_end), addr); } FORCE_INLINE -std::uint64_t get_noc_addr(std::uint32_t noc_x, std::uint32_t noc_y, std::uint32_t addr) { +std::uint64_t get_noc_addr(std::uint32_t noc_x, std::uint32_t noc_y, std::uint32_t addr, uint8_t noc = noc_index) { /* Get an encoding which contains tensix core and address you want to write to via the noc multicast */ - return NOC_XY_ADDR(NOC_X(noc_x), NOC_Y(noc_y), addr); + return NOC_XY_ADDR(DYNAMIC_NOC_X(noc, noc_x), DYNAMIC_NOC_Y(noc, noc_y), addr); } /* @@ -523,38 +554,38 @@ std::uint64_t get_noc_addr_helper(std::uint32_t noc_xy, std::uint32_t addr) { -uint64_t get_dram_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t bank_base_address, const uint32_t offset = 0) { +uint64_t get_dram_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t bank_base_address, const uint32_t offset = 0, uint8_t noc = noc_index) { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t addr = (bank_offset_index * align(page_size, ALLOCATOR_ALIGNMENT)) + bank_base_address + offset + bank_to_dram_offset[bank_index]; - uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); uint64_t noc_addr = get_noc_addr_helper(noc_xy, addr); return noc_addr; } -uint64_t get_l1_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t bank_base_address, const uint32_t offset = 0) { +uint64_t get_l1_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t bank_base_address, const uint32_t offset = 0, uint8_t noc = noc_index) { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t addr = (bank_offset_index * align(page_size, ALLOCATOR_ALIGNMENT)) + bank_base_address + offset + bank_to_dram_offset[bank_index]; - uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); uint64_t noc_addr = get_noc_addr_helper(noc_xy, addr); return noc_addr; } -uint64_t get_system_memory_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t base_addr, const uint32_t offset = 0) { - uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(NOC_X(PCIE_NOC_X), NOC_Y(PCIE_NOC_Y), noc_index)); +uint64_t get_system_memory_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t base_addr, const uint32_t offset = 0, uint8_t noc = noc_index) { + uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(DYNAMIC_NOC_X(noc, PCIE_NOC_X), DYNAMIC_NOC_Y(noc, PCIE_NOC_Y), noc)); uint32_t addr = base_addr + page_size * id + offset; uint64_t noc_addr = pcie_core_noc_encoding | addr; return noc_addr; } FORCE_INLINE -std::uint64_t get_noc_addr(std::uint32_t addr) { +std::uint64_t get_noc_addr(std::uint32_t addr, uint8_t noc = noc_index) { /* Get an encoding which contains the address in L1 on the current core that you want to read from/write to via the noc */ - return NOC_XY_ADDR(my_x[noc_index], my_y[noc_index], addr); + return NOC_XY_ADDR(my_x[noc], my_y[noc], addr); } /** @@ -574,43 +605,43 @@ std::uint64_t get_noc_addr(std::uint32_t addr) { * | size | Size of data transfer in bytes | uint32_t | 0..1MB | Yes | */ inline -void noc_async_read(std::uint64_t src_noc_addr, std::uint32_t dst_local_l1_addr, std::uint32_t size) { +void noc_async_read(std::uint64_t src_noc_addr, std::uint32_t dst_local_l1_addr, std::uint32_t size, uint8_t noc = noc_index) { /* Read requests - use static VC Read responses - assigned VCs dynamically */ WAYPOINT("NARW"); - DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc_index, src_noc_addr, dst_local_l1_addr, size); - ncrisc_noc_fast_read_any_len(noc_index, NCRISC_RD_CMD_BUF, src_noc_addr, dst_local_l1_addr, size); + DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc, src_noc_addr, dst_local_l1_addr, size); + ncrisc_noc_fast_read_any_len(noc, read_cmd_buf, src_noc_addr, dst_local_l1_addr, size); WAYPOINT("NARD"); } // TODO: write docs // this issues only a single packet with size <= NOC_MAX_BURST_SIZE (ie maximum packet size) FORCE_INLINE -void noc_async_read_one_packet(std::uint64_t src_noc_addr, std::uint32_t dst_local_l1_addr, std::uint32_t size) { +void noc_async_read_one_packet(std::uint64_t src_noc_addr, std::uint32_t dst_local_l1_addr, std::uint32_t size, uint8_t noc = noc_index) { /* Read requests - use static VC Read responses - assigned VCs dynamically */ WAYPOINT("RPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("RPD"); WAYPOINT("NARW"); - DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc_index, src_noc_addr, dst_local_l1_addr, size); + DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc, src_noc_addr, dst_local_l1_addr, size); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dst_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, (uint32_t)src_noc_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_RET_ADDR_LO, dst_local_l1_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_LO, (uint32_t)src_noc_addr); #ifdef ARCH_BLACKHOLE // Handles reading from PCIe - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_MID, (uint32_t)(src_noc_addr >> 32) & 0x1000000F); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_MID, (uint32_t)(src_noc_addr >> 32) & 0x1000000F); #endif - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(src_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, size); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_reads_num_issued[noc_index] += 1; + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_COORDINATE, (uint32_t)(src_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_AT_LEN_BE, size); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_reads_num_issued[noc] += 1; WAYPOINT("NARD"); } @@ -618,24 +649,24 @@ void noc_async_read_one_packet(std::uint64_t src_noc_addr, std::uint32_t dst_loc // TODO: write docs // this issues only a single packet with size <= NOC_MAX_BURST_SIZE (ie maximum packet size) FORCE_INLINE -void noc_async_read_one_packet_set_state(std::uint64_t src_noc_addr, std::uint32_t size) { +void noc_async_read_one_packet_set_state(std::uint64_t src_noc_addr, std::uint32_t size, uint8_t noc = noc_index) { /* Read requests - use static VC Read responses - assigned VCs dynamically */ WAYPOINT("RPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("RPD"); WAYPOINT("NARW"); #ifdef ARCH_BLACKHOLE // Handles reading from PCIe - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_MID, (uint32_t)(src_noc_addr >> 32) & 0x1000000F); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_MID, (uint32_t)(src_noc_addr >> 32) & 0x1000000F); #endif - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(src_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, size); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_COORDINATE, (uint32_t)(src_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_AT_LEN_BE, size); WAYPOINT("NARD"); } @@ -644,27 +675,27 @@ void noc_async_read_one_packet_set_state(std::uint64_t src_noc_addr, std::uint32 // this issues only a single packet with size <= NOC_MAX_BURST_SIZE (ie maximum packet size) template FORCE_INLINE -void noc_async_read_one_packet_with_state(std::uint32_t src_noc_addr, std::uint32_t dst_local_l1_addr) { +void noc_async_read_one_packet_with_state(std::uint32_t src_noc_addr, std::uint32_t dst_local_l1_addr, uint8_t noc = noc_index) { /* Read requests - use static VC Read responses - assigned VCs dynamically */ WAYPOINT("RPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("RPD"); WAYPOINT("NARW"); // In order to sanitize, need to grab full noc addr + xfer size from state. - DEBUG_SANITIZE_NOC_READ_TRANSACTION_WITH_ADDR_AND_SIZE_STATE(noc_index, src_noc_addr, dst_local_l1_addr); + DEBUG_SANITIZE_NOC_READ_TRANSACTION_WITH_ADDR_AND_SIZE_STATE(noc, src_noc_addr, dst_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dst_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, src_noc_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_RET_ADDR_LO, dst_local_l1_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_LO, src_noc_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); if constexpr (inc_num_issued) { - noc_reads_num_issued[noc_index] += 1; + noc_reads_num_issued[noc] += 1; } WAYPOINT("NARD"); @@ -672,7 +703,7 @@ void noc_async_read_one_packet_with_state(std::uint32_t src_noc_addr, std::uint3 // TODO: write docs FORCE_INLINE -void noc_async_read_set_state(std::uint64_t src_noc_addr) { +void noc_async_read_set_state(std::uint64_t src_noc_addr, uint8_t noc = noc_index) { /* Read requests - use static VC Read responses - assigned VCs dynamically @@ -680,14 +711,14 @@ void noc_async_read_set_state(std::uint64_t src_noc_addr) { WAYPOINT("NARW"); WAYPOINT("RPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("RPD"); #ifdef ARCH_BLACKHOLE // Handles reading from PCIe - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_MID, (uint32_t)(src_noc_addr >> 32) & 0x1000000F); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_MID, (uint32_t)(src_noc_addr >> 32) & 0x1000000F); #endif - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(src_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_COORDINATE, (uint32_t)(src_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); WAYPOINT("NARD"); } @@ -695,7 +726,7 @@ void noc_async_read_set_state(std::uint64_t src_noc_addr) { // TODO: write docs template FORCE_INLINE -void noc_async_read_with_state(std::uint32_t src_noc_addr, std::uint32_t dst_local_l1_addr, std::uint32_t size) { +void noc_async_read_with_state(std::uint32_t src_noc_addr, std::uint32_t dst_local_l1_addr, std::uint32_t size, uint8_t noc = noc_index) { /* Read requests - use static VC Read responses - assigned VCs dynamically @@ -703,54 +734,54 @@ void noc_async_read_with_state(std::uint32_t src_noc_addr, std::uint32_t dst_loc WAYPOINT("NARW"); // In order to sanitize, need to grab full noc addr + xfer size from state. - DEBUG_SANITIZE_NOC_READ_TRANSACTION_WITH_ADDR_STATE(noc_index, src_noc_addr, dst_local_l1_addr, size); + DEBUG_SANITIZE_NOC_READ_TRANSACTION_WITH_ADDR_STATE(noc, src_noc_addr, dst_local_l1_addr, size); while (size > NOC_MAX_BURST_SIZE) { WAYPOINT("RPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("RPD"); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dst_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, src_noc_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, NOC_MAX_BURST_SIZE); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_RET_ADDR_LO, dst_local_l1_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_LO, src_noc_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_AT_LEN_BE, NOC_MAX_BURST_SIZE); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); size -= NOC_MAX_BURST_SIZE; src_noc_addr += NOC_MAX_BURST_SIZE; dst_local_l1_addr += NOC_MAX_BURST_SIZE; if constexpr (inc_num_issued) { - noc_reads_num_issued[noc_index] += 1; + noc_reads_num_issued[noc] += 1; } } // left-over packet WAYPOINT("RPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("RPD"); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dst_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, src_noc_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, size); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_RET_ADDR_LO, dst_local_l1_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_LO, src_noc_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_AT_LEN_BE, size); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); if constexpr (inc_num_issued) { - noc_reads_num_issued[noc_index] += 1; + noc_reads_num_issued[noc] += 1; } WAYPOINT("NARD"); } FORCE_INLINE -void noc_async_read_inc_num_issued(std::uint32_t num_issued_reads_inc) { - noc_reads_num_issued[noc_index] += num_issued_reads_inc; +void noc_async_read_inc_num_issued(std::uint32_t num_issued_reads_inc, uint8_t noc = noc_index) { + noc_reads_num_issued[noc] += num_issued_reads_inc; } // TODO: write docs // this issues only a single packet with size <= NOC_MAX_BURST_SIZE (ie maximum packet size) FORCE_INLINE -void noc_async_write_one_packet(std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr, std::uint32_t size) { +void noc_async_write_one_packet(std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr, std::uint32_t size, uint8_t noc = noc_index) { WAYPOINT("NWPW"); - DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc_index, dst_noc_addr, src_local_l1_addr, size); - while (!noc_cmd_buf_ready(noc_index, NCRISC_WR_CMD_BUF)); + DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc, dst_noc_addr, src_local_l1_addr, size); + while (!noc_cmd_buf_ready(noc, write_cmd_buf)); WAYPOINT("NWPD"); uint32_t noc_cmd_field = NOC_CMD_CPY | NOC_CMD_WR | NOC_CMD_VC_STATIC | @@ -758,18 +789,18 @@ void noc_async_write_one_packet(std::uint32_t src_local_l1_addr, std::uint64_t d 0x0 | // (mcast ? (NOC_CMD_PATH_RESERVE | NOC_CMD_BRCST_PACKET) : 0x0) NOC_CMD_RESP_MARKED; - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CTRL, noc_cmd_field); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_TARG_ADDR_LO, src_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_LO, (uint32_t)dst_noc_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CTRL, noc_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_TARG_ADDR_LO, src_local_l1_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_LO, (uint32_t)dst_noc_addr); #ifdef ARCH_BLACKHOLE // Handles writing to PCIe - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_MID, (uint32_t)(dst_noc_addr >> 32) & 0x1000000F); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_MID, (uint32_t)(dst_noc_addr >> 32) & 0x1000000F); #endif - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(dst_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_AT_LEN_BE, size); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_nonposted_writes_num_issued[noc_index] += 1; - noc_nonposted_writes_acked[noc_index] += 1; // num_dests + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_COORDINATE, (uint32_t)(dst_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_AT_LEN_BE, size); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_nonposted_writes_num_issued[noc] += 1; + noc_nonposted_writes_acked[noc] += 1; // num_dests } // TODO: write docs @@ -781,10 +812,11 @@ void noc_async_write_multicast_one_packet( std::uint32_t size, std::uint32_t num_dests, bool linked = false, - bool multicast_path_reserve = true) { + bool multicast_path_reserve = true, + uint8_t noc = noc_index) { WAYPOINT("NMPW"); - DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc_index, dst_noc_addr_multicast, src_local_l1_addr, size); - while (!noc_cmd_buf_ready(noc_index, NCRISC_WR_CMD_BUF)); + DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc, dst_noc_addr_multicast, src_local_l1_addr, size); + while (!noc_cmd_buf_ready(noc, write_cmd_buf)); WAYPOINT("NWPD"); uint32_t noc_cmd_field = @@ -795,64 +827,64 @@ void noc_async_write_multicast_one_packet( ((multicast_path_reserve ? NOC_CMD_PATH_RESERVE : 0) | NOC_CMD_BRCST_PACKET) | NOC_CMD_RESP_MARKED; - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CTRL, noc_cmd_field); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_TARG_ADDR_LO, src_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_LO, (uint32_t)dst_noc_addr_multicast); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CTRL, noc_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_TARG_ADDR_LO, src_local_l1_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_LO, (uint32_t)dst_noc_addr_multicast); #ifdef ARCH_BLACKHOLE // Handles writing to PCIe - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_MID, (uint32_t)(dst_noc_addr_multicast >> 32) & 0x1000000F); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_MID, (uint32_t)(dst_noc_addr_multicast >> 32) & 0x1000000F); #endif - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(dst_noc_addr_multicast >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_AT_LEN_BE, size); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_nonposted_writes_num_issued[noc_index] += 1; - noc_nonposted_writes_acked[noc_index] += num_dests; + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_COORDINATE, (uint32_t)(dst_noc_addr_multicast >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_AT_LEN_BE, size); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_nonposted_writes_num_issued[noc] += 1; + noc_nonposted_writes_acked[noc] += num_dests; } // TODO: write docs // this sets the state for issuing a single packet with size <= NOC_MAX_BURST_SIZE (ie maximum packet size) template FORCE_INLINE -void noc_async_write_one_packet_set_state(std::uint64_t dst_noc_addr, std::uint32_t size) { +void noc_async_write_one_packet_set_state(std::uint64_t dst_noc_addr, std::uint32_t size, uint8_t noc = noc_index, uint8_t vc = NOC_UNICAST_WRITE_VC) { WAYPOINT("NWPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_WR_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, write_cmd_buf)); WAYPOINT("NWPD"); uint32_t noc_cmd_field = NOC_CMD_CPY | NOC_CMD_WR | NOC_CMD_VC_STATIC | - NOC_CMD_STATIC_VC(NOC_UNICAST_WRITE_VC) | 0x0 | // (linked ? NOC_CMD_VC_LINKED : 0x0) + NOC_CMD_STATIC_VC(vc) | 0x0 | // (linked ? NOC_CMD_VC_LINKED : 0x0) 0x0 | // (mcast ? (NOC_CMD_PATH_RESERVE | NOC_CMD_BRCST_PACKET) : 0x0) (non_posted ? NOC_CMD_RESP_MARKED : 0x0); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CTRL, noc_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CTRL, noc_cmd_field); #ifdef ARCH_BLACKHOLE // Handles writing to PCIe - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_MID, (uint32_t)(dst_noc_addr >> 32) & 0x1000000F); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_MID, (uint32_t)(dst_noc_addr >> 32) & 0x1000000F); #endif - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(dst_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_AT_LEN_BE, size); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_COORDINATE, (uint32_t)(dst_noc_addr >> NOC_ADDR_COORD_SHIFT) & NOC_COORDINATE_MASK); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_AT_LEN_BE, size); } // TODO: write docs // this issues only a single packet with cmd buf state with size <= NOC_MAX_BURST_SIZE (ie maximum packet size) template FORCE_INLINE -void noc_async_write_one_packet_with_state(std::uint32_t src_local_l1_addr, std::uint32_t dst_noc_addr) { +void noc_async_write_one_packet_with_state(std::uint32_t src_local_l1_addr, std::uint32_t dst_noc_addr, uint8_t noc = noc_index) { WAYPOINT("NWPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_WR_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, write_cmd_buf)); WAYPOINT("NWPD"); // In order to sanitize, need to grab full noc addr + xfer size from state. - DEBUG_SANITIZE_NOC_WRITE_TRANSACTION_WITH_ADDR_AND_SIZE_STATE(noc_index, dst_noc_addr, src_local_l1_addr); + DEBUG_SANITIZE_NOC_WRITE_TRANSACTION_WITH_ADDR_AND_SIZE_STATE(noc, dst_noc_addr, src_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_TARG_ADDR_LO, src_local_l1_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_LO, dst_noc_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_TARG_ADDR_LO, src_local_l1_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_LO, dst_noc_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); if constexpr (non_posted) { - noc_nonposted_writes_num_issued[noc_index] += 1; - noc_nonposted_writes_acked[noc_index] += 1; // num_dests + noc_nonposted_writes_num_issued[noc] += 1; + noc_nonposted_writes_acked[noc] += 1; // num_dests } } @@ -868,19 +900,19 @@ struct InterleavedAddrGen { } FORCE_INLINE - std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0) const { + std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t addr = this->get_addr(id, bank_offset_index, bank_index, offset); - uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); uint64_t noc_addr = get_noc_addr_helper(noc_xy, addr); return noc_addr; } FORCE_INLINE - void noc_async_read_page(const uint32_t id, const uint32_t dest_addr, const uint32_t offset = 0) const { - noc_async_read(this->get_noc_addr(id, offset), dest_addr, page_size); + void noc_async_read_page(const uint32_t id, const uint32_t dest_addr, const uint32_t offset = 0, uint8_t noc = noc_index) const { + noc_async_read(this->get_noc_addr(id, offset), dest_addr, page_size, noc); } }; @@ -899,11 +931,11 @@ struct InterleavedPow2AddrGen { } FORCE_INLINE - std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0) const { + std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t addr = this->get_addr(id, bank_offset_index, bank_index, offset); - uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); uint64_t noc_addr = get_noc_addr_helper(noc_xy, addr); return noc_addr; @@ -923,46 +955,46 @@ struct InterleavedAddrGenFast { } FORCE_INLINE - std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0) const { + std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t addr = this->get_addr(id, bank_offset_index, bank_index, offset); - uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); uint64_t noc_addr = get_noc_addr_helper(noc_xy, addr); return noc_addr; } FORCE_INLINE - void noc_async_read_tile(const uint32_t id, uint32_t dest_addr, const uint32_t offset = 0) const { + void noc_async_read_tile(const uint32_t id, uint32_t dest_addr, const uint32_t offset = 0, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t src_addr = this->get_addr(id, bank_offset_index, bank_index, offset); - uint32_t src_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t src_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); WAYPOINT("NRTW"); - DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc_index, get_noc_addr_helper(src_noc_xy, src_addr), dest_addr, this->page_size); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc, get_noc_addr_helper(src_noc_xy, src_addr), dest_addr, this->page_size); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("NRTD"); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dest_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, src_addr); // (uint32_t)src_addr - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, this->page_size); // len_bytes - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_reads_num_issued[noc_index] += 1; + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_RET_ADDR_LO, dest_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_LO, src_addr); // (uint32_t)src_addr + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_AT_LEN_BE, this->page_size); // len_bytes + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_reads_num_issued[noc] += 1; } FORCE_INLINE - void noc_async_write_tile(const uint32_t id, uint32_t src_addr) const { + void noc_async_write_tile(const uint32_t id, uint32_t src_addr, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t dest_addr = this->get_addr(id, bank_offset_index, bank_index); - uint32_t dest_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t dest_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); WAYPOINT("NWTW"); - DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc_index, get_noc_addr_helper(dest_noc_xy, dest_addr), src_addr, this->page_size); - while (!noc_cmd_buf_ready(noc_index, NCRISC_WR_CMD_BUF)); + DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc, get_noc_addr_helper(dest_noc_xy, dest_addr), src_addr, this->page_size); + while (!noc_cmd_buf_ready(noc, write_cmd_buf)); WAYPOINT("NWTD"); constexpr uint32_t noc_cmd_field = NOC_CMD_CPY | NOC_CMD_WR | NOC_CMD_VC_STATIC | @@ -970,14 +1002,14 @@ struct InterleavedAddrGenFast { 0x0 | // (mcast ? (NOC_CMD_PATH_RESERVE | NOC_CMD_BRCST_PACKET) : 0x0) NOC_CMD_RESP_MARKED; - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CTRL, noc_cmd_field); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_TARG_ADDR_LO, src_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_LO, dest_addr); // (uint32_t)dest_addr - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_COORDINATE, dest_noc_xy); // dest_addr >> 32 - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_AT_LEN_BE, this->page_size); // len_bytes - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_nonposted_writes_num_issued[noc_index] += 1; - noc_nonposted_writes_acked[noc_index] += 1; // num_dests + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CTRL, noc_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_TARG_ADDR_LO, src_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_LO, dest_addr); // (uint32_t)dest_addr + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_COORDINATE, dest_noc_xy); // dest_addr >> 32 + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_AT_LEN_BE, this->page_size); // len_bytes + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_nonposted_writes_num_issued[noc] += 1; + noc_nonposted_writes_acked[noc] += 1; // num_dests } }; @@ -997,67 +1029,67 @@ struct InterleavedPow2AddrGenFast { } FORCE_INLINE - std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0) const { + std::uint64_t get_noc_addr(const uint32_t id, const uint32_t offset = 0, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t addr = this->get_addr(id, bank_offset_index, bank_index, offset); - uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); uint64_t noc_addr = get_noc_addr_helper(noc_xy, addr); return noc_addr; } FORCE_INLINE - void noc_async_read_page(const uint32_t id, uint32_t dest_addr, const uint32_t offset = 0) const { + void noc_async_read_page(const uint32_t id, uint32_t dest_addr, const uint32_t offset = 0, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t src_addr = this->get_addr(id, bank_offset_index, bank_index, offset); - uint32_t src_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t src_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); WAYPOINT("NRPW"); - DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc_index, get_noc_addr_helper(src_noc_xy, src_addr), dest_addr, 1 << this->aligned_log_base_2_of_page_size); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc, get_noc_addr_helper(src_noc_xy, src_addr), dest_addr, 1 << this->aligned_log_base_2_of_page_size); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("NRPD"); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dest_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, src_addr); // (uint32_t)src_addr - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, 1 << this->aligned_log_base_2_of_page_size); // len_bytes - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_reads_num_issued[noc_index] += 1; + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_RET_ADDR_LO, dest_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_LO, src_addr); // (uint32_t)src_addr + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_AT_LEN_BE, 1 << this->aligned_log_base_2_of_page_size); // len_bytes + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_reads_num_issued[noc] += 1; } FORCE_INLINE - void noc_async_read_partial_page(const uint32_t id, uint32_t dest_addr, const uint32_t size, const uint32_t offset) const { + void noc_async_read_partial_page(const uint32_t id, uint32_t dest_addr, const uint32_t size, const uint32_t offset, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t src_addr = this->get_addr(id, bank_offset_index, bank_index, offset); - uint32_t src_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t src_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); WAYPOINT("RPW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("RPD"); - DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc_index, get_noc_addr_helper(src_noc_xy, src_addr), dest_addr, size); - - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dest_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, src_addr); // (uint32_t)src_addr - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, size); // len_bytes - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_reads_num_issued[noc_index] += 1; + DEBUG_SANITIZE_NOC_READ_TRANSACTION(noc, get_noc_addr_helper(src_noc_xy, src_addr), dest_addr, size); + + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_RET_ADDR_LO, dest_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_LO, src_addr); // (uint32_t)src_addr + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_AT_LEN_BE, size); // len_bytes + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_reads_num_issued[noc] += 1; } FORCE_INLINE - void noc_async_write_page(const uint32_t id, uint32_t src_addr, const uint32_t write_size_bytes, const uint32_t offset = 0) const { + void noc_async_write_page(const uint32_t id, uint32_t src_addr, const uint32_t write_size_bytes, const uint32_t offset = 0, uint8_t noc = noc_index) const { uint32_t bank_offset_index = interleaved_addr_gen::get_bank_offset_index(id); uint32_t bank_index = interleaved_addr_gen::get_bank_index(id, bank_offset_index); uint32_t dest_addr = this->get_addr(id, bank_offset_index, bank_index, offset); - uint32_t dest_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index); + uint32_t dest_noc_xy = interleaved_addr_gen::get_noc_xy(bank_index, noc); WAYPOINT("NWPW"); - DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc_index, get_noc_addr_helper(dest_noc_xy, dest_addr), src_addr, write_size_bytes); - while (!noc_cmd_buf_ready(noc_index, NCRISC_WR_CMD_BUF)); + DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc, get_noc_addr_helper(dest_noc_xy, dest_addr), src_addr, write_size_bytes); + while (!noc_cmd_buf_ready(noc, write_cmd_buf)); WAYPOINT("NWPD"); constexpr uint32_t noc_cmd_field = NOC_CMD_CPY | NOC_CMD_WR | NOC_CMD_VC_STATIC | @@ -1065,19 +1097,19 @@ struct InterleavedPow2AddrGenFast { 0x0 | // (mcast ? (NOC_CMD_PATH_RESERVE | NOC_CMD_BRCST_PACKET) : 0x0) NOC_CMD_RESP_MARKED; - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CTRL, noc_cmd_field); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_TARG_ADDR_LO, src_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_LO, dest_addr); // (uint32_t)dest_addr - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_RET_ADDR_COORDINATE, dest_noc_xy); // dest_addr >> 32 - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_AT_LEN_BE, write_size_bytes); // len_bytes - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_WR_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_nonposted_writes_num_issued[noc_index] += 1; - noc_nonposted_writes_acked[noc_index] += 1; // num_dests + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CTRL, noc_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_TARG_ADDR_LO, src_addr); + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_LO, dest_addr); // (uint32_t)dest_addr + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_RET_ADDR_COORDINATE, dest_noc_xy); // dest_addr >> 32 + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_AT_LEN_BE, write_size_bytes); // len_bytes + NOC_CMD_BUF_WRITE_REG(noc, write_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_nonposted_writes_num_issued[noc] += 1; + noc_nonposted_writes_acked[noc] += 1; // num_dests } }; template -FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedAddrGen& s, uint32_t offset = 0) { +FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedAddrGen& s, uint32_t offset = 0, uint8_t noc = noc_index) { /* Alternative API for getting the noc address when we are reading using a swizzled layout. This version assumes bank unit size can be arbitrary size. Use @@ -1089,11 +1121,11 @@ FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedAddr InterleavedAddrGen: Check struct for attribute definitions. */ - return s.get_noc_addr(id, offset); + return s.get_noc_addr(id, offset, noc); } template -FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedPow2AddrGen& s, uint32_t offset = 0) { +FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedPow2AddrGen& s, uint32_t offset = 0, uint8_t noc = noc_index) { /* Alternative API for getting the noc address when we are reading using a swizzled layout. This version assumes bank unit size is a power of 2. For arbitrary bank @@ -1105,11 +1137,11 @@ FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedPow2 InterleavedPow2AddrGen: Check struct for attribute definitions. */ - return s.get_noc_addr(id, offset); + return s.get_noc_addr(id, offset, noc); } template -FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedAddrGenFast& s, uint32_t offset = 0) { +FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedAddrGenFast& s, uint32_t offset = 0, uint8_t noc = noc_index) { /* Alternative API for getting the noc address when we are reading using a swizzled layout. This version assumes bank unit size can be arbitrary size. Use @@ -1121,27 +1153,27 @@ FORCE_INLINE std::uint64_t get_noc_addr(const uint32_t id, const InterleavedAddr InterleavedAddrGen: Check struct for attribute definitions. */ - return s.get_noc_addr(id, offset); + return s.get_noc_addr(id, offset, noc); } template FORCE_INLINE void noc_async_read_page( - const uint32_t id, const InterleavedAddrGen& s, std::uint32_t dst_local_l1_addr, uint32_t offset = 0) { + const uint32_t id, const InterleavedAddrGen& s, std::uint32_t dst_local_l1_addr, uint32_t offset = 0, uint8_t noc = noc_index) { /* Read requests - use static VC Read responses - assigned VCs dynamically */ - s.noc_async_read_page(id, dst_local_l1_addr, offset); + s.noc_async_read_page(id, dst_local_l1_addr, offset, noc); } template FORCE_INLINE void noc_async_read_tile( - const uint32_t id, const InterleavedAddrGenFast& s, std::uint32_t dst_local_l1_addr, uint32_t offset = 0) { + const uint32_t id, const InterleavedAddrGenFast& s, std::uint32_t dst_local_l1_addr, uint32_t offset = 0, uint8_t noc = noc_index) { /* Read requests - use static VC Read responses - assigned VCs dynamically */ - s.noc_async_read_tile(id, dst_local_l1_addr, offset); + s.noc_async_read_tile(id, dst_local_l1_addr, offset, noc); } /** @@ -1164,15 +1196,15 @@ FORCE_INLINE void noc_async_read_tile( */ template inline -void noc_async_write(std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr, std::uint32_t size) { +void noc_async_write(std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr, std::uint32_t size, uint8_t noc = noc_index) { if constexpr (max_page_size <= NOC_MAX_BURST_SIZE) { noc_async_write_one_packet(src_local_l1_addr, dst_noc_addr, size); } else { WAYPOINT("NAWW"); - DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc_index, dst_noc_addr, src_local_l1_addr,size); + DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc, dst_noc_addr, src_local_l1_addr,size); ncrisc_noc_fast_write_any_len( - noc_index, - NCRISC_WR_CMD_BUF, + noc, + write_cmd_buf, src_local_l1_addr, dst_noc_addr, size, @@ -1187,8 +1219,8 @@ void noc_async_write(std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr template FORCE_INLINE void noc_async_write_tile( - const uint32_t id, const InterleavedAddrGenFast& s, std::uint32_t src_local_l1_addr) { - s.noc_async_write_tile(id, src_local_l1_addr); + const uint32_t id, const InterleavedAddrGenFast& s, std::uint32_t src_local_l1_addr, uint8_t noc = noc_index) { + s.noc_async_write_tile(id, src_local_l1_addr, noc); } template @@ -1198,12 +1230,12 @@ uint32_t get_semaphore(uint32_t semaphore_id) { } inline -void noc_semaphore_set_remote(std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr) { +void noc_semaphore_set_remote(std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr, uint8_t noc = noc_index) { WAYPOINT("NSSW"); - DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc_index, dst_noc_addr, src_local_l1_addr, 4); + DEBUG_SANITIZE_NOC_WRITE_TRANSACTION(noc, dst_noc_addr, src_local_l1_addr, 4); ncrisc_noc_fast_write_any_len( - noc_index, - NCRISC_WR_REG_CMD_BUF, + noc, + write_reg_cmd_buf, src_local_l1_addr, dst_noc_addr, 4 /* size in bytes */, @@ -1255,15 +1287,16 @@ void noc_async_write_multicast( std::uint32_t size, std::uint32_t num_dests, bool linked = false, - bool multicast_path_reserve = true) { + bool multicast_path_reserve = true, + uint8_t noc = noc_index) { if constexpr (max_page_size <= NOC_MAX_BURST_SIZE) { noc_async_write_multicast_one_packet(src_local_l1_addr, dst_noc_addr_multicast, size, num_dests, linked, multicast_path_reserve); } else { WAYPOINT("NMWW"); - DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc_index, dst_noc_addr_multicast, src_local_l1_addr,size); + DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc, dst_noc_addr_multicast, src_local_l1_addr,size); ncrisc_noc_fast_write_any_len( - noc_index, - NCRISC_WR_CMD_BUF, + noc, + write_cmd_buf, src_local_l1_addr, dst_noc_addr_multicast, size, @@ -1302,12 +1335,12 @@ void noc_async_write_multicast( */ inline void noc_semaphore_set_multicast( - std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr_multicast, std::uint32_t num_dests, bool linked = false, bool multicast_path_reserve = true) { + std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr_multicast, std::uint32_t num_dests, bool linked = false, bool multicast_path_reserve = true, uint8_t noc = noc_index) { WAYPOINT("NSMW"); - DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc_index, dst_noc_addr_multicast, src_local_l1_addr, 4); + DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc, dst_noc_addr_multicast, src_local_l1_addr, 4); ncrisc_noc_fast_write_any_len( - noc_index, - NCRISC_WR_REG_CMD_BUF, + noc, + write_reg_cmd_buf, src_local_l1_addr, dst_noc_addr_multicast, 4 /*size in bytes*/, @@ -1344,12 +1377,12 @@ void noc_semaphore_set_multicast( */ inline void noc_semaphore_set_multicast_loopback_src( - std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr_multicast, std::uint32_t num_dests, bool linked = false, bool multicast_path_reserve = true) { + std::uint32_t src_local_l1_addr, std::uint64_t dst_noc_addr_multicast, std::uint32_t num_dests, bool linked = false, bool multicast_path_reserve = true, uint8_t noc = noc_index) { WAYPOINT("NSMW"); - DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc_index, dst_noc_addr_multicast, src_local_l1_addr, 4); + DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc, dst_noc_addr_multicast, src_local_l1_addr, 4); ncrisc_noc_fast_write_any_len_loopback_src( - noc_index, - NCRISC_WR_REG_CMD_BUF, + noc, + write_reg_cmd_buf, src_local_l1_addr, dst_noc_addr_multicast, 4 /*size in bytes*/, @@ -1368,12 +1401,13 @@ void noc_async_write_multicast_loopback_src( std::uint32_t size, std::uint32_t num_dests, bool linked = false, - bool multicast_path_reserve = true) { + bool multicast_path_reserve = true, + uint8_t noc = noc_index) { WAYPOINT("NMLW"); - DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc_index, dst_noc_addr_multicast, src_local_l1_addr, size); + DEBUG_SANITIZE_NOC_MULTI_WRITE_TRANSACTION(noc, dst_noc_addr_multicast, src_local_l1_addr, size); ncrisc_noc_fast_write_any_len_loopback_src( - noc_index, - NCRISC_WR_CMD_BUF, + noc, + write_cmd_buf, src_local_l1_addr, dst_noc_addr_multicast, size, @@ -1393,12 +1427,12 @@ void noc_async_write_multicast_loopback_src( * * Return value: None */ -void noc_async_read_barrier() { +void noc_async_read_barrier(uint8_t noc = noc_index) { WAYPOINT("NRBW"); // BH cache is write-through so reader must invalidate if reading any address that was previously read do { invalidate_l1_cache(); - } while (!ncrisc_noc_reads_flushed(noc_index)); + } while (!ncrisc_noc_reads_flushed(noc)); WAYPOINT("NRBD"); } @@ -1411,9 +1445,9 @@ void noc_async_read_barrier() { * Return value: None */ FORCE_INLINE -void noc_async_write_barrier() { +void noc_async_write_barrier(uint8_t noc = noc_index) { WAYPOINT("NWBW"); - while (!ncrisc_noc_nonposted_writes_flushed(noc_index)) + while (!ncrisc_noc_nonposted_writes_flushed(noc)) ; WAYPOINT("NWBD"); } @@ -1424,9 +1458,9 @@ void noc_async_write_barrier() { * for them to complete */ FORCE_INLINE -void noc_async_writes_flushed() { +void noc_async_writes_flushed(uint8_t noc = noc_index) { WAYPOINT("NWFW"); - while (!ncrisc_noc_nonposted_writes_sent(noc_index)) + while (!ncrisc_noc_nonposted_writes_sent(noc)) ; WAYPOINT("NWFD"); } @@ -1532,13 +1566,13 @@ void noc_semaphore_set(volatile tt_l1_ptr uint32_t* sem_addr, uint32_t val) { * | be | Byte-enable | uint8_t | 0x1-0xF | False | */ FORCE_INLINE -void noc_inline_dw_write(uint64_t addr, uint32_t val, uint8_t be = 0xF) { +void noc_inline_dw_write(uint64_t addr, uint32_t val, uint8_t be = 0xF, uint8_t noc = noc_index) { WAYPOINT("NWIW"); - DEBUG_SANITIZE_NOC_ADDR(noc_index, addr, 4); + DEBUG_SANITIZE_NOC_ADDR(noc, addr, 4); noc_fast_write_dw_inline( - noc_index, - NCRISC_WR_REG_CMD_BUF, + noc, + write_reg_cmd_buf, val, addr, be, // byte-enable @@ -1573,7 +1607,7 @@ void noc_semaphore_inc(uint64_t addr, uint32_t incr, uint8_t noc_id = noc_index) WAYPOINT("NSIW"); DEBUG_SANITIZE_NOC_ADDR(noc_id, addr, 4); DEBUG_INSERT_DELAY(TransactionAtomic); - noc_fast_atomic_increment(noc_id, NCRISC_AT_CMD_BUF, addr, NOC_UNICAST_WRITE_VC, incr, 31 /*wrap*/, false /*linked*/, false /*posted*/); + noc_fast_atomic_increment(noc_id, write_at_cmd_buf, addr, NOC_UNICAST_WRITE_VC, incr, 31 /*wrap*/, false /*linked*/, false /*posted*/, MEM_NOC_ATOMIC_RET_VAL_ADDR); WAYPOINT("NSID"); } @@ -1589,69 +1623,69 @@ uint32_t min(uint32_t a, uint32_t b) { return (a < b) ? a: b; } template FORCE_INLINE -uint32_t noc_async_read_tile_dram_sharded_set_state(uint32_t bank_base_address, uint32_t bank_id = 0, const uint32_t vc = 0) { +uint32_t noc_async_read_tile_dram_sharded_set_state(uint32_t bank_base_address, uint32_t bank_id = 0, const uint32_t vc = 0, uint8_t noc = noc_index) { uint32_t src_addr_; uint32_t src_noc_xy; src_addr_ = bank_base_address + bank_to_dram_offset[bank_id]; - src_noc_xy = dram_bank_to_noc_xy[noc_index][bank_id]; + src_noc_xy = dram_bank_to_noc_xy[noc][bank_id]; WAYPOINT("NRTW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("NRTD"); if constexpr(use_vc) { uint32_t noc_rd_cmd_field = NOC_CMD_CPY | NOC_CMD_RD | NOC_CMD_RESP_MARKED | NOC_CMD_VC_STATIC | NOC_CMD_STATIC_VC(vc); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CTRL, noc_rd_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CTRL, noc_rd_cmd_field); } - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_AT_LEN_BE, page_size); // len_bytes + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_COORDINATE, src_noc_xy); // src_addr >> 32 + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_AT_LEN_BE, page_size); // len_bytes return src_addr_; } FORCE_INLINE -void noc_async_read_tile_dram_sharded_with_state(uint32_t src_base_addr, uint32_t src_addr, uint32_t dest_addr, uint32_t trid = 0) { +void noc_async_read_tile_dram_sharded_with_state(uint32_t src_base_addr, uint32_t src_addr, uint32_t dest_addr, uint32_t trid = 0, uint8_t noc = noc_index) { uint32_t src_addr_; src_addr_ = src_base_addr + src_addr; WAYPOINT("NRTW"); - while (!noc_cmd_buf_ready(noc_index, NCRISC_RD_CMD_BUF)); + while (!noc_cmd_buf_ready(noc, read_cmd_buf)); WAYPOINT("NRTD"); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_RET_ADDR_LO, dest_addr); - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_TARG_ADDR_LO, src_addr_); // (uint32_t)src_addr - NOC_CMD_BUF_WRITE_REG(noc_index, NCRISC_RD_CMD_BUF, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); - noc_reads_num_issued[noc_index] += 1; + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_RET_ADDR_LO, dest_addr); + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_TARG_ADDR_LO, src_addr_); // (uint32_t)src_addr + NOC_CMD_BUF_WRITE_REG(noc, read_cmd_buf, NOC_CMD_CTRL, NOC_CTRL_SEND_REQ); + noc_reads_num_issued[noc] += 1; } FORCE_INLINE -void noc_async_read_tile_dram_sharded_with_state_with_trid(uint32_t src_base_addr, uint32_t src_addr, uint32_t dest_addr, uint32_t trid = 0) { +void noc_async_read_tile_dram_sharded_with_state_with_trid(uint32_t src_base_addr, uint32_t src_addr, uint32_t dest_addr, uint32_t trid = 0, uint8_t noc = noc_index) { WAYPOINT("NRDW"); #ifndef ARCH_GRAYSKULL - ncrisc_noc_fast_read_with_transaction_id(noc_index, NCRISC_RD_CMD_BUF, src_base_addr, src_addr, dest_addr, trid); + ncrisc_noc_fast_read_with_transaction_id(noc, read_cmd_buf, src_base_addr, src_addr, dest_addr, trid); #endif WAYPOINT("NRDD"); } FORCE_INLINE -void noc_async_read_tile_dram_sharded_set_trid(uint32_t trid = 0) { +void noc_async_read_tile_dram_sharded_set_trid(uint32_t trid = 0, uint8_t noc = noc_index) { WAYPOINT("NSTW"); #ifndef ARCH_GRAYSKULL - ncrisc_noc_set_transaction_id(noc_index, NCRISC_RD_CMD_BUF, trid); + ncrisc_noc_set_transaction_id(noc, read_cmd_buf, trid); #endif WAYPOINT("NSTD"); } FORCE_INLINE -void noc_async_read_barrier_with_trid(uint32_t trid) { +void noc_async_read_barrier_with_trid(uint32_t trid, uint8_t noc = noc_index) { WAYPOINT("NBTW"); #ifndef ARCH_GRAYSKULL do { invalidate_l1_cache(); - } while (!ncrisc_noc_read_with_transaction_id_flushed(noc_index, trid)); + } while (!ncrisc_noc_read_with_transaction_id_flushed(noc, trid)); #endif WAYPOINT("NBTD"); } diff --git a/tt_metal/hw/inc/debug/sanitize_noc.h b/tt_metal/hw/inc/debug/sanitize_noc.h index de70b3d68ea..e0a46def931 100644 --- a/tt_metal/hw/inc/debug/sanitize_noc.h +++ b/tt_metal/hw/inc/debug/sanitize_noc.h @@ -24,7 +24,6 @@ #include "watcher_common.h" -extern uint8_t noc_index; #include "dev_msgs.h" #include "noc_overlay_parameters.h" diff --git a/tt_metal/hw/inc/dev_msgs.h b/tt_metal/hw/inc/dev_msgs.h index 20439c4d79f..18cad7e3778 100644 --- a/tt_metal/hw/inc/dev_msgs.h +++ b/tt_metal/hw/inc/dev_msgs.h @@ -68,6 +68,17 @@ enum dispatch_core_processor_masks { DISPATCH_CLASS_MASK_ETH_DM0 = 1 << DISPATCH_CLASS_ETH_DM0, }; +enum noc_index { + NOC_0 = 0, + NOC_1 = 1, +}; + +enum noc_mode : uint8_t { + DM_DEDICATED_NOC = 0, + DM_DYNAMIC_NOC = 1, + DM_INVALID_NOC = 2, +}; + // Address offsets to kernel runtime configuration components // struct to densely packs values used by each processor struct dyn_mem_map_t { @@ -89,9 +100,9 @@ struct kernel_config_msg_t { volatile uint8_t mode; // dispatch mode host/dev volatile uint8_t brisc_noc_id; + volatile uint8_t brisc_noc_mode; volatile uint8_t max_cb_index; volatile uint8_t exit_erisc_kernel; - volatile uint8_t pad; volatile uint8_t enables; } __attribute__((packed)); diff --git a/tt_metal/hw/inc/grayskull/noc_nonblocking_api.h b/tt_metal/hw/inc/grayskull/noc_nonblocking_api.h index 0298243b385..8fe58acf1a6 100644 --- a/tt_metal/hw/inc/grayskull/noc_nonblocking_api.h +++ b/tt_metal/hw/inc/grayskull/noc_nonblocking_api.h @@ -7,13 +7,29 @@ #include #include "noc_parameters.h" +#include "dev_msgs.h" //// -const uint32_t NCRISC_WR_CMD_BUF = 0; // for large writes -const uint32_t NCRISC_RD_CMD_BUF = 1; // for all reads -const uint32_t NCRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) -const uint32_t NCRISC_AT_CMD_BUF = 3; // for atomics +constexpr uint32_t DYNAMIC_NOC_NCRISC_WR_CMD_BUF = 2; // all writes share cmd buf +constexpr uint32_t DYNAMIC_NOC_NCRISC_WR_REG_CMD_BUF = 2; +constexpr uint32_t DYNAMIC_NOC_NCRISC_AT_CMD_BUF = 2; +constexpr uint32_t DYNAMIC_NOC_NCRISC_RD_CMD_BUF = 3; + +constexpr uint32_t DYNAMIC_NOC_BRISC_WR_CMD_BUF = 0; // all writes share cmd buf +constexpr uint32_t DYNAMIC_NOC_BRISC_WR_REG_CMD_BUF = 0; +constexpr uint32_t DYNAMIC_NOC_BRISC_AT_CMD_BUF = 0; +constexpr uint32_t DYNAMIC_NOC_BRISC_RD_CMD_BUF = 1; + +constexpr uint32_t NCRISC_WR_CMD_BUF = 0; // for large writes +constexpr uint32_t NCRISC_RD_CMD_BUF = 1; // for all reads +constexpr uint32_t NCRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) +constexpr uint32_t NCRISC_AT_CMD_BUF = 3; // for atomics + +constexpr uint32_t BRISC_WR_CMD_BUF = 0; // for large writes +constexpr uint32_t BRISC_RD_CMD_BUF = 1; // for all reads +constexpr uint32_t BRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) +constexpr uint32_t BRISC_AT_CMD_BUF = 3; // for atomics // 32 bits of address followed by coordinate. First address goes into lo register, coordinates are in the mid register constexpr uint32_t NOC_ADDR_COORD_SHIFT = 32; // address is lower 36 bits and upper bits are the coordinates, 32 bits in lo reg and rest goes to mid @@ -159,6 +175,31 @@ inline __attribute__((always_inline)) void noc_init(uint32_t atomic_ret_val) { } } +inline __attribute__((always_inline)) void dynamic_noc_init() { +#pragma GCC unroll 0 + for (int noc = 0; noc < NUM_NOCS; noc++) { + uint32_t noc_id_reg = NOC_CMD_BUF_READ_REG(noc, 0, NOC_NODE_ID); + uint32_t my_x = noc_id_reg & NOC_NODE_ID_MASK; + uint32_t my_y = (noc_id_reg >> NOC_ADDR_NODE_ID_BITS) & NOC_NODE_ID_MASK; + uint64_t xy_local_addr = NOC_XY_ADDR(my_x, my_y, 0); + + uint32_t noc_rd_cmd_field = NOC_CMD_CPY | NOC_CMD_RD | NOC_CMD_RESP_MARKED | NOC_CMD_VC_STATIC | NOC_CMD_STATIC_VC(1); + + // program brisc cmd_buf 0 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_RD_CMD_BUF, NOC_CTRL, noc_rd_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_RD_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program brisc cmd_buf 1 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_WR_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program ncrisc cmd_buf 2 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_RD_CMD_BUF, NOC_CTRL, noc_rd_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_RD_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program ncrisc cmd_buf 3 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_WR_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + } +} // set noc local memory state for a single kernel from the global state inline __attribute__((always_inline)) void noc_local_state_init(int noc) { @@ -254,8 +295,17 @@ inline __attribute__((always_inline)) void noc_fast_write_dw_inline(uint32_t noc } } -inline __attribute__((always_inline)) void noc_fast_atomic_increment(uint32_t noc, uint32_t cmd_buf, uint64_t addr, uint32_t vc, uint32_t incr, uint32_t wrap, bool linked, bool posted = false) { +template +inline __attribute__((always_inline)) void noc_fast_atomic_increment(uint32_t noc, uint32_t cmd_buf, uint64_t addr, uint32_t vc, uint32_t incr, uint32_t wrap, bool linked, bool posted = false, uint32_t atomic_ret_val = 0) { while (!noc_cmd_buf_ready(noc, cmd_buf)); + if constexpr (noc_mode == DM_DYNAMIC_NOC) { + uint32_t noc_id_reg = NOC_CMD_BUF_READ_REG(noc, 0, NOC_NODE_ID); + uint32_t my_x = noc_id_reg & NOC_NODE_ID_MASK; + uint32_t my_y = (noc_id_reg >> NOC_ADDR_NODE_ID_BITS) & NOC_NODE_ID_MASK; + uint64_t atomic_ret_addr = NOC_XY_ADDR(my_x, my_y, atomic_ret_val); + NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_RET_ADDR_LO, (uint32_t)(atomic_ret_addr & 0xFFFFFFFF)); + NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_RET_ADDR_COORDINATE, (uint32_t)(atomic_ret_addr >> NOC_ADDR_COORD_SHIFT)); + } NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_TARG_ADDR_LO, (uint32_t)(addr & 0xFFFFFFFF)); NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_TARG_ADDR_COORDINATE, (uint32_t)(addr >> NOC_ADDR_COORD_SHIFT)); NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_CTRL, diff --git a/tt_metal/hw/inc/mod_div_lib.h b/tt_metal/hw/inc/mod_div_lib.h index da4b84f967e..2a0d3a30b66 100644 --- a/tt_metal/hw/inc/mod_div_lib.h +++ b/tt_metal/hw/inc/mod_div_lib.h @@ -40,6 +40,22 @@ inline __attribute__((always_inline)) uint32_t fast_udiv_56(uint32_t n) return (((uint64_t) n * 0x24924925) >> 32) >> 3; } +inline __attribute__((always_inline)) uint32_t fast_udiv_70(uint32_t n) +{ + // Uses embedding style magic number + // * fixed point 1/70 then shifting. + // https://web.archive.org/web/20190703172151/http://www.hackersdelight.org/magic.htm + return (((uint64_t) n * 0xEA0EA0EB) >> 32) >> 6; +} + +inline __attribute__((always_inline)) uint32_t fast_udiv_80(uint32_t n) +{ + // Uses embedding style magic number + // * fixed point 1/80 then shifting. + // https://web.archive.org/web/20190703172151/http://www.hackersdelight.org/magic.htm + return (((uint64_t) n * 0xCCCCCCCD) >> 32) >> 6; +} + inline __attribute__((always_inline)) uint32_t fast_udiv_94(uint32_t n) { // Uses embedding style magic number @@ -78,7 +94,11 @@ inline __attribute__((always_inline)) uint32_t udivsi3_const_divisor(uint32_t n) } else if constexpr (d == 56) { // fast divide for 56 divisor. Handles Banked L1 address generation for N300 return fast_udiv_56(n); - } else if constexpr (d == 94) { + } else if constexpr (d == 70) { + return fast_udiv_70(n); + } else if constexpr (d == 80) { + return fast_udiv_80(n); + } else if constexpr (d == 94) { // fast divide for 94 divisor. Handles Banked L1 address generation for E75 return fast_udiv_94(n); } else if constexpr (d == 124) { diff --git a/tt_metal/hw/inc/risc_common.h b/tt_metal/hw/inc/risc_common.h index ac7cf5edc81..c92a42651b1 100644 --- a/tt_metal/hw/inc/risc_common.h +++ b/tt_metal/hw/inc/risc_common.h @@ -20,6 +20,8 @@ #define NOC_X(x) NOC_0_X(noc_index, noc_size_x, (x)) #define NOC_Y(y) NOC_0_Y(noc_index, noc_size_y, (y)) +#define DYNAMIC_NOC_X(noc, x) NOC_0_X(noc, noc_size_x, (x)) +#define DYNAMIC_NOC_Y(noc, y) NOC_0_Y(noc, noc_size_y, (y)) #define TILE_WORD_2_BIT ((256 + 64 + 32) >> 4) #define TILE_WORD_4_BIT ((512 + 64 + 32) >> 4) diff --git a/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h b/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h index 40d4a6ec39f..48b6411911d 100644 --- a/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h +++ b/tt_metal/hw/inc/wormhole/noc_nonblocking_api.h @@ -7,13 +7,32 @@ #include #include "noc_parameters.h" +#include "dev_msgs.h" //// -const uint32_t NCRISC_WR_CMD_BUF = 0; // for large writes -const uint32_t NCRISC_RD_CMD_BUF = 1; // for all reads -const uint32_t NCRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) -const uint32_t NCRISC_AT_CMD_BUF = 3; // for atomics +// Use VC 1 for unicast writes, and VC 4 for mcast writes + +// used for ops with USE_MULTI_MOC defined +constexpr uint32_t DYNAMIC_NOC_NCRISC_WR_CMD_BUF = 2; // all writes share cmd buf +constexpr uint32_t DYNAMIC_NOC_NCRISC_WR_REG_CMD_BUF = 2; +constexpr uint32_t DYNAMIC_NOC_NCRISC_AT_CMD_BUF = 2; +constexpr uint32_t DYNAMIC_NOC_NCRISC_RD_CMD_BUF = 3; + +constexpr uint32_t DYNAMIC_NOC_BRISC_WR_CMD_BUF = 0; // all writes share cmd buf +constexpr uint32_t DYNAMIC_NOC_BRISC_WR_REG_CMD_BUF = 0; +constexpr uint32_t DYNAMIC_NOC_BRISC_AT_CMD_BUF = 0; +constexpr uint32_t DYNAMIC_NOC_BRISC_RD_CMD_BUF = 1; + +constexpr uint32_t NCRISC_WR_CMD_BUF = 0; // for large writes +constexpr uint32_t NCRISC_RD_CMD_BUF = 1; // for all reads +constexpr uint32_t NCRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) +constexpr uint32_t NCRISC_AT_CMD_BUF = 3; // for atomics + +constexpr uint32_t BRISC_WR_CMD_BUF = 0; // for large writes +constexpr uint32_t BRISC_RD_CMD_BUF = 1; // for all reads +constexpr uint32_t BRISC_WR_REG_CMD_BUF = 2; // for small writes (e.g., registers, semaphores) +constexpr uint32_t BRISC_AT_CMD_BUF = 3; // for atomics // 36 bits of address followed by coordinate. First 32 bits of address go into lo register, remaining address bits and coordinates are in the mid register constexpr uint32_t NOC_ADDR_COORD_SHIFT = 32; // address is lower 36 bits and upper bits are the coordinates, 32 bits in lo reg and rest goes to mid @@ -169,6 +188,32 @@ inline __attribute__((always_inline)) void noc_init(uint32_t atomic_ret_val) { } } +inline __attribute__((always_inline)) void dynamic_noc_init() { +#pragma GCC unroll 0 + for (int noc = 0; noc < NUM_NOCS; noc++) { + uint32_t noc_id_reg = NOC_CMD_BUF_READ_REG(noc, 0, NOC_NODE_ID); + uint32_t my_x = noc_id_reg & NOC_NODE_ID_MASK; + uint32_t my_y = (noc_id_reg >> NOC_ADDR_NODE_ID_BITS) & NOC_NODE_ID_MASK; + uint64_t xy_local_addr = NOC_XY_ADDR(my_x, my_y, 0); + + uint32_t noc_rd_cmd_field = NOC_CMD_CPY | NOC_CMD_RD | NOC_CMD_RESP_MARKED | NOC_CMD_VC_STATIC | NOC_CMD_STATIC_VC(1); + + // program brisc cmd_buf 0 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_RD_CMD_BUF, NOC_CTRL, noc_rd_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_RD_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program brisc cmd_buf 1 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_BRISC_WR_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program ncrisc cmd_buf 2 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_RD_CMD_BUF, NOC_CTRL, noc_rd_cmd_field); + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_RD_CMD_BUF, NOC_RET_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + + // program ncrisc cmd_buf 3 + NOC_CMD_BUF_WRITE_REG(noc, DYNAMIC_NOC_NCRISC_WR_CMD_BUF, NOC_TARG_ADDR_COORDINATE, (uint32_t)(xy_local_addr >> NOC_ADDR_COORD_SHIFT)); + } +} + // set noc local memory state for a single kernel from the global state inline __attribute__((always_inline)) void noc_local_state_init(int noc) { @@ -264,8 +309,17 @@ inline __attribute__((always_inline)) void noc_fast_write_dw_inline(uint32_t noc } } -inline __attribute__((always_inline)) void noc_fast_atomic_increment(uint32_t noc, uint32_t cmd_buf, uint64_t addr, uint32_t vc, uint32_t incr, uint32_t wrap, bool linked, bool posted = false) { +template +inline __attribute__((always_inline)) void noc_fast_atomic_increment(uint32_t noc, uint32_t cmd_buf, uint64_t addr, uint32_t vc, uint32_t incr, uint32_t wrap, bool linked, bool posted = false, uint32_t atomic_ret_val = 0) { while (!noc_cmd_buf_ready(noc, cmd_buf)); + if constexpr (noc_mode == DM_DYNAMIC_NOC) { + uint32_t noc_id_reg = NOC_CMD_BUF_READ_REG(noc, 0, NOC_NODE_ID); + uint32_t my_x = noc_id_reg & NOC_NODE_ID_MASK; + uint32_t my_y = (noc_id_reg >> NOC_ADDR_NODE_ID_BITS) & NOC_NODE_ID_MASK; + uint64_t atomic_ret_addr = NOC_XY_ADDR(my_x, my_y, atomic_ret_val); + NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_RET_ADDR_LO, (uint32_t)(atomic_ret_addr & 0xFFFFFFFF)); + NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_RET_ADDR_COORDINATE, (uint32_t)(atomic_ret_addr >> NOC_ADDR_COORD_SHIFT)); + } NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_TARG_ADDR_LO, (uint32_t)(addr & 0xFFFFFFFF)); NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_TARG_ADDR_COORDINATE, (uint32_t)(addr >> NOC_ADDR_COORD_SHIFT)); NOC_CMD_BUF_WRITE_REG(noc, cmd_buf, NOC_CTRL, diff --git a/tt_metal/impl/allocator/allocator.cpp b/tt_metal/impl/allocator/allocator.cpp index 8e8675b4ca7..6c7fd50acff 100644 --- a/tt_metal/impl/allocator/allocator.cpp +++ b/tt_metal/impl/allocator/allocator.cpp @@ -37,7 +37,7 @@ void validate_num_banks(uint32_t num_banks, const BufferType &buffer_type) { // Dataflow API does not have a working implementation of generic modulo to determine bank_id for interleaved // address gen For non pow2 num banks, special cases need to be added to avoid falling back to generic // implementation. See https://github.com/tenstorrent/tt-metal/issues/3321 - std::unordered_set acceptable_num_non_pow2_mem_banks = {12, 56, 94, 124, 130, 140}; + std::unordered_set acceptable_num_non_pow2_mem_banks = {12, 56, 70, 80, 94, 124, 130, 140}; bool custom_mod_bank_id_calculation_exists = acceptable_num_non_pow2_mem_banks.count(num_banks) > 0; bool doesnt_support_interleaved = buffer_type == BufferType::L1_SMALL; bool valid_num_banks = (is_pow2_num_banks or custom_mod_bank_id_calculation_exists or doesnt_support_interleaved); diff --git a/tt_metal/impl/kernels/data_types.hpp b/tt_metal/impl/kernels/data_types.hpp index 787bc74be95..834058c7b6b 100644 --- a/tt_metal/impl/kernels/data_types.hpp +++ b/tt_metal/impl/kernels/data_types.hpp @@ -18,6 +18,11 @@ enum NOC : uint8_t { NOC_1 = 1, }; +enum NOC_MODE : uint8_t { + DM_DEDICATED_NOC = 0, + DM_DYNAMIC_NOC = 1, +}; + enum Eth : uint8_t { SENDER = 0, RECEIVER = 1, diff --git a/tt_metal/impl/kernels/kernel.cpp b/tt_metal/impl/kernels/kernel.cpp index a714fc2de8b..5e78e50e741 100644 --- a/tt_metal/impl/kernels/kernel.cpp +++ b/tt_metal/impl/kernels/kernel.cpp @@ -105,6 +105,7 @@ void DataMovementKernel::process_defines( const std::function callback) const { Kernel::process_defines(callback); callback("NOC_INDEX", std::to_string(this->config_.noc)); + callback("NOC_MODE", std::to_string(this->config_.noc_mode)); } void ComputeKernel::process_defines( @@ -112,12 +113,16 @@ void ComputeKernel::process_defines( for (const auto &[define, value] : this->defines_) { callback(define, value); } + // pass default noc mode as compute does not need it, just for compile to pass + callback("NOC_MODE", std::to_string(NOC_MODE::DM_DEDICATED_NOC)); } void EthernetKernel::process_defines( const std::function callback) const { Kernel::process_defines(callback); callback("NOC_INDEX", std::to_string(this->config_.noc)); + // pass default noc mode as eth does not need it, just for compile to pass + callback("NOC_MODE", std::to_string(NOC_MODE::DM_DEDICATED_NOC)); } void Kernel::process_compile_time_args(const std::function callback) const { diff --git a/tt_metal/impl/kernels/kernel_types.hpp b/tt_metal/impl/kernels/kernel_types.hpp index a1edd32eb50..6a878c49544 100644 --- a/tt_metal/impl/kernels/kernel_types.hpp +++ b/tt_metal/impl/kernels/kernel_types.hpp @@ -20,6 +20,7 @@ using KernelHandle = std::uint16_t; struct DataMovementConfig { DataMovementProcessor processor = DataMovementProcessor::RISCV_0; // For data transfer kernels: NCRISC & BRISC NOC noc = NOC::RISCV_0_default; + NOC_MODE noc_mode = NOC_MODE::DM_DEDICATED_NOC; std::vector compile_args; // Will cause CompileProgram to emit a file hlk_defines_generated.h // Each unique combination of defines will produce a unique compiled instantiation @@ -32,6 +33,7 @@ struct ReaderDataMovementConfig : public DataMovementConfig { DataMovementConfig{ .processor = DataMovementProcessor::RISCV_1, .noc = detail::GetPreferredNOCForDRAMRead(tt::Cluster::instance().arch()), + .noc_mode = NOC_MODE::DM_DEDICATED_NOC, .compile_args = compile_args, .defines = defines} {} }; @@ -41,6 +43,7 @@ struct WriterDataMovementConfig : public DataMovementConfig { DataMovementConfig{ .processor = DataMovementProcessor::RISCV_0, .noc = detail::GetPreferredNOCForDRAMWrite(tt::Cluster::instance().arch()), + .noc_mode = NOC_MODE::DM_DEDICATED_NOC, .compile_args = compile_args, .defines = defines} {} }; diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index 19a96a4e5d0..1877b9855de 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -160,6 +160,7 @@ KernelGroup::KernelGroup( this->programmable_core_type_index = programmable_core_type_index; this->core_ranges = this->core_ranges.merge(new_ranges); this->kernel_ids = kernel_ids; + this->launch_msg.kernel_config.brisc_noc_mode = NOC_MODE::DM_DEDICATED_NOC; std::memset(&this->launch_msg, 0, sizeof(launch_msg_t)); @@ -183,10 +184,18 @@ KernelGroup::KernelGroup( if (class_id == DISPATCH_CLASS_TENSIX_DM0) { // Use brisc's noc if brisc specifies a noc this->launch_msg.kernel_config.brisc_noc_id = std::get(kernel->config()).noc; + // if noc mode is already set to DM_DYNAMIC_NOC then we can't change back to DM_DEDICATED_NOC + if (std::get(kernel->config()).noc_mode == NOC_MODE::DM_DYNAMIC_NOC) { + this->launch_msg.kernel_config.brisc_noc_mode = NOC_MODE::DM_DYNAMIC_NOC; + } } else if (class_id == DISPATCH_CLASS_TENSIX_DM1) { // Use 1-ncrisc's noc (the other noc) if ncrisc specifies a noc // If both brisc and ncrisc set the noc, then this is safe due to prior correctness validation this->launch_msg.kernel_config.brisc_noc_id = 1 - std::get(kernel->config()).noc; + // if noc mode is already set to DM_DYNAMIC_NOC then we can't change back to DM_DEDICATED_NOC + if (this->launch_msg.kernel_config.brisc_noc_mode == NOC_MODE::DM_DYNAMIC_NOC) { + this->launch_msg.kernel_config.brisc_noc_mode = NOC_MODE::DM_DYNAMIC_NOC; + } this->launch_msg.kernel_config.ncrisc_kernel_size16 = kernel->get_binary_size16(); } } From d616ca68a018d15f8ba7ba531ac228254ec9a9dd Mon Sep 17 00:00:00 2001 From: Shaw Nguyen Date: Mon, 7 Oct 2024 08:16:54 +0000 Subject: [PATCH 17/58] #13463: Develop Clone Operation --- .../sweep_tests/pytests/tt_dnn/test_copy.py | 62 ----- .../ttnn/unit_tests/operations/test_clone.py | 254 ++++++++++++++++++ ttnn/CMakeLists.txt | 4 + .../operations/data_movement/clone/clone.cpp | 14 + .../operations/data_movement/clone/clone.hpp | 18 ++ .../data_movement/clone/clone_pybind.cpp | 37 +++ .../data_movement/clone/clone_pybind.hpp | 13 + .../clone/device/clone_device_operation.cpp | 64 +++++ .../clone/device/clone_device_operation.hpp | 61 +++++ .../clone/device/clone_program_factory.cpp | 171 ++++++++++++ .../clone/device/kernels/compute_kernel.cpp | 31 +++ .../clone/device/kernels/read_kernel.cpp | 36 +++ .../clone/device/kernels/read_kernel_rm.cpp | 36 +++ .../clone/device/kernels/write_kernel.cpp | 30 +++ .../clone/device/kernels/write_kernel_rm.cpp | 36 +++ .../operations/data_movement/copy/copy.cpp | 63 ++--- .../operations/data_movement/copy/copy.hpp | 42 +-- .../data_movement/copy/copy_pybind.cpp | 58 +--- .../data_movement/copy/copy_pybind.hpp | 1 - .../data_movement/data_movement_pybind.hpp | 84 +++--- .../experimental/auto_format/auto_format.cpp | 84 +++--- 21 files changed, 938 insertions(+), 261 deletions(-) create mode 100644 tests/ttnn/unit_tests/operations/test_clone.py create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/clone.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/clone.hpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/clone_pybind.hpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.hpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel_rm.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel.cpp create mode 100644 ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel_rm.cpp diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_copy.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_copy.py index 86c46e8b8f6..3a238ff96b7 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_copy.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_copy.py @@ -66,65 +66,3 @@ def test_run_copy_op( device, test_args, ) - - -@pytest.mark.parametrize( - "input_shapes", - [ - [[1, 1, 1, 30]], # Single core - [[1, 1, 300, 380]], # multi core - [[1, 3, 320, 380]], # multi core - [[1, 1, 32, 32]], # Single core - [[1, 1, 320, 384]], # Multi core - [[1, 3, 320, 384]], # Multi core - ], -) -@pytest.mark.parametrize( - "input_mem_config", - mem_configs, -) -@pytest.mark.parametrize( - "dst_mem_config", - mem_configs, -) -@pytest.mark.parametrize( - "output_type", - [ - ttnn.bfloat16, - ], -) -@pytest.mark.parametrize( - "input_type", - [ - torch.float32, - torch.float16, - torch.bfloat16, - ], -) -class TestClone: - def test_run_clone_op( - self, - input_type, - output_type, - input_shapes, - input_mem_config, - dst_mem_config, - device, - function_level_defaults, - ): - datagen_func = [ - generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-100, high=100), input_type) - ] - test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0] - test_args["input_mem_config"] = [input_mem_config] - test_args["dtype"] = [output_type] - test_args.update({"output_mem_config": dst_mem_config}) - comparison_func = partial(comparison_funcs.comp_allclose, rtol=1e-1, atol=1e-1) - run_single_pytorch_test( - "clone", - input_shapes, - datagen_func, - comparison_func, - device, - test_args, - ) diff --git a/tests/ttnn/unit_tests/operations/test_clone.py b/tests/ttnn/unit_tests/operations/test_clone.py new file mode 100644 index 00000000000..c5ba6d7df0c --- /dev/null +++ b/tests/ttnn/unit_tests/operations/test_clone.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import ttnn +from models.utility_functions import comp_allclose_and_pcc +from loguru import logger + +from tests.ttnn.unit_tests.operations.test_utils import ( + to_cpu, + to_npu, +) + + +def get_lib_dtype(lib, dtype): + """ + Maps string-based data types to their corresponding library-specific dtypes. + + Parameters: + lib: library module (e.g., torch, ttnn) + The library for which the dtype mapping is required. + dtype: str + The string representation of the data type (e.g., 'bfloat16', 'float32', 'int32'). + + Returns: + Corresponding library-specific dtype or None if not found. + """ + dtype_map = { + "bfloat16": lib.bfloat16, + "float32": lib.float32, + "int32": lib.int32, + } + return dtype_map.get(dtype, None) + + +def run_clone( + shape, + input_memory_config, + output_memory_config, + input_dtype, + output_dtype, + tilized, + device, +): + """ + Function to test the clone operation on NPU. Generates random input data, clones it on NPU, + and compares the output with the CPU clone for correctness. + + Parameters: + shape: tuple + Shape of the input tensor. + input_memory_config: MemoryConfig + Memory configuration for the input tensor on NPU. + output_memory_config: MemoryConfig + Memory configuration for the output tensor on NPU. + input_dtype: str + Data type of the input tensor ('int32' or other). + output_dtype: str or None + Data type of the output tensor (must be None or match input_dtype when not tilized). + tilized: bool + Whether to use TILE_LAYOUT or ROW_MAJOR_LAYOUT for NPU tensor. + device: ttnn.device + Device where the operation is performed (e.g., NPU device). + + Raises: + pytest.skip: When certain conditions on dtype mismatch or layout are not met. + """ + if input_dtype == "int32": + cpu_input = torch.randint(low=-10, high=11, size=shape, dtype=get_lib_dtype(torch, input_dtype)) + else: + cpu_input = 2 * torch.rand(size=shape, dtype=get_lib_dtype(torch, input_dtype)) - 1 + + if input_dtype == "int32": + if output_dtype and output_dtype != "int32": + pytest.skip("For int32 input, output_dtype must be None or int32.") + if output_dtype == "int32" and input_dtype != "int32": + pytest.skip("For int32 output, input_dtype must also be int32.") + if output_dtype != input_dtype and output_dtype and not tilized: + pytest.skip("When not tilized, dtype conversion is not supported.") + + npu_input = to_npu( + cpu_input, + device, + npu_dtype=get_lib_dtype(ttnn, input_dtype), + npu_layout=ttnn.TILE_LAYOUT if tilized else ttnn.ROW_MAJOR_LAYOUT, + ).to(device, input_memory_config) + + npu_output = ttnn.clone( + npu_input, + dtype=get_lib_dtype(ttnn, output_dtype), + memory_config=output_memory_config, + ) + + cpu_output = to_cpu(npu_output, shape) + + passing, out = comp_allclose_and_pcc(torch.ops.aten.clone(cpu_input), cpu_output, rtol=0.01, atol=0.01) + logger.info(out) + assert passing + + +memory_config_list = [ + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.DRAM), + ttnn.MemoryConfig(ttnn.TensorMemoryLayout.INTERLEAVED, ttnn.BufferType.L1), +] + + +@pytest.mark.parametrize( + "shape", + [ + [10], # 1d + [10, 10], # 2d + [10, 10, 10], # 3d + [10, 10, 10, 10], # 4d + [1, 1, 1, 30], # Single core + [1, 1, 300, 380], # Multi core + [1, 3, 320, 380], # Multi core + [1, 1, 32, 32], # Single core + [1, 1, 320, 384], # Multi core + [1, 3, 320, 384], # Multi core + ], +) +@pytest.mark.parametrize( + "tilized", + [True, False], +) +def test_clone_shape( + shape, + tilized, + device, +): + """ + Test case to verify the clone operation on different tensor shapes and layouts (tilized or not). + """ + torch.manual_seed(2024) + run_clone( + shape, + memory_config_list[0], + memory_config_list[0], + "bfloat16", + None, + tilized, + device, + ) + + +@pytest.mark.parametrize( + "input_memory_config", + memory_config_list, +) +@pytest.mark.parametrize( + "output_memory_config", + [*memory_config_list, None], +) +@pytest.mark.parametrize( + "tilized", + [True, False], +) +def test_clone_memory_config( + input_memory_config, + output_memory_config, + tilized, + device, +): + """ + Test case to verify the clone operation with different memory configurations (input/output) + and layout configurations (tilized or not). + """ + torch.manual_seed(2024) + run_clone( + [1, 3, 320, 384], + input_memory_config, + output_memory_config, + "bfloat16", + None, + tilized, + device, + ) + + +@pytest.mark.parametrize( + "input_dtype", + [ + "bfloat16", + "float32", + "int32", + ], +) +@pytest.mark.parametrize( + "output_dtype", + [ + "bfloat16", + "float32", + "int32", + None, + ], +) +@pytest.mark.parametrize( + "tilized", + [True, False], +) +def test_clone_dtype_conversion( + input_dtype, + output_dtype, + tilized, + device, +): + """ + Test case to verify the clone operation with various input/output dtype combinations. + """ + torch.manual_seed(2024) + run_clone( + [1, 3, 320, 384], + memory_config_list[0], + memory_config_list[0], + input_dtype, + output_dtype, + tilized, + device, + ) + + +@pytest.mark.parametrize( + "tilized", + [True, False], +) +def test_clone_callback( + tilized, + device, + use_program_cache, +): + """ + Test case to verify the clone operation with various input/output dtype combinations. + """ + torch.manual_seed(2024) + num_program_cache_entries_list = [] + for i in range(2): + run_clone( + [1, 3, 320, 384], + memory_config_list[0], + memory_config_list[0], + "bfloat16", + None, + tilized, + device, + ) + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) + num_program_cache_entries_list.append(device.num_program_cache_entries()) + logger.info(f"num_program_cache_entries_list={num_program_cache_entries_list}") + assert num_program_cache_entries_list[0] > 0 + assert num_program_cache_entries_list[0] == num_program_cache_entries_list[1] diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index f2e7ec8f91b..4a18128dc8e 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -352,6 +352,10 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/reshard/reshard_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/clone.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/clone.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/clone.cpp new file mode 100644 index 00000000000..136bba1d970 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/clone.cpp @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "clone.hpp" + +#include "device/clone_device_operation.hpp" + +namespace ttnn::operations::data_movement::clone { +Tensor Clone::invoke( + const Tensor& input, const std::optional& dtype, const std::optional& memory_config) { + return ttnn::prim::clone(input, dtype, memory_config); +} +} // namespace ttnn::operations::data_movement::clone diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/clone.hpp b/ttnn/cpp/ttnn/operations/data_movement/clone/clone.hpp new file mode 100644 index 00000000000..f4f2b6737e5 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/clone.hpp @@ -0,0 +1,18 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "ttnn/decorators.hpp" + +namespace ttnn::operations::data_movement::clone { +struct Clone { + static Tensor invoke( + const Tensor& input, const std::optional& dtype, const std::optional& memory_config); +}; +} // namespace ttnn::operations::data_movement::clone + +namespace ttnn { +constexpr auto clone = + ttnn::register_operation_with_auto_launch_op<"ttnn::clone", ttnn::operations::data_movement::clone::Clone>(); +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp new file mode 100644 index 00000000000..6485b569a58 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/clone_pybind.cpp @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "clone_pybind.hpp" + +#include "clone.hpp" +#include "pybind11/decorators.hpp" + +namespace ttnn::operations::data_movement::clone { +void bind_clone_operation(py::module& module) { + auto doc = R"doc(clone(input: Tensor, dtype: DataType, memory_config: MemoryConfig) -> Tensor + + Clones the input, creating a copy with the specified `memory_config` and converting its data type to `dtype`. + This operation does not alter the tensor's layout. + - ROW_MAJOR_LAYOUT: Returns the tensor unpadded in the last two dimensions. + - TILE_LAYOUT: Pads the tensor to ensure its width and height are multiples of 32. + If the input's current layout matches the specified layout, padding adjustments are applied to the last two dimensions as necessary. + + Args: + * :attr:`input`: The tensor to be cloned. + * :attr:`dtype`: The target data type of the cloned tensor. + * :attr:`memory_config`: The memory configuration for the clone, options include DRAM_MEMORY_CONFIG or L1_MEMORY_CONFIG. + )doc"; + + bind_registered_operation( + module, + ttnn::clone, + doc, + ttnn::pybind_arguments_t{ + py::arg("input"), + py::kw_only(), + py::arg("dtype") = std::nullopt, + py::arg("memory_config") = std::nullopt, + }); +} +} // namespace ttnn::operations::data_movement::clone diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/clone_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/clone/clone_pybind.hpp new file mode 100644 index 00000000000..b201401ca35 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/clone_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::data_movement::clone { +void bind_clone_operation(py::module& module); +} // namespace ttnn::operations::data_movement::clone diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp new file mode 100644 index 00000000000..f97d6afa9ea --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp @@ -0,0 +1,64 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "clone_device_operation.hpp" + +namespace ttnn::operations::data_movement::clone { +void CloneOperation::validate_inputs( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& input = tensor_args.input; + if (operation_attributes.dtype != input.get_dtype()) + TT_FATAL(input.get_layout() == Layout::TILE, "dtype conversion is only supported with tile layout"); + TT_FATAL(input.storage_type() == StorageType::DEVICE, "input to clone must be on device"); + TT_FATAL(input.buffer() != nullptr, "input to clone must be allocated in buffer on device"); + TT_FATAL( + input.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, + "clone does not currently support sharding"); + TT_FATAL( + operation_attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED, + "clone does not currently support sharding"); +} + +CloneOperation::program_factory_t CloneOperation::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return ProgramFactory{}; +} + +void CloneOperation::validate_on_program_cache_miss( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_inputs(operation_attributes, tensor_args); +}; + +void CloneOperation::validate_on_program_cache_hit( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_inputs(operation_attributes, tensor_args); +}; + +CloneOperation::shape_return_value_t CloneOperation::compute_output_shapes( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return tensor_args.input.get_shape(); +}; + +CloneOperation::tensor_return_value_t CloneOperation::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& input = tensor_args.input; + return create_device_tensor( + compute_output_shapes(operation_attributes, tensor_args), + operation_attributes.dtype, + input.get_layout(), + input.device(), + operation_attributes.memory_config); +} + +std::tuple CloneOperation::invoke( + const Tensor& input, const std::optional& dtype, const std::optional& memory_config) { + return { + operation_attributes_t{ + dtype.value_or(input.get_dtype()), + memory_config.value_or(input.memory_config()), + }, + tensor_args_t{input}, + }; +} +} // namespace ttnn::operations::data_movement::clone diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.hpp new file mode 100644 index 00000000000..4e03caef5a2 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.hpp @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/decorators.hpp" + +namespace ttnn::operations::data_movement::clone { + +struct CloneOperation { + struct operation_attributes_t { + const DataType dtype; + const MemoryConfig memory_config; + }; + + struct tensor_args_t { + const Tensor& input; + }; + + using shape_return_value_t = Shape; + using tensor_return_value_t = Tensor; + + struct ProgramFactory { + struct shared_variables_t { + KernelHandle unary_reader_kernel_id; + KernelHandle unary_writer_kernel_id; + std::vector cores; + }; + + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + }; + + using program_factory_t = std::variant; + + static void validate_inputs(const operation_attributes_t&, const tensor_args_t&); + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); + + static std::tuple invoke( + const Tensor& input, const std::optional& dtype, const std::optional& memory_config); +}; + +} // namespace ttnn::operations::data_movement::clone + +namespace ttnn::prim { +constexpr auto clone = + ttnn::register_operation<"ttnn::prim::clone", ttnn::operations::data_movement::clone::CloneOperation>(); +} // namespace ttnn::prim diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp new file mode 100644 index 00000000000..5dca80d22e9 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp @@ -0,0 +1,171 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "clone_device_operation.hpp" +#include "tt_metal/common/work_split.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/math.hpp" + +namespace ttnn::operations::data_movement::clone { +CloneOperation::ProgramFactory::cached_program_t CloneOperation::ProgramFactory::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + const auto& input = tensor_args.input; + Program program = Program(); + + bool tilized = output.get_layout() == Layout::TILE; + + tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); + tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + bool convert_dtype = input_cb_data_format != output_cb_data_format; + uint32_t input_unit_size = tilized ? tt::tt_metal::detail::TileSize(input_cb_data_format) + : input.get_legacy_shape()[-1] * input.element_size(); + uint32_t output_unit_size = tilized ? tt::tt_metal::detail::TileSize(output_cb_data_format) + : output.get_legacy_shape()[-1] * output.element_size(); + + uint32_t num_units = + tilized ? output.volume() / tt::constants::TILE_HW : output.volume() / output.get_legacy_shape()[-1]; + + tt::tt_metal::Device* device = output.device(); + + auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); + uint32_t num_cores_x = compute_with_storage_grid_size.x; + uint32_t num_cores_y = compute_with_storage_grid_size.y; + auto [num_cores, all_cores, core_group_1, core_group_2, num_units_per_core_group_1, num_units_per_core_group_2] = + tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_units); + + uint32_t src0_cb_index = tt::CB::c_in0; + uint32_t num_input_units = 2; + uint32_t aligned_input_unit_size = round_up_to_mul32(input_unit_size); + tt::tt_metal::CircularBufferConfig cb_src0_config = + tt::tt_metal::CircularBufferConfig( + num_input_units * aligned_input_unit_size, {{src0_cb_index, input_cb_data_format}}) + .set_page_size(src0_cb_index, aligned_input_unit_size); + auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + + uint32_t output_cb_index = src0_cb_index; + if (convert_dtype) { + output_cb_index = tt::CB::c_out0; + uint32_t num_output_units = 2; + uint32_t aligned_output_unit_size = round_up_to_mul32(output_unit_size); + tt::tt_metal::CircularBufferConfig output_cb_config = + tt::tt_metal::CircularBufferConfig( + num_output_units * aligned_output_unit_size, {{output_cb_index, output_cb_data_format}}) + .set_page_size(output_cb_index, aligned_output_unit_size); + auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); + } + + auto src_buffer = input.buffer(); + auto dst_buffer = output.buffer(); + bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + + vector reader_compile_time_args, writer_compile_time_args; + if (tilized) { + reader_compile_time_args = {(uint32_t)src_is_dram}; + writer_compile_time_args = {(uint32_t)output_cb_index, (uint32_t)dst_is_dram}; + } else { + bool src_stick_size_is_power_of_two = is_power_of_two_at_least_32(input_unit_size); + uint32_t src_log2_stick_size = src_stick_size_is_power_of_two ? (uint32_t)log2(input_unit_size) : 0; + reader_compile_time_args = { + (uint32_t)src0_cb_index, + (uint32_t)src_is_dram, + (uint32_t)src_stick_size_is_power_of_two, + (uint32_t)src_log2_stick_size}; + bool dst_stick_size_is_power_of_two = is_power_of_two_at_least_32(output_unit_size); + uint32_t dst_log2_stick_size = dst_stick_size_is_power_of_two ? (uint32_t)log2(output_unit_size) : 0; + writer_compile_time_args = { + (uint32_t)output_cb_index, + (uint32_t)dst_is_dram, + (uint32_t)dst_stick_size_is_power_of_two, + (uint32_t)dst_log2_stick_size}; + } + map kernel_defines; + tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + program, + tilized ? "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel.cpp" + : "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel_rm.cpp", + all_cores, + tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines)); + + tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + program, + tilized ? "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel.cpp" + : "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel_rm.cpp", + all_cores, + tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args, kernel_defines)); + + if (convert_dtype) { + vector compute_kernel_args_group_1 = {num_units_per_core_group_1}; + auto eltwise_unary_kernel_group_1 = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp", + core_group_1, + tt::tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_1}); + + if (!core_group_2.ranges().empty()) { + vector compute_kernel_args_group_2 = {num_units_per_core_group_2}; + auto eltwise_unary_kernel_group_2 = tt::tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp", + core_group_2, + tt::tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_2}); + } + } + + uint32_t start_id = 0; + uint32_t g1_numcores = core_group_1.num_cores(); + auto cores = grid_to_cores(num_cores, num_cores_x, num_cores_y, false); + + for (uint32_t i = 0; i < cores.size(); ++i) { + const CoreCoord& core = cores.at(i); + uint32_t num_units_per_core = i < g1_numcores ? num_units_per_core_group_1 : num_units_per_core_group_2; + + if (tilized) { + tt::tt_metal::SetRuntimeArgs( + program, unary_reader_kernel_id, core, {src_buffer->address(), num_units_per_core, start_id}); + tt::tt_metal::SetRuntimeArgs( + program, unary_writer_kernel_id, core, {dst_buffer->address(), num_units_per_core, start_id}); + } else { + tt::tt_metal::SetRuntimeArgs( + program, + unary_reader_kernel_id, + core, + {src_buffer->address(), input_unit_size, num_units_per_core, start_id}); + tt::tt_metal::SetRuntimeArgs( + program, + unary_writer_kernel_id, + core, + {dst_buffer->address(), output_unit_size, num_units_per_core, start_id}); + } + start_id += num_units_per_core; + } + + return {std::move(program), {unary_reader_kernel_id, unary_writer_kernel_id, cores}}; +} + +void CloneOperation::ProgramFactory::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + const auto& program = cached_program.program; + const auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; + const auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + auto cores = cached_program.shared_variables.cores; + + auto src_buffer_address = tensor_args.input.buffer()->address(); + auto dst_buffer_address = output.buffer()->address(); + for (const auto& core : cores) { + { + auto& runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); + runtime_args[0] = src_buffer_address; + } + { + auto& runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args[0] = dst_buffer_address; + } + } +} +} // namespace ttnn::operations::data_movement::clone diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp new file mode 100644 index 00000000000..173632fe2d6 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp @@ -0,0 +1,31 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "compute_kernel_api/common.h" +#include "compute_kernel_api/eltwise_unary/eltwise_unary.h" +#include "compute_kernel_api/tile_move_copy.h" + +namespace NAMESPACE { +void MAIN { + uint32_t per_core_tile_count = get_compile_time_arg_val(0); + unary_op_init_common(tt::CB::c_in0); + for (uint32_t tile_index = 0; tile_index < per_core_tile_count; ++tile_index) { + acquire_dst(tt::DstMode::Half); + + cb_wait_front(tt::CB::c_in0, 1); + cb_reserve_back(tt::CB::c_out0, 1); + + copy_tile(tt::CB::c_in0, 0, 0); + + pack_tile(0, tt::CB::c_out0); + + cb_pop_front(tt::CB::c_in0, 1); + cb_push_back(tt::CB::c_out0, 1); + + release_dst(tt::DstMode::Half); + } +} +} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel.cpp new file mode 100644 index 00000000000..b90bb2ccb91 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel.cpp @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t src_addr = get_arg_val(0); + uint32_t num_tiles = get_arg_val(1); + uint32_t start_id = get_arg_val(2); + + constexpr bool src_is_dram = get_compile_time_arg_val(0) == 1; + + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t one_tile = 1; + + const uint32_t tile_bytes = get_tile_size(cb_id_in0); + const DataFormat data_format = get_dataformat(cb_id_in0); + + const InterleavedAddrGenFast s = { + .bank_base_address = src_addr, + .page_size = tile_bytes, + .data_format = data_format, + }; + + uint32_t end_id = start_id + num_tiles; + for (uint32_t i = start_id; i < end_id; ++i) { + cb_reserve_back(cb_id_in0, one_tile); + uint32_t l1_write_addr = get_write_ptr(cb_id_in0); + noc_async_read_tile(i, s, l1_write_addr); + noc_async_read_barrier(); + cb_push_back(cb_id_in0, one_tile); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel_rm.cpp new file mode 100644 index 00000000000..7f19c3f2636 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel_rm.cpp @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t src_addr = get_arg_val(0); + uint32_t stick_size = get_arg_val(1); + uint32_t num_sticks = get_arg_val(2); + uint32_t start_id = get_arg_val(3); + + constexpr uint32_t cb_id_in0 = get_compile_time_arg_val(0); + constexpr bool src0_is_dram = get_compile_time_arg_val(1) == 1; + +#define src_stick_size_is_pow2 get_compile_time_arg_val(2) == 1 +#if (src_stick_size_is_pow2) + constexpr uint32_t src_log_base_2_of_page_size = get_compile_time_arg_val(3); + const InterleavedPow2AddrGen s0 = { + .bank_base_address = src_addr, .log_base_2_of_page_size = src_log_base_2_of_page_size}; +#else + const InterleavedAddrGen s0 = {.bank_base_address = src_addr, .page_size = stick_size}; +#endif + + uint32_t end_id = start_id + num_sticks; + for (uint32_t i = start_id; i < end_id; ++i) { + cb_reserve_back(cb_id_in0, 1); + uint32_t l1_write_addr = get_write_ptr(cb_id_in0); + uint64_t src_noc_addr = get_noc_addr(i, s0); + noc_async_read(src_noc_addr, l1_write_addr, stick_size); + noc_async_read_barrier(); + cb_push_back(cb_id_in0, 1); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel.cpp new file mode 100644 index 00000000000..ce7c3420744 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel.cpp @@ -0,0 +1,30 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t dst_addr = get_arg_val(0); + uint32_t num_tiles = get_arg_val(1); + uint32_t start_id = get_arg_val(2); + + constexpr uint32_t cb_id_out = get_compile_time_arg_val(0); + constexpr bool dst_is_dram = get_compile_time_arg_val(1) == 1; + + constexpr uint32_t one_tile = 1; + const uint32_t tile_bytes = get_tile_size(cb_id_out); + const DataFormat data_format = get_dataformat(cb_id_out); + + const InterleavedAddrGenFast s = { + .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; + + uint32_t end_id = start_id + num_tiles; + for (uint32_t i = start_id; i < end_id; ++i) { + cb_wait_front(cb_id_out, one_tile); + uint32_t l1_read_addr = get_read_ptr(cb_id_out); + noc_async_write_tile(i, s, l1_read_addr); + noc_async_write_barrier(); + cb_pop_front(cb_id_out, one_tile); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel_rm.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel_rm.cpp new file mode 100644 index 00000000000..9b5e399e321 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel_rm.cpp @@ -0,0 +1,36 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "dataflow_api.h" + +void kernel_main() { + uint32_t dst_addr = get_arg_val(0); + uint32_t stick_size = get_arg_val(1); + uint32_t num_sticks = get_arg_val(2); + uint32_t start_id = get_arg_val(3); + + constexpr uint32_t cb_id_out0 = get_compile_time_arg_val(0); + constexpr bool dst0_is_dram = get_compile_time_arg_val(1) == 1; + +#define dst_stick_size_is_pow2 get_compile_time_arg_val(2) == 1 +#if (dst_stick_size_is_pow2) + constexpr uint32_t dst_log_base_2_of_page_size = get_compile_time_arg_val(3); + const InterleavedPow2AddrGen s0 = { + .bank_base_address = dst_addr, .log_base_2_of_page_size = dst_log_base_2_of_page_size}; +#else + const InterleavedAddrGen s0 = {.bank_base_address = dst_addr, .page_size = stick_size}; +#endif + + uint32_t end_id = start_id + num_sticks; + for (uint32_t i = start_id; i < end_id; ++i) { + cb_wait_front(cb_id_out0, 1); + uint32_t l1_read_addr = get_read_ptr(cb_id_out0); + uint64_t dst_noc_addr = get_noc_addr(i, s0); + noc_async_write(l1_read_addr, dst_noc_addr, stick_size); + noc_async_write_barrier(); + cb_pop_front(cb_id_out0, 1); + } +} diff --git a/ttnn/cpp/ttnn/operations/data_movement/copy/copy.cpp b/ttnn/cpp/ttnn/operations/data_movement/copy/copy.cpp index 1d09453941b..6be441de8cc 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/copy/copy.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/copy/copy.cpp @@ -2,72 +2,57 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttnn/operations/data_movement/copy/copy.hpp" +#include "device/copy_device_operation.hpp" #include "ttnn/common/constants.hpp" -#include "ttnn/run_operation.hpp" #include "ttnn/decorators.hpp" -#include "ttnn/operations/data_movement/copy/copy.hpp" -#include "device/copy_device_operation.hpp" +#include "ttnn/run_operation.hpp" namespace ttnn::operations::data_movement { -ttnn::Tensor CopyOperation::invoke( - uint8_t queue_id, - const Tensor& src_tensor, - const Tensor& dst_tensor) { - operation::run(CopyDeviceOperation{dst_tensor.memory_config(), dst_tensor.get_dtype()}, {src_tensor, dst_tensor}, {}, {}, queue_id); +ttnn::Tensor CopyOperation::invoke(uint8_t queue_id, const Tensor& src_tensor, const Tensor& dst_tensor) { + operation::run( + CopyDeviceOperation{dst_tensor.memory_config(), dst_tensor.get_dtype()}, + {src_tensor, dst_tensor}, + {}, + {}, + queue_id); return dst_tensor; } -ttnn::Tensor CopyOperation::invoke( - const Tensor& src_tensor, - const Tensor& dst_tensor) { +ttnn::Tensor CopyOperation::invoke(const Tensor& src_tensor, const Tensor& dst_tensor) { return invoke(ttnn::DefaultQueueId, src_tensor, dst_tensor); } -ttnn::Tensor CloneOperation::invoke( - uint8_t queue_id, - const Tensor& input_tensor, - const std::optional& output_mem_config, - const std::optional output_dtype) { - return operation::run(CopyDeviceOperation{output_mem_config.value_or(input_tensor.memory_config()), output_dtype.value_or(input_tensor.get_dtype())}, {input_tensor}, {}, {}, queue_id).at(0); -} - -ttnn::Tensor CloneOperation::invoke( - const ttnn::Tensor& input_tensor, - const std::optional& output_mem_config, - const std::optional output_dtype) { - return invoke(ttnn::DefaultQueueId, input_tensor, output_mem_config, output_dtype); -} - ttnn::Tensor AssignOperation::invoke( uint8_t queue_id, const Tensor& input, const MemoryConfig& output_mem_config, std::optional output_dtype, std::optional optional_output_tensor) { - return operation::run(CopyDeviceOperation{output_mem_config, output_dtype.value_or(input.get_dtype())}, {input}, {}, {optional_output_tensor}, queue_id).at(0); + return operation::run( + CopyDeviceOperation{output_mem_config, output_dtype.value_or(input.get_dtype())}, + {input}, + {}, + {optional_output_tensor}, + queue_id) + .at(0); } ttnn::Tensor AssignOperation::invoke( - const Tensor& input, - const MemoryConfig& output_mem_config, - std::optional output_dtype) { + const Tensor& input, const MemoryConfig& output_mem_config, std::optional output_dtype) { return invoke(ttnn::DefaultQueueId, input, output_mem_config, output_dtype); } -ttnn::Tensor AssignOperation::invoke( - uint8_t queue_id, - const Tensor& input_a, - const Tensor& input_b) { - operation::run(CopyDeviceOperation{input_b.memory_config(), input_b.get_dtype()}, {input_a, input_b}, {}, {}, queue_id); +ttnn::Tensor AssignOperation::invoke(uint8_t queue_id, const Tensor& input_a, const Tensor& input_b) { + operation::run( + CopyDeviceOperation{input_b.memory_config(), input_b.get_dtype()}, {input_a, input_b}, {}, {}, queue_id); return input_b; } -ttnn::Tensor AssignOperation::invoke( - const Tensor& input_a, - const Tensor& input_b) { +ttnn::Tensor AssignOperation::invoke(const Tensor& input_a, const Tensor& input_b) { return invoke(ttnn::DefaultQueueId, input_a, input_b); } -} // ttnn::operations::data_movement namespace +} // namespace ttnn::operations::data_movement diff --git a/ttnn/cpp/ttnn/operations/data_movement/copy/copy.hpp b/ttnn/cpp/ttnn/operations/data_movement/copy/copy.hpp index e90af705234..85cb979f2f3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/copy/copy.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/copy/copy.hpp @@ -4,35 +4,17 @@ #pragma once -#include "ttnn/decorators.hpp" #include +#include "ttnn/decorators.hpp" + namespace ttnn { namespace operations::data_movement { struct CopyOperation { - static ttnn::Tensor invoke( - uint8_t queue_id, - const Tensor& src_tensor, - const Tensor& dst_tensor); - - static ttnn::Tensor invoke( - const Tensor& src_tensor, - const Tensor& dst_tensor); -}; + static ttnn::Tensor invoke(uint8_t queue_id, const Tensor& src_tensor, const Tensor& dst_tensor); -struct CloneOperation { - static ttnn::Tensor invoke( - uint8_t queue_id, - const Tensor& input_tensor, - const std::optional& output_mem_config = std::nullopt, - const std::optional output_dtype = std::nullopt); - - - static ttnn::Tensor invoke( - const Tensor& input_tensor, - const std::optional& output_mem_config = std::nullopt, - const std::optional output_dtype = std::nullopt); + static ttnn::Tensor invoke(const Tensor& src_tensor, const Tensor& dst_tensor); }; struct AssignOperation { @@ -48,20 +30,16 @@ struct AssignOperation { const MemoryConfig& output_mem_config, std::optional output_dtype = std::nullopt); - static ttnn::Tensor invoke( - uint8_t queue_id, - const Tensor& input_a, - const Tensor& input_b); + static ttnn::Tensor invoke(uint8_t queue_id, const Tensor& input_a, const Tensor& input_b); - static ttnn::Tensor invoke( - const Tensor& input_a, - const Tensor& input_b); + static ttnn::Tensor invoke(const Tensor& input_a, const Tensor& input_b); }; } // namespace operations::data_movement -constexpr auto copy = ttnn::register_operation_with_auto_launch_op<"ttnn::copy", ttnn::operations::data_movement::CopyOperation>(); -constexpr auto clone = ttnn::register_operation_with_auto_launch_op<"ttnn::clone", ttnn::operations::data_movement::CloneOperation>(); -constexpr auto assign = ttnn::register_operation_with_auto_launch_op<"ttnn::assign", ttnn::operations::data_movement::AssignOperation>(); +constexpr auto copy = + ttnn::register_operation_with_auto_launch_op<"ttnn::copy", ttnn::operations::data_movement::CopyOperation>(); +constexpr auto assign = + ttnn::register_operation_with_auto_launch_op<"ttnn::assign", ttnn::operations::data_movement::AssignOperation>(); } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/data_movement/copy/copy_pybind.cpp b/ttnn/cpp/ttnn/operations/data_movement/copy/copy_pybind.cpp index c876ce7f232..edfdb1e725d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/copy/copy_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/copy/copy_pybind.cpp @@ -77,38 +77,6 @@ void py_bind_copy(py::module& module) { py::arg("queue_id") = 0}); } -void py_bind_clone(py::module& module) { - auto doc = R"doc(clone(tensor: ttnn.Tensor, memory_config: MemoryConfig, dtype: DataType) -> ttnn.Tensor - - Clones the tensor by copying it with the given `memory config`. Also, converts the dataype to `dtype`. - Note: clone does not change the layout of the tensor. - Organizes the `ttnn.Tensor` :attr:`tensor` into either ROW_MAJOR_LAYOUT or TILE_LAYOUT. When requesting ROW_MAJOR_LAYOUT - the tensor will be returned unpadded in the last two dimensions. When requesting TILE_LAYOUT the tensor will be automatically - padded where the width and height become multiples of 32. - In the case where the layout is the same, the operation simply pad or unpad the last two dimensions depending on layout requested. - - Args: - * :attr:`tensor`: the ttnn.Tensor - * :attr:`memory_config`: the `ttnn` memory config, DRAM_MEMORY_CONFIG or L1_MEMORY_CONFIG. - * :attr:`dtype`: the `ttnn` data type.)doc"; - - bind_registered_operation( - module, - ttnn::clone, - doc, - ttnn::pybind_overload_t{ - [](const decltype(ttnn::clone)& self, - const ttnn::Tensor& input_tensor, - const std::optional& memory_config, - const std::optional dtype, - uint8_t queue_id) { return self(queue_id, input_tensor, memory_config, dtype); }, - py::arg("input_tensor").noconvert(), - py::kw_only(), - py::arg("memory_config") = std::nullopt, - py::arg("dtype") = std::nullopt, - py::arg("queue_id") = 0}); -} - void py_bind_assign(py::module& module) { auto doc = get_unary_doc_string( "assign", "input", R"doc( Returns a new tensor which is a new copy of input tensor ``{0}``. @@ -135,20 +103,18 @@ void py_bind_assign(py::module& module) { ttnn::assign, doc, ttnn::pybind_overload_t{ - [] (const decltype(ttnn::assign)& self, - const ttnn::Tensor& input, - const ttnn::MemoryConfig memory_config, - const std::optional dtype, - std::optional &optional_output_tensor, - uint8_t queue_id) { - return self(queue_id, input, memory_config, dtype, optional_output_tensor); - }, - py::arg("input_tensor").noconvert(), - py::kw_only(), - py::arg("memory_config"), - py::arg("dtype") = std::nullopt, - py::arg("output_tensor") = std::nullopt, - py::arg("queue_id") = 0}, + [](const decltype(ttnn::assign)& self, + const ttnn::Tensor& input, + const ttnn::MemoryConfig memory_config, + const std::optional dtype, + std::optional& optional_output_tensor, + uint8_t queue_id) { return self(queue_id, input, memory_config, dtype, optional_output_tensor); }, + py::arg("input_tensor").noconvert(), + py::kw_only(), + py::arg("memory_config"), + py::arg("dtype") = std::nullopt, + py::arg("output_tensor") = std::nullopt, + py::arg("queue_id") = 0}, ttnn::pybind_overload_t{ [](const decltype(ttnn::assign)& self, const ttnn::Tensor& input_a, diff --git a/ttnn/cpp/ttnn/operations/data_movement/copy/copy_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/copy/copy_pybind.hpp index f98cde14bb3..51032e8c44a 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/copy/copy_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/copy/copy_pybind.hpp @@ -9,7 +9,6 @@ namespace ttnn::operations::data_movement::detail { void py_bind_copy(pybind11::module& m); -void py_bind_clone(pybind11::module& m); void py_bind_assign(pybind11::module& m); } // namespace ttnn::operations::data_movement::detail diff --git a/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp index b1717978d33..4a55184a73d 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp @@ -4,40 +4,41 @@ #pragma once - #include #include #include "ttnn/cpp/pybind11/decorators.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded_pybind.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/reshard_pybind.hpp" +#include "ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved_pybind.hpp" +#include "ttnn/operations/data_movement/bcast/bcast_pybind.hpp" +#include "ttnn/operations/data_movement/clone/clone_pybind.hpp" #include "ttnn/operations/data_movement/concat/concat_pybind.hpp" +#include "ttnn/operations/data_movement/copy/copy_pybind.hpp" +#include "ttnn/operations/data_movement/fill_rm/fill_rm_pybind.hpp" +#include "ttnn/operations/data_movement/fold/fold_pybind.hpp" +#include "ttnn/operations/data_movement/indexed_fill/indexed_fill_pybind.hpp" +#include "ttnn/operations/data_movement/move/move_pybind.hpp" +#include "ttnn/operations/data_movement/non_zero_indices/non_zero_indices_pybind.hpp" #include "ttnn/operations/data_movement/pad/pad_pybind.hpp" #include "ttnn/operations/data_movement/permute/permute_pybind.hpp" +#include "ttnn/operations/data_movement/repeat/repeat_pybind.hpp" +#include "ttnn/operations/data_movement/repeat_interleave/repeat_interleave_pybind.hpp" +#include "ttnn/operations/data_movement/reshape/reshape_pybind.hpp" +#include "ttnn/operations/data_movement/reshape_on_device/reshape_pybind.hpp" +#include "ttnn/operations/data_movement/reshape_view/reshape_pybind.hpp" +#include "ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/interleaved_to_sharded_partial_pybind.hpp" +#include "ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/sharded_to_interleaved_partial_pybind.hpp" #include "ttnn/operations/data_movement/slice/slice_pybind.hpp" +#include "ttnn/operations/data_movement/split/split_pybind.hpp" +#include "ttnn/operations/data_movement/squeeze/squeeze_pybind.hpp" #include "ttnn/operations/data_movement/tilize/tilize_pybind.hpp" #include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding_pybind.hpp" -#include "ttnn/operations/data_movement/repeat_interleave/repeat_interleave_pybind.hpp" #include "ttnn/operations/data_movement/transpose/transpose_pybind.hpp" -#include "ttnn/operations/data_movement/split/split_pybind.hpp" +#include "ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.hpp" #include "ttnn/operations/data_movement/untilize/untilize_pybind.hpp" -#include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding_pybind.hpp" #include "ttnn/operations/data_movement/untilize_with_halo_v2/untilize_with_halo_v2_pybind.hpp" -#include "ttnn/operations/data_movement/non_zero_indices/non_zero_indices_pybind.hpp" -#include "ttnn/operations/data_movement/fill_rm/fill_rm_pybind.hpp" -#include "ttnn/operations/data_movement/repeat/repeat_pybind.hpp" -#include "ttnn/operations/data_movement/fold/fold_pybind.hpp" -#include "ttnn/operations/data_movement/sharded_partial/sharded_to_interleaved_partial/sharded_to_interleaved_partial_pybind.hpp" -#include "ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/interleaved_to_sharded_partial_pybind.hpp" -#include "ttnn/operations/data_movement/reshape_on_device/reshape_pybind.hpp" -#include "ttnn/operations/data_movement/reshape_view/reshape_pybind.hpp" -#include "ttnn/operations/data_movement/unsqueeze/unsqueeze_pybind.hpp" -#include "ttnn/operations/data_movement/squeeze/squeeze_pybind.hpp" -#include "ttnn/cpp/ttnn/operations/data_movement/sharded/interleaved_to_sharded/interleaved_to_sharded_pybind.hpp" -#include "ttnn/cpp/ttnn/operations/data_movement/sharded/sharded_to_interleaved/sharded_to_interleaved_pybind.hpp" -#include "ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/reshard_pybind.hpp" -#include "ttnn/operations/data_movement/indexed_fill/indexed_fill_pybind.hpp" -#include "ttnn/operations/data_movement/copy/copy_pybind.hpp" -#include "ttnn/operations/data_movement/move/move_pybind.hpp" -#include "ttnn/operations/data_movement/bcast/bcast_pybind.hpp" +#include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding_pybind.hpp" namespace py = pybind11; @@ -45,42 +46,41 @@ namespace ttnn { namespace operations { namespace data_movement { - void py_module(py::module& module) { - detail::bind_permute(module); + bind_fill_rm(module); + bind_fold_operation(module); + bind_non_zero_indices(module); + clone::bind_clone_operation(module); detail::bind_concat(module); + detail::bind_indexed_fill(module); detail::bind_pad(module); - detail::bind_slice(module); + detail::bind_permute(module); detail::bind_repeat_interleave(module); + detail::bind_slice(module); + detail::bind_split(module); detail::bind_tilize(module); detail::bind_tilize_with_val_padding(module); detail::bind_tilize_with_zero_padding(module); detail::bind_transpose(module); - detail::bind_split(module); detail::bind_untilize(module); - detail::bind_untilize_with_unpadding(module); detail::bind_untilize_with_halo_v2(module); - bind_non_zero_indices(module); - bind_fill_rm(module); - py_bind_repeat(module); - py_bind_reshape(module); - py_bind_reshape_view(module); - py_bind_unsqueeze(module); - py_bind_squeeze(module); - detail::bind_indexed_fill(module); - bind_fold_operation(module); - py_bind_sharded_to_interleaved_partial(module); - py_bind_interleaved_to_sharded_partial(module); - detail::py_bind_copy(module); - detail::py_bind_clone(module); + detail::bind_untilize_with_unpadding(module); detail::py_bind_assign(module); + detail::py_bind_bcast(module); + detail::py_bind_clone(module); + detail::py_bind_copy(module); detail::py_bind_move(module); - py_bind_sharded_to_interleaved(module); py_bind_interleaved_to_sharded(module); + py_bind_interleaved_to_sharded_partial(module); + py_bind_repeat(module); + py_bind_reshape(module); + py_bind_reshape_view(module); py_bind_reshard(module); - detail::py_bind_bcast(module); + py_bind_sharded_to_interleaved(module); + py_bind_sharded_to_interleaved_partial(module); + py_bind_squeeze(module); + py_bind_unsqueeze(module); } - } // namespace data_movement } // namespace operations } // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp index 39de9930da3..f4af25a64b7 100644 --- a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp @@ -3,20 +3,20 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/operations/experimental/auto_format/auto_format.hpp" -#include "ttnn/tensor/tensor.hpp" + +#include "tt_metal/common/constants.hpp" +#include "tt_metal/host_api.hpp" +#include "ttnn/operations/data_movement/clone/clone.hpp" #include "ttnn/operations/data_movement/data_transfer/data_transfer.hpp" #include "ttnn/operations/data_movement/pad/pad.hpp" +#include "ttnn/operations/data_movement/slice/slice.hpp" #include "ttnn/operations/data_movement/tilize/tilize.hpp" #include "ttnn/operations/data_movement/tilize_with_val_padding/tilize_with_val_padding.hpp" -#include "ttnn/operations/data_movement/slice/slice.hpp" #include "ttnn/operations/data_movement/untilize/untilize.hpp" #include "ttnn/operations/data_movement/untilize_with_unpadding/untilize_with_unpadding.hpp" -#include "ttnn/operations/data_movement/copy/copy.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" - +#include "ttnn/tensor/tensor.hpp" -namespace ttnn::operations::experimental::auto_format{ +namespace ttnn::operations::experimental::auto_format { Tensor AutoFormat::move_tensor_to_device(const Tensor& input, Device* device, const MemoryConfig& mem_config) { if (input.storage_type() != StorageType::DEVICE) { @@ -30,7 +30,7 @@ Tensor AutoFormat::move_tensor_to_mem_config(const Tensor& input, const MemoryCo if (input.storage_type() != StorageType::DEVICE) { return ttnn::data_transfer_to_device(input, AutoFormat::GetDefaultDevice(), mem_config); } else if (input.memory_config() != mem_config) { - return ttnn::clone(input, mem_config); + return ttnn::clone(input, std::nullopt, mem_config); } else { return input; } @@ -40,17 +40,18 @@ Tensor AutoFormat::move_tensor_to_mem_config(const Tensor& input, const MemoryCo // are not quite ready. So here we basically just put the tensor back on device. // Used in backward_ops.cpp // See: Remove auto format within permute_op.cpp #9404 -Tensor AutoFormat::move_tensor_to_device_and_pad(const Tensor& input, Device *device, Layout target_layout, std::optional target_mem_config){ +Tensor AutoFormat::move_tensor_to_device_and_pad( + const Tensor& input, Device* device, Layout target_layout, std::optional target_mem_config) { using namespace tt::constants; const auto intended_shape = input.get_shape(); const auto device_shape = input.get_legacy_shape(); - const auto new_intended_shape = std::array{intended_shape[0], intended_shape[1], intended_shape[-2], intended_shape[-1]}; - const auto new_device_shape = std::array{ + const auto new_intended_shape = + std::array{intended_shape[0], intended_shape[1], intended_shape[-2], intended_shape[-1]}; + const auto new_device_shape = std::array{ device_shape[0], device_shape[1], (device_shape[-2] % TILE_HEIGHT != 0 ? (device_shape[-2] / TILE_HEIGHT + 1) * TILE_HEIGHT : device_shape[-2]), - (device_shape[-1] % TILE_WIDTH != 0 ? (device_shape[-1] / TILE_WIDTH + 1) * TILE_WIDTH : device_shape[-1]) - }; + (device_shape[-1] % TILE_WIDTH != 0 ? (device_shape[-1] / TILE_WIDTH + 1) * TILE_WIDTH : device_shape[-1])}; const auto new_shape = tt::tt_metal::LegacyShape(new_intended_shape, new_device_shape); return AutoFormat::format_input_tensor(input, device, new_shape, 0.0, target_layout, target_mem_config); } @@ -90,14 +91,28 @@ Tensor AutoFormat::format_input_tensor( } } else if (!convert_layout && pad_input) { if (formatted_input.get_layout() == Layout::ROW_MAJOR || formatted_input.get_layout() == Layout::TILE) { - return ttnn::pad(0, (const ttnn::Tensor) formatted_input, padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), pad_value, false, mem_config); + return ttnn::pad( + 0, + (const ttnn::Tensor)formatted_input, + padded_shape.to_array_4D(), + tt::tt_metal::Array4D({0, 0, 0, 0}), + pad_value, + false, + mem_config); } } else if (convert_layout && pad_input) { if (formatted_input.get_layout() == Layout::ROW_MAJOR && target_layout == Layout::TILE) { return ttnn::tilize_with_val_padding(formatted_input, padded_shape, pad_value, mem_config); } else if (formatted_input.get_layout() == Layout::TILE && target_layout == Layout::ROW_MAJOR) { formatted_input = ttnn::untilize(formatted_input, mem_config); - return ttnn::pad(0, (const ttnn::Tensor) formatted_input, padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), pad_value, false, mem_config); + return ttnn::pad( + 0, + (const ttnn::Tensor)formatted_input, + padded_shape.to_array_4D(), + tt::tt_metal::Array4D({0, 0, 0, 0}), + pad_value, + false, + mem_config); } } // Fall back to host conversions @@ -110,7 +125,11 @@ Tensor AutoFormat::format_input_tensor( formatted_input = formatted_input.to(Layout::ROW_MAJOR); convert_layout = formatted_input.get_layout() != target_layout; } - formatted_input = ttnn::pad((const ttnn::Tensor)formatted_input, padded_shape.to_array_4D(), tt::tt_metal::Array4D({0, 0, 0, 0}), pad_value); + formatted_input = ttnn::pad( + (const ttnn::Tensor)formatted_input, + padded_shape.to_array_4D(), + tt::tt_metal::Array4D({0, 0, 0, 0}), + pad_value); } if (convert_layout) { @@ -162,13 +181,7 @@ Tensor AutoFormat::format_output_tensor( auto ends = std::array({shape[0], shape[1], shape[2], shape[3]}); auto step = std::array({1, 1, 1, 1}); - formatted_output = ttnn::slice( - DefaultQueueId, - formatted_output, - begins, - ends, - step, - mem_config); + formatted_output = ttnn::slice(DefaultQueueId, formatted_output, begins, ends, step, mem_config); return formatted_output; // Output is tile but shape cannot be tile. We leave in RM } else if (formatted_output.get_layout() == Layout::TILE && AutoFormat::legal_rm_shape(shape)) { @@ -189,16 +202,10 @@ Tensor AutoFormat::format_output_tensor( } else if ( formatted_output.get_layout() == Layout::ROW_MAJOR && target_layout == Layout::TILE && AutoFormat::legal_tile_shape(shape)) { - auto begins = std::array({0, 0, 0, 0}); - auto ends = std::array({shape[0], shape[1], shape[2], shape[3]}); - auto step = std::array({1, 1, 1, 1}); - formatted_output = ttnn::slice( - DefaultQueueId, - formatted_output, - begins, - ends, - step, - mem_config); + auto begins = std::vector({0, 0, 0, 0}); + auto ends = std::vector({shape[0], shape[1], shape[2], shape[3]}); + auto step = std::vector({1, 1, 1, 1}); + formatted_output = ttnn::slice(DefaultQueueId, formatted_output, begins, ends, step, mem_config); formatted_output = ttnn::tilize(formatted_output, mem_config); return formatted_output; } @@ -214,11 +221,10 @@ Tensor AutoFormat::format_output_tensor( formatted_output = formatted_output.to(Layout::ROW_MAJOR); convert_layout = formatted_output.get_layout() != target_layout; } - auto begins = std::array({0, 0, 0, 0}); - auto ends = std::array({shape[0], shape[1], shape[2], shape[3]}); - auto step = std::array({1, 1, 1, 1}); - formatted_output = - ttnn::slice(formatted_output, begins, ends, step, std::nullopt); + auto begins = std::vector({0, 0, 0, 0}); + auto ends = std::vector({shape[0], shape[1], shape[2], shape[3]}); + auto step = std::vector({1, 1, 1, 1}); + formatted_output = ttnn::slice(formatted_output, begins, ends, step, std::nullopt); } if (convert_layout) { @@ -243,4 +249,4 @@ Tensor AutoFormat::format_output_tensor( return formatted_output; } -} //namespace ttnn::operations::auto_format +} // namespace ttnn::operations::experimental::auto_format From d0c74b54f569632bb4f1ac04a075b4d6a5d3397b Mon Sep 17 00:00:00 2001 From: Shaw Nguyen Date: Mon, 7 Oct 2024 09:15:30 +0000 Subject: [PATCH 18/58] #13463: Revise Reviews 1 --- .../clone/device/clone_program_factory.cpp | 67 ++++++++++--------- .../clone/device/kernels/compute_kernel.cpp | 11 ++- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp index 5dca80d22e9..5d8bd7a59d8 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp @@ -11,55 +11,58 @@ CloneOperation::ProgramFactory::cached_program_t CloneOperation::ProgramFactory: const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args, tensor_return_value_t& output) { + using namespace tt; + using namespace tt::tt_metal; + using namespace tt::tt_metal::detail; + const auto& input = tensor_args.input; Program program = Program(); bool tilized = output.get_layout() == Layout::TILE; - tt::DataFormat input_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - tt::DataFormat output_cb_data_format = tt::tt_metal::datatype_to_dataformat_converter(output.get_dtype()); + tt::DataFormat input_cb_data_format = datatype_to_dataformat_converter(input.get_dtype()); + tt::DataFormat output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); bool convert_dtype = input_cb_data_format != output_cb_data_format; - uint32_t input_unit_size = tilized ? tt::tt_metal::detail::TileSize(input_cb_data_format) - : input.get_legacy_shape()[-1] * input.element_size(); - uint32_t output_unit_size = tilized ? tt::tt_metal::detail::TileSize(output_cb_data_format) - : output.get_legacy_shape()[-1] * output.element_size(); + uint32_t input_unit_size = + tilized ? TileSize(input_cb_data_format) : input.get_legacy_shape()[-1] * input.element_size(); + uint32_t output_unit_size = + tilized ? TileSize(output_cb_data_format) : output.get_legacy_shape()[-1] * output.element_size(); uint32_t num_units = - tilized ? output.volume() / tt::constants::TILE_HW : output.volume() / output.get_legacy_shape()[-1]; + tilized ? output.volume() / constants::TILE_HW : output.volume() / output.get_legacy_shape()[-1]; - tt::tt_metal::Device* device = output.device(); + Device* device = output.device(); auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); uint32_t num_cores_x = compute_with_storage_grid_size.x; uint32_t num_cores_y = compute_with_storage_grid_size.y; auto [num_cores, all_cores, core_group_1, core_group_2, num_units_per_core_group_1, num_units_per_core_group_2] = - tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_units); + split_work_to_cores(compute_with_storage_grid_size, num_units); - uint32_t src0_cb_index = tt::CB::c_in0; + uint32_t src0_cb_index = CB::c_in0; uint32_t num_input_units = 2; uint32_t aligned_input_unit_size = round_up_to_mul32(input_unit_size); - tt::tt_metal::CircularBufferConfig cb_src0_config = - tt::tt_metal::CircularBufferConfig( - num_input_units * aligned_input_unit_size, {{src0_cb_index, input_cb_data_format}}) + auto cb_src0_config = + CircularBufferConfig(num_input_units * aligned_input_unit_size, {{src0_cb_index, input_cb_data_format}}) .set_page_size(src0_cb_index, aligned_input_unit_size); - auto cb_src0 = tt::tt_metal::CreateCircularBuffer(program, all_cores, cb_src0_config); + auto cb_src0 = CreateCircularBuffer(program, all_cores, cb_src0_config); uint32_t output_cb_index = src0_cb_index; if (convert_dtype) { - output_cb_index = tt::CB::c_out0; + output_cb_index = CB::c_out0; uint32_t num_output_units = 2; uint32_t aligned_output_unit_size = round_up_to_mul32(output_unit_size); - tt::tt_metal::CircularBufferConfig output_cb_config = - tt::tt_metal::CircularBufferConfig( + auto output_cb_config = + CircularBufferConfig( num_output_units * aligned_output_unit_size, {{output_cb_index, output_cb_data_format}}) .set_page_size(output_cb_index, aligned_output_unit_size); - auto cb_output = tt::tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); + auto cb_output = CreateCircularBuffer(program, all_cores, output_cb_config); } auto src_buffer = input.buffer(); auto dst_buffer = output.buffer(); - bool src_is_dram = src_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = dst_buffer->buffer_type() == tt::tt_metal::BufferType::DRAM ? 1 : 0; + bool src_is_dram = src_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; + bool dst_is_dram = dst_buffer->buffer_type() == BufferType::DRAM ? 1 : 0; vector reader_compile_time_args, writer_compile_time_args; if (tilized) { @@ -82,35 +85,35 @@ CloneOperation::ProgramFactory::cached_program_t CloneOperation::ProgramFactory: (uint32_t)dst_log2_stick_size}; } map kernel_defines; - tt::tt_metal::KernelHandle unary_reader_kernel_id = tt::tt_metal::CreateKernel( + KernelHandle unary_reader_kernel_id = CreateKernel( program, tilized ? "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel.cpp" : "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/read_kernel_rm.cpp", all_cores, - tt::tt_metal::ReaderDataMovementConfig(reader_compile_time_args, kernel_defines)); + ReaderDataMovementConfig(reader_compile_time_args, kernel_defines)); - tt::tt_metal::KernelHandle unary_writer_kernel_id = tt::tt_metal::CreateKernel( + KernelHandle unary_writer_kernel_id = CreateKernel( program, tilized ? "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel.cpp" : "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/write_kernel_rm.cpp", all_cores, - tt::tt_metal::WriterDataMovementConfig(writer_compile_time_args, kernel_defines)); + WriterDataMovementConfig(writer_compile_time_args, kernel_defines)); if (convert_dtype) { vector compute_kernel_args_group_1 = {num_units_per_core_group_1}; - auto eltwise_unary_kernel_group_1 = tt::tt_metal::CreateKernel( + auto eltwise_unary_kernel_group_1 = CreateKernel( program, "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp", core_group_1, - tt::tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_1}); + ComputeConfig{.compile_args = compute_kernel_args_group_1}); if (!core_group_2.ranges().empty()) { vector compute_kernel_args_group_2 = {num_units_per_core_group_2}; - auto eltwise_unary_kernel_group_2 = tt::tt_metal::CreateKernel( + auto eltwise_unary_kernel_group_2 = CreateKernel( program, "ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp", core_group_2, - tt::tt_metal::ComputeConfig{.compile_args = compute_kernel_args_group_2}); + ComputeConfig{.compile_args = compute_kernel_args_group_2}); } } @@ -123,17 +126,17 @@ CloneOperation::ProgramFactory::cached_program_t CloneOperation::ProgramFactory: uint32_t num_units_per_core = i < g1_numcores ? num_units_per_core_group_1 : num_units_per_core_group_2; if (tilized) { - tt::tt_metal::SetRuntimeArgs( + SetRuntimeArgs( program, unary_reader_kernel_id, core, {src_buffer->address(), num_units_per_core, start_id}); - tt::tt_metal::SetRuntimeArgs( + SetRuntimeArgs( program, unary_writer_kernel_id, core, {dst_buffer->address(), num_units_per_core, start_id}); } else { - tt::tt_metal::SetRuntimeArgs( + SetRuntimeArgs( program, unary_reader_kernel_id, core, {src_buffer->address(), input_unit_size, num_units_per_core, start_id}); - tt::tt_metal::SetRuntimeArgs( + SetRuntimeArgs( program, unary_writer_kernel_id, core, diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp index 173632fe2d6..e93e2a1ed08 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/kernels/compute_kernel.cpp @@ -13,19 +13,16 @@ void MAIN { uint32_t per_core_tile_count = get_compile_time_arg_val(0); unary_op_init_common(tt::CB::c_in0); for (uint32_t tile_index = 0; tile_index < per_core_tile_count; ++tile_index) { - acquire_dst(tt::DstMode::Half); - cb_wait_front(tt::CB::c_in0, 1); cb_reserve_back(tt::CB::c_out0, 1); - + tile_regs_acquire(); copy_tile(tt::CB::c_in0, 0, 0); - + tile_regs_commit(); + tile_regs_wait(); pack_tile(0, tt::CB::c_out0); - + tile_regs_release(); cb_pop_front(tt::CB::c_in0, 1); cb_push_back(tt::CB::c_out0, 1); - - release_dst(tt::DstMode::Half); } } } // namespace NAMESPACE From 957956d564e64398cff35e0ddde7b0becc9aa00b Mon Sep 17 00:00:00 2001 From: Shaw Nguyen Date: Tue, 8 Oct 2024 03:12:41 +0000 Subject: [PATCH 19/58] #13463: Revise Reviews 2 --- .../clone/device/clone_device_operation.cpp | 10 +++++----- .../clone/device/clone_program_factory.cpp | 6 +++--- .../data_movement/data_movement_pybind.hpp | 2 -- .../experimental/auto_format/auto_format.cpp | 12 ++++++------ 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp index f97d6afa9ea..b52a85b1167 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_device_operation.cpp @@ -9,15 +9,15 @@ void CloneOperation::validate_inputs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { const auto& input = tensor_args.input; if (operation_attributes.dtype != input.get_dtype()) - TT_FATAL(input.get_layout() == Layout::TILE, "dtype conversion is only supported with tile layout"); - TT_FATAL(input.storage_type() == StorageType::DEVICE, "input to clone must be on device"); - TT_FATAL(input.buffer() != nullptr, "input to clone must be allocated in buffer on device"); + TT_FATAL(input.get_layout() == Layout::TILE, "Clone: data type conversion is only supported with tile layout"); + TT_FATAL(input.storage_type() == StorageType::DEVICE, "Clone: input must be on device"); + TT_FATAL(input.buffer() != nullptr, "Clone: input must be allocated in buffer on device"); TT_FATAL( input.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, - "clone does not currently support sharding"); + "Clone: not currently support sharding"); TT_FATAL( operation_attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED, - "clone does not currently support sharding"); + "Clone: not currently support sharding"); } CloneOperation::program_factory_t CloneOperation::select_program_factory( diff --git a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp index 5d8bd7a59d8..874e7fceba0 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/clone/device/clone_program_factory.cpp @@ -24,12 +24,12 @@ CloneOperation::ProgramFactory::cached_program_t CloneOperation::ProgramFactory: tt::DataFormat output_cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); bool convert_dtype = input_cb_data_format != output_cb_data_format; uint32_t input_unit_size = - tilized ? TileSize(input_cb_data_format) : input.get_legacy_shape()[-1] * input.element_size(); + tilized ? TileSize(input_cb_data_format) : input.get_logical_shape()[-1] * input.element_size(); uint32_t output_unit_size = - tilized ? TileSize(output_cb_data_format) : output.get_legacy_shape()[-1] * output.element_size(); + tilized ? TileSize(output_cb_data_format) : output.get_logical_shape()[-1] * output.element_size(); uint32_t num_units = - tilized ? output.volume() / constants::TILE_HW : output.volume() / output.get_legacy_shape()[-1]; + tilized ? output.volume() / constants::TILE_HW : output.volume() / output.get_logical_shape()[-1]; Device* device = output.device(); diff --git a/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp b/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp index 4a55184a73d..e2600fbe947 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/data_movement_pybind.hpp @@ -24,7 +24,6 @@ #include "ttnn/operations/data_movement/permute/permute_pybind.hpp" #include "ttnn/operations/data_movement/repeat/repeat_pybind.hpp" #include "ttnn/operations/data_movement/repeat_interleave/repeat_interleave_pybind.hpp" -#include "ttnn/operations/data_movement/reshape/reshape_pybind.hpp" #include "ttnn/operations/data_movement/reshape_on_device/reshape_pybind.hpp" #include "ttnn/operations/data_movement/reshape_view/reshape_pybind.hpp" #include "ttnn/operations/data_movement/sharded_partial/interleaved_to_sharded_partial/interleaved_to_sharded_partial_pybind.hpp" @@ -67,7 +66,6 @@ void py_module(py::module& module) { detail::bind_untilize_with_unpadding(module); detail::py_bind_assign(module); detail::py_bind_bcast(module); - detail::py_bind_clone(module); detail::py_bind_copy(module); detail::py_bind_move(module); py_bind_interleaved_to_sharded(module); diff --git a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp index f4af25a64b7..e1782d3afdc 100644 --- a/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/auto_format/auto_format.cpp @@ -202,9 +202,9 @@ Tensor AutoFormat::format_output_tensor( } else if ( formatted_output.get_layout() == Layout::ROW_MAJOR && target_layout == Layout::TILE && AutoFormat::legal_tile_shape(shape)) { - auto begins = std::vector({0, 0, 0, 0}); - auto ends = std::vector({shape[0], shape[1], shape[2], shape[3]}); - auto step = std::vector({1, 1, 1, 1}); + auto begins = std::array({0, 0, 0, 0}); + auto ends = std::array({shape[0], shape[1], shape[2], shape[3]}); + auto step = std::array({1, 1, 1, 1}); formatted_output = ttnn::slice(DefaultQueueId, formatted_output, begins, ends, step, mem_config); formatted_output = ttnn::tilize(formatted_output, mem_config); return formatted_output; @@ -221,9 +221,9 @@ Tensor AutoFormat::format_output_tensor( formatted_output = formatted_output.to(Layout::ROW_MAJOR); convert_layout = formatted_output.get_layout() != target_layout; } - auto begins = std::vector({0, 0, 0, 0}); - auto ends = std::vector({shape[0], shape[1], shape[2], shape[3]}); - auto step = std::vector({1, 1, 1, 1}); + auto begins = std::array({0, 0, 0, 0}); + auto ends = std::array({shape[0], shape[1], shape[2], shape[3]}); + auto step = std::array({1, 1, 1, 1}); formatted_output = ttnn::slice(formatted_output, begins, ends, step, std::nullopt); } From 5d90e47edbf1ca2fbceeb4f3721c38d22f1e65e5 Mon Sep 17 00:00:00 2001 From: Shaw Nguyen Date: Tue, 8 Oct 2024 17:13:11 +0000 Subject: [PATCH 20/58] #13463: Revise Reviews 3 --- tests/ttnn/unit_tests/operations/test_clone.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/ttnn/unit_tests/operations/test_clone.py b/tests/ttnn/unit_tests/operations/test_clone.py index c5ba6d7df0c..6684cd3f8b2 100644 --- a/tests/ttnn/unit_tests/operations/test_clone.py +++ b/tests/ttnn/unit_tests/operations/test_clone.py @@ -120,6 +120,8 @@ def run_clone( [1, 1, 32, 32], # Single core [1, 1, 320, 384], # Multi core [1, 3, 320, 384], # Multi core + [38, 2, 99, 181], # Odd last Dim + [5, 33, 319, 381], # Odd last Dim ], ) @pytest.mark.parametrize( From f9ed0121fe3594c854e0f6a0f2ee891fd4a18ab3 Mon Sep 17 00:00:00 2001 From: Bryan Wilder Field Lozano Date: Wed, 9 Oct 2024 00:18:48 -0400 Subject: [PATCH 21/58] #0: Rename HalMemAddrType (#13616) --- .../perf_microbenchmark/common/util.hpp | 2 +- .../watcher/test_noc_sanitize_delays.cpp | 2 +- .../command_queue/test_EnqueueProgram.cpp | 2 +- tt_metal/impl/debug/debug_helpers.hpp | 2 +- tt_metal/impl/debug/watcher_device_reader.hpp | 2 +- tt_metal/impl/debug/watcher_server.cpp | 6 +-- tt_metal/impl/device/device.cpp | 50 +++++++++---------- tt_metal/impl/device/device.hpp | 4 +- tt_metal/impl/dispatch/command_queue.cpp | 8 +-- .../impl/dispatch/command_queue_interface.hpp | 8 +-- tt_metal/impl/program/program.cpp | 8 +-- tt_metal/jit_build/build.cpp | 2 +- tt_metal/llrt/blackhole/bh_hal.cpp | 2 +- tt_metal/llrt/blackhole/bh_hal_active_eth.cpp | 42 ++++++++-------- tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp | 42 ++++++++-------- tt_metal/llrt/blackhole/bh_hal_tensix.cpp | 42 ++++++++-------- tt_metal/llrt/grayskull/gs_hal.cpp | 42 ++++++++-------- tt_metal/llrt/hal.hpp | 26 +++++----- tt_metal/llrt/llrt.cpp | 2 +- tt_metal/llrt/wormhole/wh_hal_active_eth.cpp | 42 ++++++++-------- tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp | 42 ++++++++-------- tt_metal/llrt/wormhole/wh_hal_tensix.cpp | 42 ++++++++-------- tt_metal/tools/profiler/profiler.cpp | 6 +-- tt_metal/tools/profiler/tt_metal_profiler.cpp | 10 ++-- tt_metal/tt_metal.cpp | 4 +- 25 files changed, 220 insertions(+), 220 deletions(-) diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/common/util.hpp b/tests/tt_metal/tt_metal/perf_microbenchmark/common/util.hpp index 4b8564ea2b7..879d11a1b4c 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/common/util.hpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/common/util.hpp @@ -25,7 +25,7 @@ inline uint64_t get_t0_to_any_riscfw_end_cycle(tt::tt_metal::Device *device, con uint64_t min_cycle = -1; uint64_t max_cycle = 0; dprint_buf_msg_t *dprint_msg = - hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::DPRINT); + hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::DPRINT); // This works for tensix only, will need to be updated for eth vector print_buffer_addrs = { diff --git a/tests/tt_metal/tt_metal/unit_tests_common/watcher/test_noc_sanitize_delays.cpp b/tests/tt_metal/tt_metal/unit_tests_common/watcher/test_noc_sanitize_delays.cpp index 6eccf28268a..2734b791417 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/watcher/test_noc_sanitize_delays.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_common/watcher/test_noc_sanitize_delays.cpp @@ -147,7 +147,7 @@ void RunDelayTestOnCore(WatcherDelayFixture* fixture, Device* device, CoreCoord read_vec = tt::llrt::read_hex_vec_from_core ( device->id(), phys_core, - device->get_dev_addr(phys_core, HalMemAddrType::WATCHER) + offsetof(watcher_msg_t, debug_insert_delays), + device->get_dev_addr(phys_core, HalL1MemAddrType::WATCHER) + offsetof(watcher_msg_t, debug_insert_delays), sizeof(debug_insert_delays_msg_t)); log_info(tt::LogTest, "Read back debug_insert_delays: 0x{:x}", read_vec[0]); diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_EnqueueProgram.cpp b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_EnqueueProgram.cpp index eaab46ef4f1..3194e16e35c 100644 --- a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_EnqueueProgram.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/command_queue/test_EnqueueProgram.cpp @@ -588,7 +588,7 @@ bool test_increment_runtime_args_sanity(Device* device, const DummyProgramConfig break; case tt::RISCV::ERISC: { HalProgrammableCoreType eth_core_type = idle_eth ? HalProgrammableCoreType::IDLE_ETH : HalProgrammableCoreType::ACTIVE_ETH; - unique_args_addr = hal.get_dev_addr(eth_core_type, HalMemAddrType::UNRESERVED); + unique_args_addr = hal.get_dev_addr(eth_core_type, HalL1MemAddrType::UNRESERVED); common_args_addr = unique_args_addr + 1 * 256 * sizeof(uint32_t); compile_args[2] = unique_args_addr; compile_args[3] = common_args_addr; diff --git a/tt_metal/impl/debug/debug_helpers.hpp b/tt_metal/impl/debug/debug_helpers.hpp index 4bfb1207ef1..13ac9e49f88 100644 --- a/tt_metal/impl/debug/debug_helpers.hpp +++ b/tt_metal/impl/debug/debug_helpers.hpp @@ -56,7 +56,7 @@ static CoreDescriptorSet GetDispatchCores(Device* device) { inline uint64_t GetDprintBufAddr(Device *device, const CoreCoord &phys_core, int risc_id) { - dprint_buf_msg_t *buf = device->get_dev_addr(phys_core, HalMemAddrType::DPRINT); + dprint_buf_msg_t *buf = device->get_dev_addr(phys_core, HalL1MemAddrType::DPRINT); return reinterpret_cast(buf->data[risc_id]); } diff --git a/tt_metal/impl/debug/watcher_device_reader.hpp b/tt_metal/impl/debug/watcher_device_reader.hpp index 4cb4465a9f6..7f60ad5d4cf 100644 --- a/tt_metal/impl/debug/watcher_device_reader.hpp +++ b/tt_metal/impl/debug/watcher_device_reader.hpp @@ -7,7 +7,7 @@ namespace tt::watcher { #define GET_WATCHER_DEV_ADDR_FOR_CORE(dev, core, sub_type) \ - (dev->get_dev_addr(core, HalMemAddrType::WATCHER) + offsetof(watcher_msg_t, sub_type)) + (dev->get_dev_addr(core, HalL1MemAddrType::WATCHER) + offsetof(watcher_msg_t, sub_type)) constexpr uint64_t DEBUG_SANITIZE_NOC_SENTINEL_OK_64 = 0xbadabadabadabada; constexpr uint32_t DEBUG_SANITIZE_NOC_SENTINEL_OK_32 = 0xbadabada; diff --git a/tt_metal/impl/debug/watcher_server.cpp b/tt_metal/impl/debug/watcher_server.cpp index 6e108fa9ce0..a644e379cf8 100644 --- a/tt_metal/impl/debug/watcher_server.cpp +++ b/tt_metal/impl/debug/watcher_server.cpp @@ -27,13 +27,13 @@ namespace tt { namespace watcher { #define GET_WATCHER_TENSIX_DEV_ADDR() \ - hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::WATCHER) + hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::WATCHER) #define GET_WATCHER_ERISC_DEV_ADDR() \ - hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::WATCHER) + hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::WATCHER) #define GET_WATCHER_IERISC_DEV_ADDR() \ - hal.get_dev_addr(HalProgrammableCoreType::IDLE_ETH, HalMemAddrType::WATCHER) + hal.get_dev_addr(HalProgrammableCoreType::IDLE_ETH, HalL1MemAddrType::WATCHER) static std::atomic enabled = false; static std::atomic server_running = false; diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index 80b0bff5c8a..09c5e0759d8 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -218,7 +218,7 @@ void Device::initialize_allocator(size_t l1_small_size, size_t trace_region_size .dram_bank_size = soc_desc.dram_bank_size, .dram_bank_offsets = {}, .dram_unreserved_base = DRAM_BARRIER_BASE + DRAM_BARRIER_SIZE, // these should come from the HAL - .l1_unreserved_base = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED), + .l1_unreserved_base = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED), .worker_grid_size = this->logical_grid_size(), .worker_l1_size = static_cast(soc_desc.worker_l1_size), .storage_core_bank_size = get_storage_core_bank_size(id_, num_hw_cqs_, dispatch_core_type), @@ -398,10 +398,10 @@ void Device::initialize_firmware(CoreCoord phys_core, launch_msg_t *launch_msg, // worker cores (Tensix and active eth) configured with DISPATCH_MODE_DEV // When using Slow Dispatch, all cores initialized with DISPATCH_MODE_HOST std::vector init_launch_msg_data(launch_msg_buffer_num_entries, *launch_msg); - tt::Cluster::instance().write_core(init_launch_msg_data.data(), launch_msg_buffer_num_entries * sizeof(launch_msg_t), tt_cxy_pair(this->id(), phys_core), this->get_dev_addr(phys_core, HalMemAddrType::LAUNCH)); - uint32_t go_addr = this->get_dev_addr(phys_core, HalMemAddrType::GO_MSG); + tt::Cluster::instance().write_core(init_launch_msg_data.data(), launch_msg_buffer_num_entries * sizeof(launch_msg_t), tt_cxy_pair(this->id(), phys_core), this->get_dev_addr(phys_core, HalL1MemAddrType::LAUNCH)); + uint32_t go_addr = this->get_dev_addr(phys_core, HalL1MemAddrType::GO_MSG); tt::Cluster::instance().write_core(go_msg, sizeof(go_msg_t), tt_cxy_pair(this->id(), phys_core), go_addr); - uint64_t launch_msg_buffer_read_ptr_addr = this->get_dev_addr(phys_core, HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR); + uint64_t launch_msg_buffer_read_ptr_addr = this->get_dev_addr(phys_core, HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR); std::vector zero = {0}; tt::Cluster::instance().write_core(zero.data(), sizeof(uint32_t), tt_cxy_pair(this->id(), phys_core), launch_msg_buffer_read_ptr_addr); } @@ -422,8 +422,8 @@ void Device::reset_cores() { CoreCoord physical_core = this->ethernet_core_from_logical_core(eth_core); std::vector data(sizeof(launch_msg_t) / sizeof(uint32_t)); std::vector go_signal_data(sizeof(go_msg_t) / sizeof(uint32_t)); - DeviceAddr launch_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::LAUNCH); - DeviceAddr go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::GO_MSG); + DeviceAddr launch_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::LAUNCH); + DeviceAddr go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::GO_MSG); data = tt::llrt::read_hex_vec_from_core( this->id(), physical_core, launch_addr, sizeof(launch_msg_t)); @@ -454,8 +454,8 @@ void Device::reset_cores() { // Ethernet cores won't be reset, so just signal the dispatch cores to early exit. std::vector data(sizeof(launch_msg_t) / sizeof(uint32_t)); std::vector go_signal_data(sizeof(go_msg_t) / sizeof(uint32_t)); - DeviceAddr launch_addr = hal.get_dev_addr(HalProgrammableCoreType::IDLE_ETH, HalMemAddrType::LAUNCH); - DeviceAddr go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::GO_MSG); + DeviceAddr launch_addr = hal.get_dev_addr(HalProgrammableCoreType::IDLE_ETH, HalL1MemAddrType::LAUNCH); + DeviceAddr go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::GO_MSG); data = tt::llrt::read_hex_vec_from_core( id_and_cores.first, phys_core, launch_addr, sizeof(launch_msg_t)); go_signal_data = tt::llrt::read_hex_vec_from_core( @@ -581,7 +581,7 @@ void Device::initialize_and_launch_firmware() { if (!this->storage_only_cores_.count(logical_core)) { CoreCoord worker_core = this->worker_core_from_logical_core(logical_core); tt::llrt::write_hex_vec_to_core( - this->id(), worker_core, core_info_vec, this->get_dev_addr(worker_core, HalMemAddrType::CORE_INFO)); + this->id(), worker_core, core_info_vec, this->get_dev_addr(worker_core, HalL1MemAddrType::CORE_INFO)); this->initialize_firmware(worker_core, &launch_msg, &go_msg); not_done_cores.insert(worker_core); } @@ -601,14 +601,14 @@ void Device::initialize_and_launch_firmware() { for (const auto ð_core : this->get_active_ethernet_cores()) { CoreCoord phys_eth_core = this->ethernet_core_from_logical_core(eth_core); tt::llrt::write_hex_vec_to_core( - this->id(), phys_eth_core, core_info_vec, this->get_dev_addr(phys_eth_core, HalMemAddrType::CORE_INFO)); + this->id(), phys_eth_core, core_info_vec, this->get_dev_addr(phys_eth_core, HalL1MemAddrType::CORE_INFO)); this->initialize_firmware(phys_eth_core, &launch_msg, &go_msg); } for (const auto ð_core : this->get_inactive_ethernet_cores()) { CoreCoord phys_eth_core = this->ethernet_core_from_logical_core(eth_core); tt::llrt::write_hex_vec_to_core( - this->id(), phys_eth_core, core_info_vec, this->get_dev_addr(phys_eth_core, HalMemAddrType::CORE_INFO)); + this->id(), phys_eth_core, core_info_vec, this->get_dev_addr(phys_eth_core, HalL1MemAddrType::CORE_INFO)); this->initialize_firmware(phys_eth_core, &launch_msg, &go_msg); not_done_cores.insert(phys_eth_core); } @@ -1530,8 +1530,8 @@ void Device::update_workers_build_settings(std::vector(device_worker_variants[DispatchWorkerType::PREFETCH_D][dispatch_d_idx]); // 1 to 1 mapping bw prefetch_d and dispatch_d auto dispatch_s_settings = std::get<1>(device_worker_variants[DispatchWorkerType::DISPATCH_S][dispatch_d_idx]); // 1 to 1 mapping bw dispatch_s and dispatch_d @@ -1581,8 +1581,8 @@ void Device::update_workers_build_settings(std::vectordispatch_s_enabled()) { - uint32_t tensix_worker_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::GO_MSG); - uint32_t eth_worker_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::GO_MSG); + uint32_t tensix_worker_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG); + uint32_t eth_worker_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::GO_MSG); for (auto&[core, dispatch_s_settings] : device_worker_variants[DispatchWorkerType::DISPATCH_S]) { int dispatch_s_idx = 0; auto prefetch_d_settings = std::get<1>(device_worker_variants[DispatchWorkerType::PREFETCH_D][dispatch_s_idx]); // 1 to 1 mapping bw prefetch_d and dispatch_s @@ -1813,7 +1813,7 @@ void Device::setup_tunnel_for_remote_devices() { tt_cxy_pair demux_location = dispatch_core_manager::instance().demux_core(device_id, channel, 0); settings.worker_physical_core = tt_cxy_pair(demux_location.chip, get_physical_core_coordinate(demux_location, dispatch_core_type)); settings.kernel_file = "tt_metal/impl/dispatch/kernels/packet_demux.cpp"; - settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED); + settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED); settings.cb_size_bytes = 0x10000; tunnel_core_allocations[DEMUX].push_back(std::make_tuple(demux_location, settings)); } else if (num_prefetchers == 4 || num_prefetchers == 8) { @@ -1836,7 +1836,7 @@ void Device::setup_tunnel_for_remote_devices() { settings.worker_physical_core = tt_cxy_pair(demux_location.chip, get_physical_core_coordinate(demux_location, dispatch_core_type)); settings.semaphores.clear(); settings.kernel_file = "tt_metal/impl/dispatch/kernels/packet_demux.cpp"; - settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED); + settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED); settings.cb_size_bytes = 0x10000; tunnel_core_allocations[DEMUX].push_back(std::make_tuple(demux_location, settings)); @@ -1844,14 +1844,14 @@ void Device::setup_tunnel_for_remote_devices() { demux_location = dispatch_core_manager::instance().demux_core(device_id, channel, 1); settings.worker_physical_core = tt_cxy_pair(demux_location.chip, get_physical_core_coordinate(demux_location, dispatch_core_type)); settings.kernel_file = "tt_metal/impl/dispatch/kernels/packet_demux.cpp"; - settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED); + settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED); settings.cb_size_bytes = 0x10000; tunnel_core_allocations[DEMUX].push_back(std::make_tuple(demux_location, settings)); demux_location = dispatch_core_manager::instance().demux_core(device_id, channel, 2); settings.worker_physical_core = tt_cxy_pair(demux_location.chip, get_physical_core_coordinate(demux_location, dispatch_core_type)); settings.kernel_file = "tt_metal/impl/dispatch/kernels/packet_demux.cpp"; - settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED); + settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED); settings.cb_size_bytes = 0x10000; tunnel_core_allocations[DEMUX].push_back(std::make_tuple(demux_location, settings)); @@ -1904,7 +1904,7 @@ void Device::setup_tunnel_for_remote_devices() { settings.worker_physical_core = tt_cxy_pair(demux_d_location.chip, get_physical_core_coordinate(demux_d_location, dispatch_core_type)); settings.kernel_file = "tt_metal/impl/dispatch/kernels/vc_packet_router.cpp"; settings.producer_semaphore_id = 0; - settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED); + settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED); settings.cb_size_bytes = 0x8000; if (tunnel.size() > 2) { settings.semaphores.resize(1); @@ -1916,7 +1916,7 @@ void Device::setup_tunnel_for_remote_devices() { settings.worker_physical_core = tt_cxy_pair(demux_d_location.chip, get_physical_core_coordinate(demux_d_location, dispatch_core_type)); settings.kernel_file = "tt_metal/impl/dispatch/kernels/vc_packet_router.cpp"; settings.producer_semaphore_id = 0; - settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED); + settings.cb_start_address = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED); settings.cb_size_bytes = 0x8000; tunnel_core_allocations[DEMUX_D].push_back(std::make_tuple(demux_d_location, settings)); } @@ -2197,10 +2197,10 @@ void Device::compile_command_queue_programs() { ); auto [tensix_num_worker_cores, tensix_worker_physical_grid] = get_physical_worker_grid_config(this->id(), num_hw_cqs, dispatch_core_type); - uint32_t tensix_worker_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::GO_MSG); + uint32_t tensix_worker_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::GO_MSG); uint32_t eth_worker_go_signal_addr = 0; if (hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH) != -1) { - eth_worker_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::GO_MSG); + eth_worker_go_signal_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::GO_MSG); } std::vector dispatch_compile_args = { dispatch_constants::get(dispatch_core_type).dispatch_buffer_base(), @@ -2779,7 +2779,7 @@ void Device::init_command_queue_device() { launch_msg_t msg = command_queue_program.kernels_on_core(logical_dispatch_core, index)->launch_msg; go_msg_t go_msg = command_queue_program.kernels_on_core(logical_dispatch_core, index)->go_msg; CoreCoord phys_core = this->physical_core_from_logical_core(logical_dispatch_core, core_type); - tt::llrt::write_launch_msg_to_core(this->id(), phys_core, &msg, &go_msg, this->get_dev_addr(phys_core, HalMemAddrType::LAUNCH)); + tt::llrt::write_launch_msg_to_core(this->id(), phys_core, &msg, &go_msg, this->get_dev_addr(phys_core, HalL1MemAddrType::LAUNCH)); } } @@ -2796,7 +2796,7 @@ void Device::init_command_queue_device() { launch_msg_t msg = mmio_command_queue_program.kernels_on_core(logical_dispatch_core, index)->launch_msg; go_msg_t go_msg = mmio_command_queue_program.kernels_on_core(logical_dispatch_core, index)->go_msg; CoreCoord phys_core = mmio_device->physical_core_from_logical_core(logical_dispatch_core, core_type); - tt::llrt::write_launch_msg_to_core(mmio_device_id, phys_core, &msg, &go_msg, mmio_device->get_dev_addr(phys_core, HalMemAddrType::LAUNCH)); + tt::llrt::write_launch_msg_to_core(mmio_device_id, phys_core, &msg, &go_msg, mmio_device->get_dev_addr(phys_core, HalL1MemAddrType::LAUNCH)); } } } diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 92dbdb38bb8..3ea349d5618 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -325,7 +325,7 @@ class Device { HalProgrammableCoreType get_programmable_core_type(CoreCoord phys_core) const; template - T get_dev_addr(CoreCoord phys_core, HalMemAddrType addr_type) const; + T get_dev_addr(CoreCoord phys_core, HalL1MemAddrType addr_type) const; // Returns address where allocator starts allocating buffer template T get_base_allocator_addr(const HalMemType &mem_type) const; @@ -360,7 +360,7 @@ inline HalProgrammableCoreType Device::get_programmable_core_type(CoreCoord phys } template -inline T Device::get_dev_addr(CoreCoord phys_core, HalMemAddrType addr_type) const { +inline T Device::get_dev_addr(CoreCoord phys_core, HalL1MemAddrType addr_type) const { return hal.get_dev_addr(this->get_programmable_core_type(phys_core), addr_type); } diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 0e94bcf7afd..f73d27b8e23 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -1195,7 +1195,7 @@ void EnqueueProgramCommand::assemble_device_commands( multicast_go_signal_sub_cmds.size() + unicast_go_signal_sub_cmds.size()); // Get the address for the slot this launch_message will be written to - uint32_t multicast_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::LAUNCH) + this->multicast_cores_launch_message_wptr * sizeof(launch_msg_t); + uint32_t multicast_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::LAUNCH) + this->multicast_cores_launch_message_wptr * sizeof(launch_msg_t); uint8_t go_signal_mcast_flag = 0x0; if (multicast_go_signal_sub_cmds.size() > 0) { @@ -1227,7 +1227,7 @@ void EnqueueProgramCommand::assemble_device_commands( } if (unicast_go_signal_sub_cmds.size() > 0) { - uint32_t unicast_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::LAUNCH) + this->unicast_cores_launch_message_wptr * sizeof(launch_msg_t); + uint32_t unicast_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::LAUNCH) + this->unicast_cores_launch_message_wptr * sizeof(launch_msg_t); go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_UNICAST; uint32_t curr_sub_cmd_idx = 0; for (const auto& [num_sub_cmds_in_cmd, unicast_go_signal_payload_sizeB] : unicast_go_signals_payload) { @@ -1300,12 +1300,12 @@ void EnqueueProgramCommand::assemble_device_commands( go_signal->kernel_config.host_assigned_id = program.get_runtime_id(); } // Update launch message addresses to reflect new launch_msg slot in ring buffer - uint32_t multicast_cores_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::LAUNCH) + this->multicast_cores_launch_message_wptr * sizeof(launch_msg_t); + uint32_t multicast_cores_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::LAUNCH) + this->multicast_cores_launch_message_wptr * sizeof(launch_msg_t); for (auto launch_msg_cmd_ptr : cached_program_command_sequence.launch_msg_write_packed_cmd_ptrs) { launch_msg_cmd_ptr->addr = multicast_cores_launch_msg_addr; } if (cached_program_command_sequence.unicast_launch_msg_write_packed_cmd_ptrs.size()) { - uint32_t unicast_cores_launch_message_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::LAUNCH) + this->unicast_cores_launch_message_wptr * sizeof(launch_msg_t); + uint32_t unicast_cores_launch_message_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::LAUNCH) + this->unicast_cores_launch_message_wptr * sizeof(launch_msg_t); for (auto launch_msg_cmd_ptr : cached_program_command_sequence.unicast_launch_msg_write_packed_cmd_ptrs) { launch_msg_cmd_ptr->addr = unicast_cores_launch_message_addr; } diff --git a/tt_metal/impl/dispatch/command_queue_interface.hpp b/tt_metal/impl/dispatch/command_queue_interface.hpp index b5e1fde9159..ec6e1a964c7 100644 --- a/tt_metal/impl/dispatch/command_queue_interface.hpp +++ b/tt_metal/impl/dispatch/command_queue_interface.hpp @@ -146,7 +146,7 @@ struct dispatch_constants { dispatch_buffer_block_size = 512 * 1024; prefetch_d_buffer_size_ = 256 * 1024; dispatch_s_buffer_size_ = 32 * 1024; // dispatch_s only sends Go Signals -> CB can be small - base_device_command_queue_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED); + base_device_command_queue_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED); } else { prefetch_q_entries_ = 128; max_prefetch_command_size_ = 32 * 1024; @@ -155,7 +155,7 @@ struct dispatch_constants { dispatch_buffer_block_size = 128 * 1024; prefetch_d_buffer_size_ = 128 * 1024; dispatch_s_buffer_size_ = 32 * 1024; // dispatch_s only sends Go Signals -> CB can be small - base_device_command_queue_addr = hal.get_dev_addr(HalProgrammableCoreType::IDLE_ETH, HalMemAddrType::UNRESERVED); + base_device_command_queue_addr = hal.get_dev_addr(HalProgrammableCoreType::IDLE_ETH, HalL1MemAddrType::UNRESERVED); } TT_ASSERT(cmddat_q_size_ >= 2 * max_prefetch_command_size_); TT_ASSERT(scratch_db_size_ % 2 == 0); @@ -533,8 +533,8 @@ class SystemMemoryManager { for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { this->config_buffer_mgr.init_add_core( - hal.get_dev_addr(hal.get_programmable_core_type(index), HalMemAddrType::KERNEL_CONFIG), - hal.get_dev_size(hal.get_programmable_core_type(index), HalMemAddrType::KERNEL_CONFIG)); + hal.get_dev_addr(hal.get_programmable_core_type(index), HalL1MemAddrType::KERNEL_CONFIG), + hal.get_dev_size(hal.get_programmable_core_type(index), HalL1MemAddrType::KERNEL_CONFIG)); } } diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index 1877b9855de..3c3fbfef000 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -168,7 +168,7 @@ KernelGroup::KernelGroup( // Fast dispatch kernel config mangement happens under the CQ and will re-program the base for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { this->launch_msg.kernel_config.kernel_config_base[index] = - hal.get_dev_addr(index, HalMemAddrType::KERNEL_CONFIG); + hal.get_dev_addr(index, HalL1MemAddrType::KERNEL_CONFIG); } for (int class_id = 0; class_id < DISPATCH_CLASS_MAX; class_id++) { @@ -545,7 +545,7 @@ size_t Program::num_semaphores() const { return semaphores_.size(); } void Program::init_semaphores(const Device &device, const CoreCoord &logical_core, uint32_t programmable_core_type_index) const { auto semaphores_on_core = this->semaphores_on_core(logical_core); - uint64_t kernel_config_base = hal.get_dev_addr(programmable_core_type_index, HalMemAddrType::KERNEL_CONFIG); + uint64_t kernel_config_base = hal.get_dev_addr(programmable_core_type_index, HalL1MemAddrType::KERNEL_CONFIG); uint64_t addr = kernel_config_base + this->program_configs_[programmable_core_type_index].sem_offset; CoreType core_type = hal.get_core_type(programmable_core_type_index); for (auto semaphore : semaphores_on_core) { @@ -1155,7 +1155,7 @@ uint32_t Program::get_sem_base_addr(Device *device, CoreCoord logical_core, Core uint32_t base_addr = device->using_fast_dispatch ? device->sysmem_manager().get_config_buffer_mgr().get_last_slot_addr(programmable_core_type) : - hal.get_dev_addr(programmable_core_type, HalMemAddrType::KERNEL_CONFIG); + hal.get_dev_addr(programmable_core_type, HalL1MemAddrType::KERNEL_CONFIG); return base_addr + this->program_configs_[index].sem_offset; } @@ -1168,7 +1168,7 @@ uint32_t Program::get_cb_base_addr(Device *device, CoreCoord logical_core, CoreT uint32_t base_addr = device->using_fast_dispatch ? device->sysmem_manager().get_config_buffer_mgr().get_last_slot_addr(programmable_core_type) : - hal.get_dev_addr(programmable_core_type, HalMemAddrType::KERNEL_CONFIG); + hal.get_dev_addr(programmable_core_type, HalL1MemAddrType::KERNEL_CONFIG); return base_addr + this->program_configs_[index].cb_offset; } diff --git a/tt_metal/jit_build/build.cpp b/tt_metal/jit_build/build.cpp index 630b9cabc1b..ec4e93322d9 100644 --- a/tt_metal/jit_build/build.cpp +++ b/tt_metal/jit_build/build.cpp @@ -106,7 +106,7 @@ void JitBuildEnv::init(uint32_t build_key, tt::ARCH arch) { } if (tt::llrt::OptionsG.get_feature_enabled(tt::llrt::RunTimeDebugFeatureDprint)) { - this->defines_ += "-DDEBUG_PRINT_ENABLED -DL1_UNRESERVED_BASE=" + to_string(hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::UNRESERVED)) + " "; + this->defines_ += "-DDEBUG_PRINT_ENABLED -DL1_UNRESERVED_BASE=" + to_string(hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::UNRESERVED)) + " "; } if (tt::llrt::OptionsG.get_record_noc_transfers()) { diff --git a/tt_metal/llrt/blackhole/bh_hal.cpp b/tt_metal/llrt/blackhole/bh_hal.cpp index d6e9d2ba6f1..57194dc56a7 100644 --- a/tt_metal/llrt/blackhole/bh_hal.cpp +++ b/tt_metal/llrt/blackhole/bh_hal.cpp @@ -12,7 +12,7 @@ namespace tt { namespace tt_metal { -static inline int hv (enum HalMemAddrType v) { +static inline int hv (enum HalL1MemAddrType v) { return static_cast(v); } diff --git a/tt_metal/llrt/blackhole/bh_hal_active_eth.cpp b/tt_metal/llrt/blackhole/bh_hal_active_eth.cpp index fc41672507d..15d2aa56148 100644 --- a/tt_metal/llrt/blackhole/bh_hal_active_eth.cpp +++ b/tt_metal/llrt/blackhole/bh_hal_active_eth.cpp @@ -27,29 +27,29 @@ HalCoreInfoType create_active_eth_mem_map() { std::vector mem_map_bases; - mem_map_bases.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_bases[utils::underlying_type(HalMemAddrType::BARRIER)] = MEM_L1_BARRIER; - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH)] = GET_ETH_MAILBOX_ADDRESS_HOST(launch); - mem_map_bases[utils::underlying_type(HalMemAddrType::WATCHER)] = GET_ETH_MAILBOX_ADDRESS_HOST(watcher); - mem_map_bases[utils::underlying_type(HalMemAddrType::DPRINT)] = GET_ETH_MAILBOX_ADDRESS_HOST(dprint_buf); - mem_map_bases[utils::underlying_type(HalMemAddrType::PROFILER)] = GET_ETH_MAILBOX_ADDRESS_HOST(profiler); - mem_map_bases[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)] = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::CORE_INFO)] = GET_ETH_MAILBOX_ADDRESS_HOST(core_info); - mem_map_bases[utils::underlying_type(HalMemAddrType::GO_MSG)] = GET_ETH_MAILBOX_ADDRESS_HOST(go_message); - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_ETH_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::BARRIER)] = MEM_L1_BARRIER; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = GET_ETH_MAILBOX_ADDRESS_HOST(launch); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::WATCHER)] = GET_ETH_MAILBOX_ADDRESS_HOST(watcher); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::DPRINT)] = GET_ETH_MAILBOX_ADDRESS_HOST(dprint_buf); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::PROFILER)] = GET_ETH_MAILBOX_ADDRESS_HOST(profiler); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::CORE_INFO)] = GET_ETH_MAILBOX_ADDRESS_HOST(core_info); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = GET_ETH_MAILBOX_ADDRESS_HOST(go_message); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_ETH_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); std::vector mem_map_sizes; - mem_map_sizes.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_sizes[utils::underlying_type(HalMemAddrType::BARRIER)] = sizeof(uint32_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH)] = sizeof(launch_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::WATCHER)] = sizeof(watcher_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::PROFILER)] = sizeof(profiler_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_SIZE; - mem_map_sizes[utils::underlying_type(HalMemAddrType::UNRESERVED)] = eth_l1_mem::address_map::MAX_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; - mem_map_sizes[utils::underlying_type(HalMemAddrType::GO_MSG)] = sizeof(go_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); + mem_map_sizes.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::BARRIER)] = sizeof(uint32_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = sizeof(launch_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::WATCHER)] = sizeof(watcher_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::PROFILER)] = sizeof(profiler_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_SIZE; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = eth_l1_mem::address_map::MAX_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); return {HalProgrammableCoreType::IDLE_ETH, CoreType::ETH, num_proc_per_idle_eth_core, mem_map_bases, mem_map_sizes, false}; } diff --git a/tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp b/tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp index 3085e42ad82..02d250b78da 100644 --- a/tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp +++ b/tt_metal/llrt/blackhole/bh_hal_idle_eth.cpp @@ -27,29 +27,29 @@ HalCoreInfoType create_idle_eth_mem_map() { std::vector mem_map_bases; - mem_map_bases.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_bases[utils::underlying_type(HalMemAddrType::BARRIER)] = MEM_L1_BARRIER; - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch); - mem_map_bases[utils::underlying_type(HalMemAddrType::WATCHER)] = GET_IERISC_MAILBOX_ADDRESS_HOST(watcher); - mem_map_bases[utils::underlying_type(HalMemAddrType::DPRINT)] = GET_IERISC_MAILBOX_ADDRESS_HOST(dprint_buf); - mem_map_bases[utils::underlying_type(HalMemAddrType::PROFILER)] = GET_IERISC_MAILBOX_ADDRESS_HOST(profiler); - mem_map_bases[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = IDLE_ERISC_L1_KERNEL_CONFIG_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; // TODO: this is wrong, need idle eth specific value - mem_map_bases[utils::underlying_type(HalMemAddrType::CORE_INFO)] = GET_IERISC_MAILBOX_ADDRESS_HOST(core_info); - mem_map_bases[utils::underlying_type(HalMemAddrType::GO_MSG)] = GET_IERISC_MAILBOX_ADDRESS_HOST(go_message); - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::BARRIER)] = MEM_L1_BARRIER; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::WATCHER)] = GET_IERISC_MAILBOX_ADDRESS_HOST(watcher); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::DPRINT)] = GET_IERISC_MAILBOX_ADDRESS_HOST(dprint_buf); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::PROFILER)] = GET_IERISC_MAILBOX_ADDRESS_HOST(profiler); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = IDLE_ERISC_L1_KERNEL_CONFIG_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; // TODO: this is wrong, need idle eth specific value + mem_map_bases[utils::underlying_type(HalL1MemAddrType::CORE_INFO)] = GET_IERISC_MAILBOX_ADDRESS_HOST(core_info); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = GET_IERISC_MAILBOX_ADDRESS_HOST(go_message); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); std::vector mem_map_sizes; - mem_map_sizes.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_sizes[utils::underlying_type(HalMemAddrType::BARRIER)] = sizeof(uint32_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH)] = sizeof(launch_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::WATCHER)] = sizeof(watcher_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::PROFILER)] = sizeof(profiler_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; // TODO: this is wrong, need idle eth specific value - mem_map_sizes[utils::underlying_type(HalMemAddrType::UNRESERVED)] = MEM_ETH_SIZE - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)]; - mem_map_sizes[utils::underlying_type(HalMemAddrType::GO_MSG)] = sizeof(go_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); + mem_map_sizes.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::BARRIER)] = sizeof(uint32_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = sizeof(launch_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::WATCHER)] = sizeof(watcher_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::PROFILER)] = sizeof(profiler_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; // TODO: this is wrong, need idle eth specific value + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = MEM_ETH_SIZE - mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)]; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); return {HalProgrammableCoreType::IDLE_ETH, CoreType::ETH, num_proc_per_idle_eth_core, mem_map_bases, mem_map_sizes, false}; } diff --git a/tt_metal/llrt/blackhole/bh_hal_tensix.cpp b/tt_metal/llrt/blackhole/bh_hal_tensix.cpp index e8ca63db912..6d87c007b1e 100644 --- a/tt_metal/llrt/blackhole/bh_hal_tensix.cpp +++ b/tt_metal/llrt/blackhole/bh_hal_tensix.cpp @@ -25,30 +25,30 @@ HalCoreInfoType create_tensix_mem_map() { std::vector mem_map_bases; - mem_map_bases.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_bases[utils::underlying_type(HalMemAddrType::BARRIER)] = MEM_L1_BARRIER; - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH)] = GET_MAILBOX_ADDRESS_HOST(launch); - mem_map_bases[utils::underlying_type(HalMemAddrType::WATCHER)] = GET_MAILBOX_ADDRESS_HOST(watcher); - mem_map_bases[utils::underlying_type(HalMemAddrType::DPRINT)] = GET_MAILBOX_ADDRESS_HOST(dprint_buf); - mem_map_bases[utils::underlying_type(HalMemAddrType::PROFILER)] = GET_MAILBOX_ADDRESS_HOST(profiler); - mem_map_bases[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; - mem_map_bases[utils::underlying_type(HalMemAddrType::CORE_INFO)] = GET_MAILBOX_ADDRESS_HOST(core_info); - mem_map_bases[utils::underlying_type(HalMemAddrType::GO_MSG)] = GET_MAILBOX_ADDRESS_HOST(go_message); - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::BARRIER)] = MEM_L1_BARRIER; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = GET_MAILBOX_ADDRESS_HOST(launch); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::WATCHER)] = GET_MAILBOX_ADDRESS_HOST(watcher); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::DPRINT)] = GET_MAILBOX_ADDRESS_HOST(dprint_buf); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::PROFILER)] = GET_MAILBOX_ADDRESS_HOST(profiler); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::CORE_INFO)] = GET_MAILBOX_ADDRESS_HOST(core_info); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = GET_MAILBOX_ADDRESS_HOST(go_message); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); std::vector mem_map_sizes; - mem_map_sizes.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_sizes[utils::underlying_type(HalMemAddrType::BARRIER)] = sizeof(uint32_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH)] = sizeof(launch_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::WATCHER)] = sizeof(watcher_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::PROFILER)] = sizeof(profiler_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; - mem_map_sizes[utils::underlying_type(HalMemAddrType::UNRESERVED)] = MEM_L1_SIZE - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)]; - mem_map_sizes[utils::underlying_type(HalMemAddrType::GO_MSG)] = sizeof(go_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); + mem_map_sizes.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::BARRIER)] = sizeof(uint32_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = sizeof(launch_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::WATCHER)] = sizeof(watcher_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::PROFILER)] = sizeof(profiler_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = MEM_L1_SIZE - mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)]; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); return {HalProgrammableCoreType::TENSIX, CoreType::WORKER, num_proc_per_tensix_core, mem_map_bases, mem_map_sizes, true}; } diff --git a/tt_metal/llrt/grayskull/gs_hal.cpp b/tt_metal/llrt/grayskull/gs_hal.cpp index f696444f112..047ce394288 100644 --- a/tt_metal/llrt/grayskull/gs_hal.cpp +++ b/tt_metal/llrt/grayskull/gs_hal.cpp @@ -31,29 +31,29 @@ void Hal::initialize_gs() { uint32_t max_alignment = std::max(DRAM_ALIGNMENT, L1_ALIGNMENT); std::vector mem_map_bases; - mem_map_bases.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_bases[utils::underlying_type(HalMemAddrType::BARRIER)] = MEM_L1_BARRIER; - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH)] = GET_MAILBOX_ADDRESS_HOST(launch); - mem_map_bases[utils::underlying_type(HalMemAddrType::WATCHER)] = GET_MAILBOX_ADDRESS_HOST(watcher); - mem_map_bases[utils::underlying_type(HalMemAddrType::DPRINT)] = GET_MAILBOX_ADDRESS_HOST(dprint_buf); - mem_map_bases[utils::underlying_type(HalMemAddrType::PROFILER)] = GET_MAILBOX_ADDRESS_HOST(profiler); - mem_map_bases[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; - mem_map_bases[utils::underlying_type(HalMemAddrType::CORE_INFO)] = GET_MAILBOX_ADDRESS_HOST(core_info); - mem_map_bases[utils::underlying_type(HalMemAddrType::GO_MSG)] = GET_MAILBOX_ADDRESS_HOST(go_message); - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::BARRIER)] = MEM_L1_BARRIER; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = GET_MAILBOX_ADDRESS_HOST(launch); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::WATCHER)] = GET_MAILBOX_ADDRESS_HOST(watcher); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::DPRINT)] = GET_MAILBOX_ADDRESS_HOST(dprint_buf); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::PROFILER)] = GET_MAILBOX_ADDRESS_HOST(profiler); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::CORE_INFO)] = GET_MAILBOX_ADDRESS_HOST(core_info); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = GET_MAILBOX_ADDRESS_HOST(go_message); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); std::vector mem_map_sizes; - mem_map_sizes.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_sizes[utils::underlying_type(HalMemAddrType::BARRIER)] = sizeof(uint32_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH)] = sizeof(launch_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::WATCHER)] = sizeof(watcher_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::PROFILER)] = sizeof(profiler_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; - mem_map_sizes[utils::underlying_type(HalMemAddrType::UNRESERVED)] = MEM_L1_SIZE - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)]; - mem_map_sizes[utils::underlying_type(HalMemAddrType::GO_MSG)] = sizeof(go_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); + mem_map_sizes.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::BARRIER)] = sizeof(uint32_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = sizeof(launch_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::WATCHER)] = sizeof(watcher_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::PROFILER)] = sizeof(profiler_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = MEM_L1_SIZE - mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)]; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); this->core_info_.push_back({HalProgrammableCoreType::TENSIX, CoreType::WORKER, num_proc_per_tensix_core, mem_map_bases, mem_map_sizes, true}); diff --git a/tt_metal/llrt/hal.hpp b/tt_metal/llrt/hal.hpp index 8fa05be5ec5..466af59c142 100644 --- a/tt_metal/llrt/hal.hpp +++ b/tt_metal/llrt/hal.hpp @@ -30,7 +30,7 @@ enum class HalProgrammableCoreType { COUNT = 3 }; -enum class HalMemAddrType : uint8_t { +enum class HalL1MemAddrType : uint8_t { BARRIER = 0, LAUNCH = 1, WATCHER = 2, @@ -72,19 +72,19 @@ class HalCoreInfoType { const std::vector& mem_map_bases, const std::vector& mem_map_sizes, bool supports_cbs); template - T get_dev_addr(HalMemAddrType addr_type) const; - uint32_t get_dev_size(HalMemAddrType addr_type) const; + T get_dev_addr(HalL1MemAddrType addr_type) const; + uint32_t get_dev_size(HalL1MemAddrType addr_type) const; }; template -inline T HalCoreInfoType::get_dev_addr(HalMemAddrType addr_type) const { - uint32_t index = utils::underlying_type(addr_type); +inline T HalCoreInfoType::get_dev_addr(HalL1MemAddrType addr_type) const { + uint32_t index = utils::underlying_type(addr_type); TT_ASSERT(index < this->mem_map_bases_.size()); return reinterpret_cast(this->mem_map_bases_[index]); } -inline uint32_t HalCoreInfoType::get_dev_size(HalMemAddrType addr_type) const { - uint32_t index = utils::underlying_type(addr_type); +inline uint32_t HalCoreInfoType::get_dev_size(HalL1MemAddrType addr_type) const { + uint32_t index = utils::underlying_type(addr_type); TT_ASSERT(index < this->mem_map_sizes_.size()); return this->mem_map_sizes_[index]; } @@ -113,10 +113,10 @@ class Hal { uint32_t get_processor_count(uint32_t core_type_index) const; template - T get_dev_addr(HalProgrammableCoreType programmable_core_type, HalMemAddrType addr_type) const; + T get_dev_addr(HalProgrammableCoreType programmable_core_type, HalL1MemAddrType addr_type) const; template - T get_dev_addr(uint32_t programmable_core_type_index, HalMemAddrType addr_type) const; - uint32_t get_dev_size(HalProgrammableCoreType programmable_core_type, HalMemAddrType addr_type) const; + T get_dev_addr(uint32_t programmable_core_type_index, HalL1MemAddrType addr_type) const; + uint32_t get_dev_size(HalProgrammableCoreType programmable_core_type, HalL1MemAddrType addr_type) const; uint32_t get_alignment(HalMemType memory_type) const; @@ -136,19 +136,19 @@ inline CoreType Hal::get_core_type(uint32_t core_type_index) const { } template -inline T Hal::get_dev_addr(HalProgrammableCoreType programmable_core_type, HalMemAddrType addr_type) const { +inline T Hal::get_dev_addr(HalProgrammableCoreType programmable_core_type, HalL1MemAddrType addr_type) const { uint32_t index = utils::underlying_type(programmable_core_type); TT_ASSERT(index < this->core_info_.size()); return this->core_info_[index].get_dev_addr(addr_type); } template -inline T Hal::get_dev_addr(uint32_t programmable_core_type_index, HalMemAddrType addr_type) const { +inline T Hal::get_dev_addr(uint32_t programmable_core_type_index, HalL1MemAddrType addr_type) const { TT_ASSERT(programmable_core_type_index < this->core_info_.size()); return this->core_info_[programmable_core_type_index].get_dev_addr(addr_type); } -inline uint32_t Hal::get_dev_size(HalProgrammableCoreType programmable_core_type, HalMemAddrType addr_type) const { +inline uint32_t Hal::get_dev_size(HalProgrammableCoreType programmable_core_type, HalL1MemAddrType addr_type) const { uint32_t index = utils::underlying_type(programmable_core_type); TT_ASSERT(index < this->core_info_.size()); return this->core_info_[index].get_dev_size(addr_type); diff --git a/tt_metal/llrt/llrt.cpp b/tt_metal/llrt/llrt.cpp index 4322c9d43c9..152689f8d9f 100644 --- a/tt_metal/llrt/llrt.cpp +++ b/tt_metal/llrt/llrt.cpp @@ -281,7 +281,7 @@ static bool check_if_riscs_on_specified_core_done(chip_id_t chip_id, const CoreC tt_metal::HalProgrammableCoreType dispatch_core_type = is_active_eth_core ? tt_metal::HalProgrammableCoreType::ACTIVE_ETH : is_inactive_eth_core ? tt_metal::HalProgrammableCoreType::IDLE_ETH : tt_metal::HalProgrammableCoreType::TENSIX; - uint64_t go_msg_addr = tt_metal::hal.get_dev_addr(dispatch_core_type, tt_metal::HalMemAddrType::GO_MSG); + uint64_t go_msg_addr = tt_metal::hal.get_dev_addr(dispatch_core_type, tt_metal::HalL1MemAddrType::GO_MSG); auto get_mailbox_is_done = [&](uint64_t go_msg_addr) { constexpr int RUN_MAILBOX_BOGUS = 3; diff --git a/tt_metal/llrt/wormhole/wh_hal_active_eth.cpp b/tt_metal/llrt/wormhole/wh_hal_active_eth.cpp index 496707f07e6..dc1d16aedf7 100644 --- a/tt_metal/llrt/wormhole/wh_hal_active_eth.cpp +++ b/tt_metal/llrt/wormhole/wh_hal_active_eth.cpp @@ -27,29 +27,29 @@ HalCoreInfoType create_active_eth_mem_map() { std::vector mem_map_bases; - mem_map_bases.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_bases[utils::underlying_type(HalMemAddrType::BARRIER)] = MEM_L1_BARRIER; - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH)] = GET_ETH_MAILBOX_ADDRESS_HOST(launch); - mem_map_bases[utils::underlying_type(HalMemAddrType::WATCHER)] = GET_ETH_MAILBOX_ADDRESS_HOST(watcher); - mem_map_bases[utils::underlying_type(HalMemAddrType::DPRINT)] = GET_ETH_MAILBOX_ADDRESS_HOST(dprint_buf); - mem_map_bases[utils::underlying_type(HalMemAddrType::PROFILER)] = GET_ETH_MAILBOX_ADDRESS_HOST(profiler); - mem_map_bases[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)] = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::CORE_INFO)] = GET_ETH_MAILBOX_ADDRESS_HOST(core_info); - mem_map_bases[utils::underlying_type(HalMemAddrType::GO_MSG)] = GET_ETH_MAILBOX_ADDRESS_HOST(go_message); - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_ETH_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::BARRIER)] = MEM_L1_BARRIER; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = GET_ETH_MAILBOX_ADDRESS_HOST(launch); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::WATCHER)] = GET_ETH_MAILBOX_ADDRESS_HOST(watcher); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::DPRINT)] = GET_ETH_MAILBOX_ADDRESS_HOST(dprint_buf); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::PROFILER)] = GET_ETH_MAILBOX_ADDRESS_HOST(profiler); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::CORE_INFO)] = GET_ETH_MAILBOX_ADDRESS_HOST(core_info); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = GET_ETH_MAILBOX_ADDRESS_HOST(go_message); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_ETH_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); std::vector mem_map_sizes; - mem_map_sizes.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_sizes[utils::underlying_type(HalMemAddrType::BARRIER)] = sizeof(uint32_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH)] = sizeof(launch_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::WATCHER)] = sizeof(watcher_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::PROFILER)] = sizeof(profiler_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_SIZE; - mem_map_sizes[utils::underlying_type(HalMemAddrType::UNRESERVED)] = eth_l1_mem::address_map::MAX_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; - mem_map_sizes[utils::underlying_type(HalMemAddrType::GO_MSG)] = sizeof(go_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); + mem_map_sizes.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::BARRIER)] = sizeof(uint32_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = sizeof(launch_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::WATCHER)] = sizeof(watcher_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::PROFILER)] = sizeof(profiler_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = eth_l1_mem::address_map::ERISC_L1_KERNEL_CONFIG_SIZE; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = eth_l1_mem::address_map::MAX_SIZE - eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); return {HalProgrammableCoreType::ACTIVE_ETH, CoreType::ETH, num_proc_per_active_eth_core, mem_map_bases, mem_map_sizes, false}; } diff --git a/tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp b/tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp index 3d04196a46e..da5d509eae6 100644 --- a/tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp +++ b/tt_metal/llrt/wormhole/wh_hal_idle_eth.cpp @@ -27,29 +27,29 @@ HalCoreInfoType create_idle_eth_mem_map() { std::vector mem_map_bases; - mem_map_bases.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_bases[utils::underlying_type(HalMemAddrType::BARRIER)] = MEM_L1_BARRIER; - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch); - mem_map_bases[utils::underlying_type(HalMemAddrType::WATCHER)] = GET_IERISC_MAILBOX_ADDRESS_HOST(watcher); - mem_map_bases[utils::underlying_type(HalMemAddrType::DPRINT)] = GET_IERISC_MAILBOX_ADDRESS_HOST(dprint_buf); - mem_map_bases[utils::underlying_type(HalMemAddrType::PROFILER)] = GET_IERISC_MAILBOX_ADDRESS_HOST(profiler); - mem_map_bases[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = IDLE_ERISC_L1_KERNEL_CONFIG_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; // TODO: this is wrong, need idle eth specific value - mem_map_bases[utils::underlying_type(HalMemAddrType::CORE_INFO)] = GET_IERISC_MAILBOX_ADDRESS_HOST(core_info); - mem_map_bases[utils::underlying_type(HalMemAddrType::GO_MSG)] = GET_IERISC_MAILBOX_ADDRESS_HOST(go_message); - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::BARRIER)] = MEM_L1_BARRIER; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::WATCHER)] = GET_IERISC_MAILBOX_ADDRESS_HOST(watcher); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::DPRINT)] = GET_IERISC_MAILBOX_ADDRESS_HOST(dprint_buf); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::PROFILER)] = GET_IERISC_MAILBOX_ADDRESS_HOST(profiler); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = IDLE_ERISC_L1_KERNEL_CONFIG_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; // TODO: this is wrong, need idle eth specific value + mem_map_bases[utils::underlying_type(HalL1MemAddrType::CORE_INFO)] = GET_IERISC_MAILBOX_ADDRESS_HOST(core_info); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = GET_IERISC_MAILBOX_ADDRESS_HOST(go_message); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_IERISC_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); std::vector mem_map_sizes; - mem_map_sizes.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_sizes[utils::underlying_type(HalMemAddrType::BARRIER)] = sizeof(uint32_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH)] = sizeof(launch_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::WATCHER)] = sizeof(watcher_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::PROFILER)] = sizeof(profiler_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; // TODO: this is wrong, need idle eth specific value - mem_map_sizes[utils::underlying_type(HalMemAddrType::UNRESERVED)] = MEM_ETH_SIZE - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)]; - mem_map_sizes[utils::underlying_type(HalMemAddrType::GO_MSG)] = sizeof(go_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); + mem_map_sizes.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::BARRIER)] = sizeof(uint32_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = sizeof(launch_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::WATCHER)] = sizeof(watcher_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::PROFILER)] = sizeof(profiler_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; // TODO: this is wrong, need idle eth specific value + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = MEM_ETH_SIZE - mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)]; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); return {HalProgrammableCoreType::IDLE_ETH, CoreType::ETH, num_proc_per_idle_eth_core, mem_map_bases, mem_map_sizes, false}; } diff --git a/tt_metal/llrt/wormhole/wh_hal_tensix.cpp b/tt_metal/llrt/wormhole/wh_hal_tensix.cpp index bd5e90b3ae6..2dec8efbda3 100644 --- a/tt_metal/llrt/wormhole/wh_hal_tensix.cpp +++ b/tt_metal/llrt/wormhole/wh_hal_tensix.cpp @@ -25,29 +25,29 @@ HalCoreInfoType create_tensix_mem_map() { std::vector mem_map_bases; - mem_map_bases.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_bases[utils::underlying_type(HalMemAddrType::BARRIER)] = MEM_L1_BARRIER; - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH)] = GET_MAILBOX_ADDRESS_HOST(launch); - mem_map_bases[utils::underlying_type(HalMemAddrType::WATCHER)] = GET_MAILBOX_ADDRESS_HOST(watcher); - mem_map_bases[utils::underlying_type(HalMemAddrType::DPRINT)] = GET_MAILBOX_ADDRESS_HOST(dprint_buf); - mem_map_bases[utils::underlying_type(HalMemAddrType::PROFILER)] = GET_MAILBOX_ADDRESS_HOST(profiler); - mem_map_bases[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_BASE; - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; - mem_map_bases[utils::underlying_type(HalMemAddrType::CORE_INFO)] = GET_MAILBOX_ADDRESS_HOST(core_info); - mem_map_bases[utils::underlying_type(HalMemAddrType::GO_MSG)] = GET_MAILBOX_ADDRESS_HOST(go_message); - mem_map_bases[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); + mem_map_bases.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::BARRIER)] = MEM_L1_BARRIER; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = GET_MAILBOX_ADDRESS_HOST(launch); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::WATCHER)] = GET_MAILBOX_ADDRESS_HOST(watcher); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::DPRINT)] = GET_MAILBOX_ADDRESS_HOST(dprint_buf); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::PROFILER)] = GET_MAILBOX_ADDRESS_HOST(profiler); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_BASE; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = ((L1_KERNEL_CONFIG_BASE + L1_KERNEL_CONFIG_SIZE - 1) | (max_alignment - 1)) + 1; + mem_map_bases[utils::underlying_type(HalL1MemAddrType::CORE_INFO)] = GET_MAILBOX_ADDRESS_HOST(core_info); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = GET_MAILBOX_ADDRESS_HOST(go_message); + mem_map_bases[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = GET_MAILBOX_ADDRESS_HOST(launch_msg_rd_ptr); std::vector mem_map_sizes; - mem_map_sizes.resize(utils::underlying_type(HalMemAddrType::COUNT)); - mem_map_sizes[utils::underlying_type(HalMemAddrType::BARRIER)] = sizeof(uint32_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH)] = sizeof(launch_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::WATCHER)] = sizeof(watcher_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::PROFILER)] = sizeof(profiler_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; - mem_map_sizes[utils::underlying_type(HalMemAddrType::UNRESERVED)] = MEM_L1_SIZE - mem_map_bases[utils::underlying_type(HalMemAddrType::UNRESERVED)]; - mem_map_sizes[utils::underlying_type(HalMemAddrType::GO_MSG)] = sizeof(go_msg_t); - mem_map_sizes[utils::underlying_type(HalMemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); + mem_map_sizes.resize(utils::underlying_type(HalL1MemAddrType::COUNT)); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::BARRIER)] = sizeof(uint32_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH)] = sizeof(launch_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::WATCHER)] = sizeof(watcher_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::DPRINT)] = sizeof(dprint_buf_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::PROFILER)] = sizeof(profiler_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::KERNEL_CONFIG)] = L1_KERNEL_CONFIG_SIZE; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::UNRESERVED)] = MEM_L1_SIZE - mem_map_bases[utils::underlying_type(HalL1MemAddrType::UNRESERVED)]; + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::GO_MSG)] = sizeof(go_msg_t); + mem_map_sizes[utils::underlying_type(HalL1MemAddrType::LAUNCH_MSG_BUFFER_RD_PTR)] = sizeof(uint32_t); return {HalProgrammableCoreType::TENSIX, CoreType::WORKER, num_proc_per_tensix_core, mem_map_bases, mem_map_sizes, true}; } diff --git a/tt_metal/tools/profiler/profiler.cpp b/tt_metal/tools/profiler/profiler.cpp index cae81629ad9..f55611b1a3a 100644 --- a/tt_metal/tools/profiler/profiler.cpp +++ b/tt_metal/tools/profiler/profiler.cpp @@ -38,13 +38,13 @@ void DeviceProfiler::readRiscProfilerResults( auto ethCores = soc_d.get_physical_ethernet_cores() ; if (std::find(ethCores.begin(), ethCores.end(), worker_core) == ethCores.end()) { - profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::PROFILER); + profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::PROFILER); CoreType = HalProgrammableCoreType::TENSIX; riscCount = 5; } else { - profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::PROFILER); + profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::PROFILER); CoreType = HalProgrammableCoreType::ACTIVE_ETH; riscCount = 1; } @@ -189,7 +189,7 @@ void DeviceProfiler::readRiscProfilerResults( std::vector control_buffer_reset(kernel_profiler::PROFILER_L1_CONTROL_VECTOR_SIZE, 0); control_buffer_reset[kernel_profiler::DRAM_PROFILER_ADDRESS] = output_dram_buffer->address(); - profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::PROFILER); + profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::PROFILER); tt::llrt::write_hex_vec_to_core( device_id, worker_core, diff --git a/tt_metal/tools/profiler/tt_metal_profiler.cpp b/tt_metal/tools/profiler/tt_metal_profiler.cpp index bf8fcf952c5..40d38a13c62 100644 --- a/tt_metal/tools/profiler/tt_metal_profiler.cpp +++ b/tt_metal/tools/profiler/tt_metal_profiler.cpp @@ -75,12 +75,12 @@ void setControlBuffer(uint32_t device_id, std::vector& control_buffer) if (std::find(ethCores.begin(), ethCores.end(), core.first) == ethCores.end()) { //Tensix - profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalMemAddrType::PROFILER); + profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::PROFILER); } else { //ETH - profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalMemAddrType::PROFILER); + profiler_msg = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::PROFILER); } control_buffer[kernel_profiler::FLAT_ID] = core.second; @@ -137,7 +137,7 @@ void syncDeviceHost(Device *device, CoreCoord logical_core, std::shared_ptr writeTimes(sampleCount); - profiler_msg_t *profiler_msg = device->get_dev_addr(core, HalMemAddrType::PROFILER); + profiler_msg_t *profiler_msg = device->get_dev_addr(core, HalL1MemAddrType::PROFILER); uint64_t control_addr = reinterpret_cast(&profiler_msg->control_vector[kernel_profiler::FW_RESET_L]); for (int i = 0; i < sampleCount; i++) { @@ -417,7 +417,7 @@ void DumpDeviceProfileResults(Device *device, std::vector &worker_cor for (const CoreCoord& core : tt::get_logical_dispatch_cores(device_id, device_num_hw_cqs, dispatch_core_type)) { const auto curr_core = device->physical_core_from_logical_core(core, dispatch_core_type); - profiler_msg_t *profiler_msg = device->get_dev_addr(curr_core, HalMemAddrType::PROFILER); + profiler_msg_t *profiler_msg = device->get_dev_addr(curr_core, HalL1MemAddrType::PROFILER); vector control_buffer = tt::llrt::read_hex_vec_from_core( device_id, curr_core, @@ -437,7 +437,7 @@ void DumpDeviceProfileResults(Device *device, std::vector &worker_cor for (const CoreCoord& core : tt::Cluster::instance().get_soc_desc(device_id).physical_ethernet_cores) { const auto curr_core = device->physical_core_from_logical_core(core, CoreType::ETH); - profiler_msg_t *profiler_msg = device->get_dev_addr(curr_core, HalMemAddrType::PROFILER); + profiler_msg_t *profiler_msg = device->get_dev_addr(curr_core, HalL1MemAddrType::PROFILER); vector control_buffer = tt::llrt::read_hex_vec_from_core( device_id, core, diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 42def67b5ad..b8f6cf9786e 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -704,7 +704,7 @@ void LaunchProgram(Device *device, Program &program, bool wait_until_cores_done, auto physical_core = device->physical_core_from_logical_core(logical_core, core_type); not_done_cores.insert(physical_core); - tt::llrt::write_launch_msg_to_core(device->id(), physical_core, msg, go_msg, device->get_dev_addr(physical_core, HalMemAddrType::LAUNCH)); + tt::llrt::write_launch_msg_to_core(device->id(), physical_core, msg, go_msg, device->get_dev_addr(physical_core, HalL1MemAddrType::LAUNCH)); } } if (wait_until_cores_done) { @@ -779,7 +779,7 @@ bool ConfigureDeviceWithProgram(Device *device, Program &program, bool fd_bootlo } // PROF_END("CBS") if (cbs_on_core.size()) { - uint64_t kernel_config_base = hal.get_dev_addr(index, HalMemAddrType::KERNEL_CONFIG); + uint64_t kernel_config_base = hal.get_dev_addr(index, HalL1MemAddrType::KERNEL_CONFIG); uint64_t addr = kernel_config_base + program.get_program_config(index).cb_offset; llrt::write_hex_vec_to_core(device_id, physical_core, circular_buffer_config_vec, addr); } From 35ce478515aa95f3c4475544842fbb0a60b850bb Mon Sep 17 00:00:00 2001 From: Virdhatchani Narayanamoorthy <138196495+VirdhatchaniKN@users.noreply.github.com> Date: Wed, 9 Oct 2024 11:51:53 +0700 Subject: [PATCH 22/58] #13517: Remove backward relational ops (#13518) --- .../sweep_tests/pytorch_ops.py | 30 --- .../python_api_testing/sweep_tests/op_map.py | 12 - .../sweep_tests/ttnn_ops.py | 60 ----- .../operations/backward/test_backward_lt.py | 208 ------------------ .../binary_backward/binary_backward.cpp | 93 -------- .../binary_backward/binary_backward.hpp | 4 - .../binary_backward_pybind.hpp | 88 -------- .../eltwise/unary_backward/unary_backward.cpp | 22 -- ttnn/ttnn/operations/binary_backward.py | 18 -- 9 files changed, 535 deletions(-) delete mode 100644 tests/ttnn/unit_tests/operations/backward/test_backward_lt.py diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index d77f952d7fd..7362670f625 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -1843,36 +1843,6 @@ def log_bw(x, y, *args, **kwargs): return in_data.grad -def gt_bw(x, *args, **kwargs): - grad_data = x - - pyt_y = torch.zeros_like(grad_data) - - golden_tensor = pyt_y - - return golden_tensor - - -def lt_bw(x, *args, **kwargs): - grad_data = x - - pyt_y = torch.zeros_like(grad_data) - - golden_tensor = pyt_y - - return golden_tensor - - -def ne_bw(x, *args, **kwargs): - grad_data = x - - pyt_y = torch.zeros_like(grad_data) - - golden_tensor = pyt_y - - return golden_tensor - - def rsub_bw(x, y, z, *args, **kwargs): grad_data = x in_data = y diff --git a/tests/ttnn/python_api_testing/sweep_tests/op_map.py b/tests/ttnn/python_api_testing/sweep_tests/op_map.py index 824cb5a5799..49f08546578 100644 --- a/tests/ttnn/python_api_testing/sweep_tests/op_map.py +++ b/tests/ttnn/python_api_testing/sweep_tests/op_map.py @@ -810,18 +810,6 @@ "tt_op": ttnn_ops.relu_bw, "pytorch_op": pytorch_ops.relu_bw, }, - "gt-bw": { - "tt_op": ttnn_ops.gt_bw, - "pytorch_op": pytorch_ops.gt_bw, - }, - "lt-bw": { - "tt_op": ttnn_ops.gt_bw, - "pytorch_op": pytorch_ops.gt_bw, - }, - "ne-bw": { - "tt_op": ttnn_ops.ne_bw, - "pytorch_op": pytorch_ops.ne_bw, - }, "log10-bw": { "tt_op": ttnn_ops.log10_bw, "pytorch_op": pytorch_ops.log10_bw, diff --git a/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py b/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py index 7606da60477..5388de12fc4 100644 --- a/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py +++ b/tests/ttnn/python_api_testing/sweep_tests/ttnn_ops.py @@ -3791,66 +3791,6 @@ def relu_bw( return ttnn_tensor_to_torch(t2) -def gt_bw( - x, # grad_tensor - y, # input_tensor - *args, - scalar, - device, - dtype, - layout, - input_mem_config, - output_mem_config, - **kwargs, -): - t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = setup_ttnn_tensor(y, device, layout[1], input_mem_config[1], dtype[1]) - - t2 = ttnn.gt_bw(t0, t1, alpha=scalar, memory_config=output_mem_config)[0] - - return ttnn_tensor_to_torch(t2) - - -def lt_bw( - x, # grad_tensor - y, # input_tensor - *args, - scalar, - device, - dtype, - layout, - input_mem_config, - output_mem_config, - **kwargs, -): - t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = setup_ttnn_tensor(y, device, layout[1], input_mem_config[1], dtype[1]) - - t2 = ttnn.lt_bw(t0, t1, alpha=scalar, memory_config=output_mem_config)[0] - - return ttnn_tensor_to_torch(t2) - - -def ne_bw( - x, # grad_tensor - y, # input_tensor - *args, - scalar, - device, - dtype, - layout, - input_mem_config, - output_mem_config, - **kwargs, -): - t0 = setup_ttnn_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = setup_ttnn_tensor(y, device, layout[1], input_mem_config[1], dtype[1]) - - t2 = ttnn.ne_bw(t0, t1, alpha=scalar, memory_config=output_mem_config)[0] - - return ttnn_tensor_to_torch(t2) - - def log10_bw( x, y, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_lt.py b/tests/ttnn/unit_tests/operations/backward/test_backward_lt.py deleted file mode 100644 index 170a41446d4..00000000000 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_lt.py +++ /dev/null @@ -1,208 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import pytest -import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc - - -@pytest.mark.parametrize( - "input_shapes", - ( - (torch.Size([1, 1, 32, 32])), - (torch.Size([1, 1, 320, 384])), - (torch.Size([1, 3, 320, 384])), - ), -) -def test_bw_lt(input_shapes, device): - in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - other_data, other_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) - - tt_output_tensor_on_device = ttnn.lt_bw(grad_tensor, input_tensor, other_tensor) - - golden_function = ttnn.get_golden_function(ttnn.lt_bw) - golden_tensor = golden_function(grad_data, in_data, other_data) - - status = compare_pcc(tt_output_tensor_on_device, golden_tensor) - assert status - - -@pytest.mark.parametrize( - "input_shapes", - ( - (torch.Size([1, 1, 32, 32])), - (torch.Size([1, 1, 320, 384])), - (torch.Size([1, 3, 320, 384])), - ), -) -@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]]) -def test_bw_lt_with_opt_output(input_shapes, device, are_required_outputs): - in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) - grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device) - input_grad = None - other_grad = None - - if are_required_outputs[0]: - _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) - if are_required_outputs[1]: - _, other_grad = data_gen_with_range(input_shapes, -1, 1, device) - - cq_id = 0 - - pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.lt_bw( - grad_tensor, - input_tensor, - other_tensor, - are_required_outputs=are_required_outputs, - input_grad=input_grad, - other_grad=other_grad, - queue_id=cq_id, - ) - assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - tt_output_tensor_on_device = [input_grad, other_grad] - - golden_function = ttnn.get_golden_function(ttnn.lt_bw) - golden_tensor = golden_function(grad_data, in_data, other_data) - - status = True - for i in range(len(are_required_outputs)): - if are_required_outputs[i]: - status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) - assert status - - -@pytest.mark.parametrize( - "input_shapes", - ( - (torch.Size([1, 1, 32, 32])), - (torch.Size([1, 1, 320, 384])), - (torch.Size([1, 3, 320, 384])), - ), -) -@pytest.mark.parametrize("scalar", [1.0, 0.5, 0.035]) -def test_bw_lt_scalar(input_shapes, scalar, device): - in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - grad_data, grad_tensor = data_gen_with_range(input_shapes, -100, 100, device) - - tt_output_tensor_on_device = ttnn.lt_bw(grad_tensor, input_tensor, scalar) - - golden_function = ttnn.get_golden_function(ttnn.lt_bw) - golden_tensor = golden_function(grad_data, in_data, scalar) - - status = compare_pcc(tt_output_tensor_on_device, golden_tensor) - assert status - - -@pytest.mark.parametrize( - "input_shapes", - ( - (torch.Size([1, 1, 32, 32])), - (torch.Size([1, 1, 320, 384])), - (torch.Size([1, 3, 320, 384])), - ), -) -@pytest.mark.parametrize("scalar", [1.0, 0.5, 0.035]) -def test_bw_lt_with_scalar_opt_output(input_shapes, device, scalar): - in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device) - input_grad = None - _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) - - cq_id = 0 - - pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.lt_bw( - grad_tensor, - input_tensor, - scalar, - input_grad=input_grad, - queue_id=cq_id, - ) - assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - tt_output_tensor_on_device = [input_grad] - - golden_function = ttnn.get_golden_function(ttnn.lt_bw) - golden_tensor = golden_function(grad_data, in_data, scalar) - - status = compare_pcc(tt_output_tensor_on_device, golden_tensor) - assert status - - -@pytest.mark.parametrize( - "input_shapes", - ( - (torch.Size([1, 1, 32, 32])), - (torch.Size([1, 1, 320, 384])), - (torch.Size([1, 3, 320, 384])), - ), -) -@pytest.mark.parametrize("are_required_outputs", [[True, True], [True, False], [False, True]]) -def test_bw_lt_with_opt_output_opt_qid(input_shapes, device, are_required_outputs): - in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - other_data, other_tensor = data_gen_with_range(input_shapes, -90, 100, device, True) - grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device) - input_grad = None - other_grad = None - - if are_required_outputs[0]: - _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) - if are_required_outputs[1]: - _, other_grad = data_gen_with_range(input_shapes, -1, 1, device) - - pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.lt_bw( - grad_tensor, - input_tensor, - other_tensor, - are_required_outputs=are_required_outputs, - input_grad=input_grad, - other_grad=other_grad, - ) - assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - tt_output_tensor_on_device = [input_grad, other_grad] - - golden_function = ttnn.get_golden_function(ttnn.lt_bw) - golden_tensor = golden_function(grad_data, in_data, other_data) - - status = True - for i in range(len(are_required_outputs)): - if are_required_outputs[i]: - status = status & compare_pcc([tt_output_tensor_on_device[i]], [golden_tensor[i]]) - assert status - - -@pytest.mark.parametrize( - "input_shapes", - ( - (torch.Size([1, 1, 32, 32])), - (torch.Size([1, 1, 320, 384])), - (torch.Size([1, 3, 320, 384])), - ), -) -@pytest.mark.parametrize("scalar", [1.0, 0.5, 0.035]) -def test_bw_lt_with_scalar_opt_output_opt_qid(input_shapes, device, scalar): - in_data, input_tensor = data_gen_with_range(input_shapes, -100, 100, device, True) - grad_data, grad_tensor = data_gen_with_range(input_shapes, -70, 90, device) - input_grad = None - _, input_grad = data_gen_with_range(input_shapes, -1, 1, device) - - pages_before = ttnn._ttnn.reports.get_buffer_pages() - ttnn.lt_bw( - grad_tensor, - input_tensor, - scalar, - input_grad=input_grad, - ) - assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - tt_output_tensor_on_device = [input_grad] - - golden_function = ttnn.get_golden_function(ttnn.lt_bw) - golden_tensor = golden_function(grad_data, in_data, scalar) - - status = compare_pcc(tt_output_tensor_on_device, golden_tensor) - assert status diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp index 5c96919fd7a..2238f055fa2 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp @@ -402,48 +402,6 @@ std::vector ExecuteBackwardSquaredDifference::invoke( return grad_tensor; } -std::vector> _eq_bw( - uint8_t cq_id, - const Tensor& grad, - const Tensor& input, - const Tensor& other, - const MemoryConfig& output_mem_config, - const std::vector& are_required_outputs, - std::optional input_grad, - std::optional other_grad) { - std::vector> result; - - if (are_required_outputs.at(0)) { - input_grad = ttnn::full_like(input, 0.0f); - result.emplace_back(input_grad); - } else { - result.emplace_back(std::nullopt); - } - if (are_required_outputs.at(1)) { - other_grad = ttnn::full_like(grad, 0.0f); - result.emplace_back(other_grad); - } else { - result.emplace_back(std::nullopt); - } - return result; -} - -std::vector _eq_bw_inter( - const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { - auto result = _eq_bw(0, grad, input, other, output_mem_config, {true, true}, std::nullopt, std::nullopt); - std::vector output_tensors; - output_tensors.reserve(result.size()); - - for (const auto& opt_tensor : result) { - if (opt_tensor) { - output_tensors.emplace_back(*opt_tensor); - } else { - output_tensors.emplace_back(); - } - } - return output_tensors; -} - std::vector> ExecuteBackwardAssign::invoke( uint8_t cq_id, const Tensor& grad, const Tensor& input, const Tensor& other, const std::vector& are_required_outputs, const std::optional& output_mem_config, std::optional input_grad, std::optional other_grad) { @@ -536,15 +494,6 @@ std::vector> ExecuteBackwardConcat::invoke( } -std::vector _binary_comp_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config) { - std::vector grad_tensor; - Tensor zero_grad = ttnn::zeros_like(grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config); - grad_tensor.emplace_back(zero_grad); - Tensor zero_input = ttnn::zeros_like(input, input.get_dtype(), input.get_layout(), std::nullopt, output_mem_config); - grad_tensor.emplace_back(zero_input); - return grad_tensor; -} - std::vector> ExecuteBackwardRsub::invoke( uint8_t queue_id, const Tensor& grad, @@ -603,48 +552,6 @@ std::vector ExecuteBackwardBiasGelu::invoke( return grad_tensor; } -std::vector> ExecuteBackwardLT::invoke( - uint8_t queue_id, const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config, - const std::vector& are_required_outputs, - std::optional input_grad, - std::optional other_grad) { - std::vector> result = {std::nullopt, std::nullopt}; - result[0] = input_grad.has_value() ? ttnn::zeros_like(queue_id, grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config, input_grad) : ttnn::zeros_like(queue_id, grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config); - result[1] = other_grad.has_value() ? ttnn::zeros_like(queue_id, input, input.get_dtype(), input.get_layout(), std::nullopt, output_mem_config, other_grad) : ttnn::zeros_like(queue_id, input, input.get_dtype(), input.get_layout(), std::nullopt, output_mem_config); - return result; -} - -std::vector> ExecuteBackwardLT::invoke( - uint8_t queue_id, const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config, - std::optional input_grad) { - std::vector> result = {std::nullopt}; - result[0] = input_grad.has_value() ? ttnn::zeros_like(queue_id, grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config, input_grad) : ttnn::zeros_like(queue_id, grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config); - return result; -} - -std::vector> ExecuteBackwardLT::invoke( - const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config, - const std::vector& are_required_outputs, - std::optional input_grad, - std::optional other_grad) { - return ExecuteBackwardLT::invoke(ttnn::DefaultQueueId, grad, input, other, output_mem_config, are_required_outputs, input_grad, other_grad); -} - -std::vector> ExecuteBackwardLT::invoke( - const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config, - std::optional input_grad) { - return ExecuteBackwardLT::invoke(ttnn::DefaultQueueId, grad, input, other, output_mem_config, input_grad); -} - - -std::vector _gt_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const std::optional& output_mem_config) { - return _binary_comp_bw(grad, input, other, output_mem_config); -} - -std::vector _ge_bw(const Tensor& grad, const Tensor& input, const Tensor& other, const MemoryConfig& output_mem_config) { - return _binary_comp_bw(grad, input, other, output_mem_config); -} - // template parameter min_or_max = TRUE for MAX, FALSE for MIN template std::vector _min_or_max_bw( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp index 2e00c995f24..59c6b4617a1 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.hpp @@ -513,10 +513,6 @@ constexpr auto addalpha_bw = ttnn::register_operation< "ttnn::addalpha_bw", operations::binary_backward::ExecuteAddalphaBW>(); -constexpr auto lt_bw = ttnn::register_operation< - "ttnn::lt_bw", - operations::binary_backward::ExecuteBackwardLT>(); - constexpr auto add_bw = ttnn::register_operation< "ttnn::add_bw", operations::binary_backward::ExecuteBackwardAdd>(); diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp index d327d5564b8..8b8fdfbc29d 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward_pybind.hpp @@ -661,89 +661,6 @@ void bind_binary_bw(py::module& module, const binary_backward_operation_t& opera py::arg("memory_config") = std::nullopt}); } -template -void bind_binary_bw_optional(py::module& module, const binary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") { - auto doc = fmt::format( - R"doc( - - {2} - - Args: - grad_tensor (ttnn.Tensor): the input gradient tensor. - input_tensor (ttnn.Tensor): the input tensor. - other_tensor (ttnn.Tensor or Number): the input tensor. - - Keyword args: - are_required_outputs (List[bool], optional): List of required outputs. Defaults to `[True, True]`. - memory_config (ttnn.MemoryConfig, optional): Memory configuration for the operation. Defaults to `None`. - input_grad (ttnn.Tensor, optional): Preallocated output tensor for gradient of `input_tensor`. Defaults to `None`. - other_grad (ttnn.Tensor, optional): Preallocated output tensor for gradient of `other_tensor`. Defaults to `None`. - queue_id (int, optional): command queue id. Defaults to `0`. - - Note: - {3} - - Note : bfloat8_b/bfloat4_b is only supported on TILE_LAYOUT - - Example: - >>> grad_tensor = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) - >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16)), device) - >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.tensor((0, 1), dtype=torch.bfloat16)), device) - >>> output = {1}(grad_tensor, tensor1, tensor2) - - )doc", - operation.base_name(), - operation.python_fully_qualified_name(), - description, - supported_dtype); - - bind_registered_operation( - module, - operation, - doc, - // tensor and scalar - ttnn::pybind_overload_t{ - [](const binary_backward_operation_t& self, - const ttnn::Tensor& grad_tensor, - const ttnn::Tensor& input_tensor, - float other, - const std::optional& memory_config, - const std::optional& input_grad, - const uint8_t& queue_id) -> std::vector> { - return self(queue_id, grad_tensor, input_tensor, other, memory_config, input_grad); - }, - py::arg("grad_tensor"), - py::arg("input_tensor"), - py::arg("other"), - py::kw_only(), - py::arg("memory_config") = std::nullopt, - py::arg("input_grad") = std::nullopt, - py::arg("queue_id") = 0}, - - // tensor and tensor - ttnn::pybind_overload_t{ - [](const binary_backward_operation_t& self, - const ttnn::Tensor& grad_tensor, - const ttnn::Tensor& input_tensor, - const ttnn::Tensor& other_tensor, - const std::optional& memory_config, - const std::vector& are_required_outputs, - const std::optional& input_grad, - const std::optional& other_grad, - const uint8_t& queue_id) -> std::vector> { - return self(queue_id, grad_tensor, input_tensor, other_tensor, memory_config, are_required_outputs, input_grad, other_grad); - }, - py::arg("grad_tensor"), - py::arg("input_tensor"), - py::arg("other_tensor"), - py::kw_only(), - py::arg("memory_config") = std::nullopt, - py::arg("are_required_outputs") = std::vector{true, true}, - py::arg("input_grad") = std::nullopt, - py::arg("other_grad") = std::nullopt, - py::arg("queue_id") = 0}); -} - template void bind_binary_bw_div(py::module& module, const binary_backward_operation_t& operation, const std::string_view description, const std::string_view supported_dtype = "") { auto doc = fmt::format( @@ -1103,11 +1020,6 @@ void py_module(py::module& module) { )doc"); - detail::bind_binary_bw_optional( - module, - ttnn::lt_bw, - R"doc(Performs backward operations for less than operation of :attr:`input_tensor_a` and :attr:`input_tensor_b` or :attr:`scalar` with given :attr:`grad_tensor`.)doc"); - detail::bind_binary_backward_ops( module, ttnn::atan2_bw, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp index 4ffd542dcfb..079f756203b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp @@ -260,28 +260,6 @@ std::vector ExecuteUnaryBackwardMultigammaln::invoke(const Tensor& grad, } -std::vector _unary_comp_bw(const Tensor& grad, const std::optional& output_mem_config) { - std::vector grad_tensor; - Tensor zero_grad = ttnn::zeros_like(grad, grad.get_dtype(), grad.get_layout(), std::nullopt, output_mem_config); - grad_tensor.emplace_back(zero_grad); - return grad_tensor; -} - -std::vector _eq_bw( - const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config) { - return _unary_comp_bw(grad, output_mem_config); -} - -std::vector _gt_bw( - const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config) { - return _unary_comp_bw(grad, output_mem_config); -} - -std::vector _ge_bw( - const Tensor& grad, const Tensor& input, float other, const std::optional& output_mem_config) { - return _unary_comp_bw(grad, output_mem_config); -} - std::vector ExecuteUnaryBackwardLgamma::invoke(const Tensor& grad, const Tensor& input, const std::optional& output_mem_config) { auto output_memory_config = output_mem_config.value_or(input.memory_config()); //TODO: Remove after ternary forward ops migration is completed std::vector grad_tensor; diff --git a/ttnn/ttnn/operations/binary_backward.py b/ttnn/ttnn/operations/binary_backward.py index 5c340be0692..ac833338f2f 100644 --- a/ttnn/ttnn/operations/binary_backward.py +++ b/ttnn/ttnn/operations/binary_backward.py @@ -70,14 +70,6 @@ def _golden_function_backward_overload(torch_op, grad_tensor, input_tensor_a, in return golden_tensor -ttnn.attach_golden_function( - ttnn.lt_bw, - golden_function=lambda grad, a, b, *args, **kwargs: _golden_function_comparison_ops( - torch.lt, grad, a, b, *args, **kwargs - ), -) - - def _golden_function_backward_with_dim( torch_op, grad_tensor, input_tensor_a, input_tensor_b, dimension=None, *args, **kwargs ): @@ -148,16 +140,6 @@ def _golden_function_backward_with_string( return golden_tensor -def _golden_function_comparison_ops(torch_op, grad_tensor, input_tensor_a, input_tensor_b, *args, **kwargs): - import torch - - if isinstance(input_tensor_b, (float, int)): - golden_tensor = [torch.zeros_like(input_tensor_a)] - else: - golden_tensor = [torch.zeros_like(input_tensor_a), torch.zeros_like(input_tensor_b)] - return golden_tensor - - ttnn.attach_golden_function( ttnn.sub_bw, golden_function=lambda grad, a, b, *args, **kwargs: _golden_function_backward( From 6a6ba38f42c94e233cf8412e3f24708e1d23a90b Mon Sep 17 00:00:00 2001 From: o2buzzle <76864037+o2buzzle@users.noreply.github.com> Date: Tue, 10 Sep 2024 09:09:09 +0000 Subject: [PATCH 23/58] #12554: port moreh_matmul_backward --- .../operations/test_moreh_matmul.py | 75 +++++++++++++++++ ttnn/CMakeLists.txt | 3 + .../examples/example/example_pybind.hpp | 14 ++-- .../device/moreh_matmul_program_factory.cpp | 2 +- .../moreh_matmul_backward.cpp | 81 +++++++++++++++++++ .../moreh_matmul_backward.hpp | 26 ++++++ .../moreh_matmul_backward_pybind.cpp | 28 +++++++ .../moreh_matmul_backward_pybind.hpp | 13 +++ .../ttnn/operations/moreh/moreh_pybind.cpp | 2 + ttnn/ttnn/operations/moreh.py | 1 + 10 files changed, 238 insertions(+), 7 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.hpp diff --git a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py index 3dcff16cfd4..96191571dd7 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py @@ -306,3 +306,78 @@ def test_moreh_matmul_1d(input_shape, device): logger.debug(f"Output pcc={output_pcc}") assert passing + + +@pytest.mark.parametrize( + "params", + ( + # input, other, output shape + ([3, 128, 96], [3, 4, 1, 96, 256], [3, 4, 3, 128, 256]), + ([3, 3, 313, 511], [3, 3, 511, 765], [3, 3, 313, 765]), + ([3, 1, 2, 1, 4, 1, 319, 95], [4, 2, 95, 470], [3, 1, 2, 1, 4, 2, 319, 470]), + ([3, 2, 1, 470, 95], [2, 1, 3, 1, 2, 2, 95, 319], [2, 1, 3, 3, 2, 2, 470, 319]), + ), +) +@pytest.mark.parametrize( + "requires_grad", + ( + (True, False), + (False, True), + (True, True), + ), +) +def test_moreh_matmul_backward(params, requires_grad, device): + torch.manual_seed(3072) + input_shape, other_shape, output_shape = params + require_input_grad, require_other_grad = requires_grad + + # get tensors + ( + tt_input, + tt_other, + _, + tt_output_grad, + tt_input_grad, + tt_other_grad, + torch_input, + torch_other, + torch_output_grad, + ) = get_tensors(input_shape, other_shape, output_shape, require_input_grad, require_other_grad, False, device) + + # torch matmul + torch_out = torch.matmul( + torch_input.requires_grad_(require_input_grad), torch_other.requires_grad_(require_other_grad) + ) + torch_out.backward(torch_output_grad) + + # tt matmul backward + tt_input_grad, tt_other_grad = ttnn.operations.moreh.matmul_backward( + tt_output_grad, + tt_input, + tt_other, + are_required_outputs=(require_input_grad, require_other_grad), + input_a_grad=tt_input_grad, + input_b_grad=tt_other_grad, + ) + + # test for equivalance + rtol = atol = 0.1 + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + if require_input_grad: + ttcpu_input_grad = tt_input_grad.cpu().to(cpu_layout).unpad_from_tile(input_shape).to_torch() + passing, output_pcc = comp_allclose_and_pcc(torch_input.grad, ttcpu_input_grad, pcc=0.999, rtol=rtol, atol=atol) + logger.debug(f"input_grad passing={passing}") + logger.debug(f"input_grad pcc={output_pcc}") + assert passing + else: + assert tt_input_grad is None + + if require_other_grad: + ttcpu_other_grad = tt_other_grad.cpu().to(cpu_layout).unpad_from_tile(other_shape).to_torch() + passing, output_pcc = comp_allclose_and_pcc(torch_other.grad, ttcpu_other_grad, pcc=0.999, rtol=rtol, atol=atol) + logger.debug(f"other_grad passing={passing}") + logger.debug(f"other_grad pcc={output_pcc}") + assert passing + else: + assert tt_other_grad is None + diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 4a18128dc8e..e9591d11070 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -421,6 +421,8 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul/moreh_matmul.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward_pybind.cpp @@ -497,6 +499,7 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_sum/moreh_sum.cpp + ) #Split src and python bindings diff --git a/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp b/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp index 1d8d5febffa..732ad255e4e 100644 --- a/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp @@ -16,7 +16,6 @@ namespace py = pybind11; namespace ttnn::operations::examples { void bind_example_operation(py::module& module) { - bind_registered_operation( module, ttnn::prim::example, @@ -25,10 +24,12 @@ void bind_example_operation(py::module& module) { // Add pybind overloads for the C++ APIs that should be exposed to python // There should be no logic here, just a call to `self` with the correct arguments // The overload with `queue_id` argument will be added automatically for primitive operations - // This specific function can be called from python as `ttnn.prim.example(input_tensor)` or `ttnn.prim.example(input_tensor, queue_id=queue_id)` + // This specific function can be called from python as `ttnn.prim.example(input_tensor)` or + // `ttnn.prim.example(input_tensor, queue_id=queue_id)` ttnn::pybind_overload_t{ - [](const decltype(ttnn::prim::example)& self, const ttnn::Tensor& input_tensor) - -> ttnn::Tensor { return self(input_tensor); }, + [](const decltype(ttnn::prim::example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor { + return self(input_tensor); + }, py::arg("input_tensor")}); bind_registered_operation( @@ -39,8 +40,9 @@ void bind_example_operation(py::module& module) { // Add pybind overloads for the C++ APIs that should be exposed to python // There should be no logic here, just a call to `self` with the correct arguments ttnn::pybind_overload_t{ - [](const decltype(ttnn::composite_example)& self, const ttnn::Tensor& input_tensor) - -> ttnn::Tensor { return self(input_tensor); }, + [](const decltype(ttnn::composite_example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor { + return self(input_tensor); + }, py::arg("input_tensor")}); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp index 8976019c2ac..987c7a5ce04 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp @@ -40,7 +40,7 @@ std::vector find_reduce_dim( // batch dims for (int i = 0; i < rank - 2; ++i) { int idx = rank - 1 - i; - TT_ASSERT(idx >= 0); + TT_ASSERT(idx >= 0, "idx < 0"); if (a_dim[idx] != b_dim[idx]) { dims.push_back(i); log_debug(tt::LogOp, "find_reduce_dim :{} push {} dim", __LINE__, i); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp new file mode 100644 index 00000000000..bf205bbe7c5 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp @@ -0,0 +1,81 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_matmul_backward.hpp" + +#include "ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp" +#include "ttnn/operations/moreh/moreh_matmul/moreh_matmul.hpp" +#include "ttnn/operations/moreh/moreh_sum/moreh_sum.hpp" + +namespace ttnn::operations::moreh::moreh_matmul_backward { +///////////////////////////////////////// +std::vector> MorehMatmulBackward::invoke( + const Tensor& output_grad, + const Tensor& input, + const Tensor& other, + const std::vector& are_required_outputs, + const std::optional& input_grad, + const std::optional& other_grad, + const std::optional& output_mem_config, + const std::optional compute_kernel_config) { + std::vector> outputs(2); + outputs.reserve(2); + + const bool input_requires_grad = are_required_outputs.at(0); + const bool other_requires_grad = are_required_outputs.at(1); + + if (input_requires_grad) { + TT_ASSERT(input_grad.has_value()); + const auto& input_grad_tensor = input_grad.value(); + if (moreh_matmul::is_same_batch_dim(output_grad, input_grad_tensor)) { + const auto& input_grad_shape = input_grad_tensor.get_legacy_shape().without_padding(); + const auto& output_grad_shape = output_grad.get_legacy_shape().without_padding(); + ttnn::moreh_matmul( + output_grad, + other, + false, + true, + input_grad_tensor, + std::nullopt, + output_mem_config, + compute_kernel_config); + } else { + const auto& input_shape = input.get_legacy_shape().without_padding(); + const auto& temp_input_grad = ttnn::moreh_matmul( + output_grad, other, false, true, std::nullopt, std::nullopt, output_mem_config, compute_kernel_config); + auto reduce_dims = + moreh_matmul::find_reduce_dim(temp_input_grad.get_legacy_shape(), input_grad_tensor.get_legacy_shape()); + ttnn::moreh_sum( + temp_input_grad, reduce_dims, true, input_grad_tensor, output_mem_config, compute_kernel_config); + } + outputs[0] = input_grad_tensor; + } + + if (other_requires_grad) { + TT_ASSERT(other_grad.has_value()); + const auto& other_grad_tensor = other_grad.value(); + if (moreh_matmul::is_same_batch_dim(output_grad, other_grad_tensor)) { + ttnn::moreh_matmul( + input, + output_grad, + true, + false, + other_grad_tensor, + std::nullopt, + output_mem_config, + compute_kernel_config); + } else { + const auto& temp_other_grad = ttnn::moreh_matmul( + input, output_grad, true, false, std::nullopt, std::nullopt, output_mem_config, compute_kernel_config); + auto reduce_dims = + moreh_matmul::find_reduce_dim(temp_other_grad.get_legacy_shape(), other_grad_tensor.get_legacy_shape()); + ttnn::moreh_sum( + temp_other_grad, reduce_dims, true, other_grad_tensor, output_mem_config, compute_kernel_config); + } + outputs[1] = other_grad_tensor; + } + + return outputs; +} +} // namespace ttnn::operations::moreh::moreh_matmul_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp new file mode 100644 index 00000000000..e732a0b945d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp @@ -0,0 +1,26 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once +#include "ttnn/decorators.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +namespace ttnn::operations::moreh::moreh_matmul_backward { +struct MorehMatmulBackward { + static std::vector> invoke( + const Tensor& output_grad, + const Tensor& input, + const Tensor& other, + const std::vector& are_required_outputs, + const std::optional& input_grad, + const std::optional& other_grad, + const std::optional& output_mem_config, + const std::optional compute_kernel_config); +}; +} // namespace ttnn::operations::moreh::moreh_matmul_backward + +namespace ttnn { +constexpr auto moreh_matmul_backward = ttnn::register_operation< + "ttnn::moreh_matmul_backward", + ttnn::operations::moreh::moreh_matmul_backward::MorehMatmulBackward>(); +} diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp new file mode 100644 index 00000000000..071a0df1a3b --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp @@ -0,0 +1,28 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_matmul_backward_pybind.hpp" + +#include "pybind11/cast.h" +#include "pybind11/decorators.hpp" +#include "ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp" + +namespace ttnn::operations::moreh::moreh_matmul_backward { +void bind_moreh_matmul_backward_operation(py::module& module) { + bind_registered_operation( + module, + ttnn::moreh_matmul_backward, + "Moreh moreh_matmul_backward Operation", + ttnn::pybind_arguments_t{ + py::arg("output_grad"), + py::arg("input_a"), + py::arg("input_b"), + py::kw_only(), + py::arg("are_required_outputs") = std::vector{true, true}, + py::arg("input_a_grad") = std::nullopt, + py::arg("input_b_grad") = std::nullopt, + py::arg("output_mem_config") = std::nullopt, + py::arg("compute_kernel_config") = std::nullopt}); +} +} // namespace ttnn::operations::moreh::moreh_matmul_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.hpp new file mode 100644 index 00000000000..0f3518e966d --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.hpp @@ -0,0 +1,13 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::moreh::moreh_matmul_backward { +void bind_moreh_matmul_backward_operation(py::module& module); +} // namespace ttnn::operations::moreh::moreh_matmul_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp index 65eb1ade7b7..9be06cdfc80 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp @@ -19,6 +19,7 @@ #include "ttnn/operations/moreh/moreh_linear/moreh_linear_pybind.hpp" #include "ttnn/operations/moreh/moreh_linear_backward/moreh_linear_backward_pybind.hpp" #include "ttnn/operations/moreh/moreh_matmul/moreh_matmul_pybind.hpp" +#include "ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.hpp" #include "ttnn/operations/moreh/moreh_mean/moreh_mean_pybind.hpp" #include "ttnn/operations/moreh/moreh_mean_backward/moreh_mean_backward_pybind.hpp" #include "ttnn/operations/moreh/moreh_nll_loss/moreh_nll_loss_pybind.hpp" @@ -49,6 +50,7 @@ void bind_moreh_operations(py::module &module) { moreh_linear_backward::bind_moreh_linear_backward_operation(module); moreh_linear::bind_moreh_linear_operation(module); moreh_matmul::bind_moreh_matmul_operation(module); + moreh_matmul_backward::bind_moreh_matmul_backward_operation(module); moreh_mean_backward::bind_moreh_mean_backward_operation(module); moreh_mean::bind_moreh_mean_operation(module); moreh_nll_loss_backward::bind_moreh_nll_loss_backward_operation(module); diff --git a/ttnn/ttnn/operations/moreh.py b/ttnn/ttnn/operations/moreh.py index 090f09e1ecd..df02441bc28 100644 --- a/ttnn/ttnn/operations/moreh.py +++ b/ttnn/ttnn/operations/moreh.py @@ -20,6 +20,7 @@ logsoftmax = ttnn._ttnn.operations.moreh.moreh_logsoftmax logsoftmax_backward = ttnn._ttnn.operations.moreh.moreh_logsoftmax_backward matmul = ttnn._ttnn.operations.moreh.moreh_matmul +matmul_backward = ttnn._ttnn.operations.moreh.moreh_matmul_backward mean = ttnn._ttnn.operations.moreh.moreh_mean mean_backward = ttnn._ttnn.operations.moreh.moreh_mean_backward nll_loss = ttnn._ttnn.operations.moreh.moreh_nll_loss From dc2c1643bd44656082b58301897ee911c999e4fd Mon Sep 17 00:00:00 2001 From: o2buzzle <76864037+o2buzzle@users.noreply.github.com> Date: Fri, 13 Sep 2024 09:57:01 +0000 Subject: [PATCH 24/58] #12554: fixes and integration with moreh_dot_backward --- .../operations/test_moreh_dot_backward.py | 4 +- .../operations/test_moreh_matmul.py | 77 ++++++++++++++++++- .../moreh_dot_backward_device_operation.cpp | 46 ++++++----- .../moreh_dot_backward_device_operation.hpp | 10 +-- .../moreh_dot_backward_program_factory.cpp | 17 ++-- .../moreh_dot_backward.cpp | 6 +- .../moreh_dot_backward.hpp | 6 +- .../moreh_matmul_backward.cpp | 18 ++++- ttnn/ttnn/operations/moreh.py | 1 + 9 files changed, 141 insertions(+), 44 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_moreh_dot_backward.py b/tests/ttnn/unit_tests/operations/test_moreh_dot_backward.py index 9bfc65aaf3f..2fa86da612c 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_dot_backward.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_dot_backward.py @@ -117,8 +117,8 @@ def test_moreh_matmul_1d_backward(input_shape, requires_grad, device): torch_out.backward(torch_output_grad) # tt matmul backward - ttnn.experimental.operations.primary.moreh_matmul_backward( - tt_output_grad, tt_input, tt_other, (require_input_grad, require_other_grad), tt_input_grad, tt_other_grad + ttnn.operations.moreh.dot_backward( + tt_output_grad, tt_input, tt_other, input_grad=tt_input_grad, other_grad=tt_other_grad ) # test for equivalance diff --git a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py index 96191571dd7..f135ff0c771 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py @@ -306,8 +306,8 @@ def test_moreh_matmul_1d(input_shape, device): logger.debug(f"Output pcc={output_pcc}") assert passing - - + + @pytest.mark.parametrize( "params", ( @@ -381,3 +381,76 @@ def test_moreh_matmul_backward(params, requires_grad, device): else: assert tt_other_grad is None + +@pytest.mark.parametrize( + "input_shape", + ( + [1, 1, 1, 10], # test not mutiple of 32 case + [1, 1, 1, 32], # test single tile + [1, 1, 1, 352], # test multiple tiles + [1, 1, 1, 323], # test multiple tiles, not a multiple of 32 + ), +) +@pytest.mark.parametrize( + "requires_grad", + ( + (True, False), + (False, True), + (True, True), + ), +) +def test_moreh_matmul_1d_backward(input_shape, requires_grad, device): + torch.manual_seed(3072) + require_input_grad, require_other_grad = requires_grad + output_shape = [1, 1, 1, 1] + # get tensors + ( + tt_input, + tt_other, + _, + tt_output_grad, + tt_input_grad, + tt_other_grad, + torch_input, + torch_other, + torch_output_grad, + ) = get_tensors(input_shape, input_shape, output_shape, require_input_grad, require_other_grad, True, device) + + # torch matmul + torch_out = torch.matmul( + torch_input.requires_grad_(require_input_grad), torch_other.requires_grad_(require_other_grad) + ) + torch_out.backward(torch_output_grad) + + # tt matmul backward + ttnn.operations.moreh.matmul_backward( + tt_output_grad, + tt_input, + tt_other, + are_required_outputs=(require_input_grad, require_other_grad), + input_a_grad=tt_input_grad, + input_b_grad=tt_other_grad, + ) + + # test for equivalance + rtol = atol = 0.1 + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + if require_input_grad: + ttcpu_input_grad = tt_input_grad.cpu().to(cpu_layout).unpad_from_tile(input_shape).to_torch() + + passing, output_pcc = comp_allclose_and_pcc( + torch_input.grad, ttcpu_input_grad.reshape(-1), pcc=0.999, rtol=rtol, atol=atol + ) + logger.debug(f"input_grad passing={passing}") + logger.debug(f"input_grad pcc={output_pcc}") + assert passing + + if require_other_grad: + ttcpu_other_grad = tt_other_grad.cpu().to(cpu_layout).unpad_from_tile(input_shape).to_torch() + + passing, output_pcc = comp_allclose_and_pcc( + torch_other.grad, ttcpu_other_grad.reshape(-1), pcc=0.999, rtol=rtol, atol=atol + ) + logger.debug(f"other_grad passing={passing}") + logger.debug(f"other_grad pcc={output_pcc}") + assert passing diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.cpp index 9d3c46cd840..80180448390 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.cpp @@ -10,31 +10,21 @@ namespace ttnn::operations::moreh::moreh_dot_backward { MorehDotBackwardOperation::program_factory_t MorehDotBackwardOperation::select_program_factory( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - // For now we litteraly don't care and return a single factory. Whatever return SingleCore{}; } -void MorehDotBackwardOperation::validate_on_program_cache_miss( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - // TT_FATAL("INTENTIONAL: validate_on_program_cache_miss"); -} - -void MorehDotBackwardOperation::validate_on_program_cache_hit( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - // TT_FATAL("INTENTIONAL: validate_on_program_cache_hit"); -} - void grad_tensor_validate(const Tensor& tensor, const Tensor& grad_tensor) { - const auto& tensor_shape = tensor.get_shape().value.without_padding(); - const auto& grad_tensor_shape = grad_tensor.get_shape().value.without_padding(); + const auto& tensor_shape = tensor.get_legacy_shape().without_padding(); + const auto& grad_tensor_shape = grad_tensor.get_legacy_shape().without_padding(); TT_FATAL(tensor_shape == grad_tensor_shape, "Tensor shape and grad tensor shape should be the same."); TT_FATAL(grad_tensor.storage_type() == StorageType::DEVICE, "Operands to dot backward need to be on device!"); TT_FATAL(grad_tensor.device() == tensor.device(), "Operands to dot backward need to be on the same device!"); TT_FATAL(grad_tensor.buffer() != nullptr, "Operands to dot backward need to be allocated in buffers on device!"); } -void MorehDotBackwardOperation::validate( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { +void validate_tensors( + const MorehDotBackwardOperation::operation_attributes_t& operation_attributes, + const MorehDotBackwardOperation::tensor_args_t& tensor_args) { const auto& output_grad = tensor_args.output_grad; const auto& input = tensor_args.input; const auto& other = tensor_args.other; @@ -57,24 +47,38 @@ void MorehDotBackwardOperation::validate( output_grad.buffer() != nullptr and input.buffer() != nullptr and other.buffer() != nullptr, "Operands to dot backward need to be allocated in buffers on device!"); - const auto& input_grad = tensor_args.input_grad; - const auto& other_grad = tensor_args.other_grad; - if (input_grad) { + // validate optional inputs + const auto& input_grad = tensor_args.output_tensors.at(0); + const auto& other_grad = tensor_args.output_tensors.at(1); + if (input_grad.has_value()) { grad_tensor_validate(input, input_grad.value()); } - if (other_grad) { + + if (other_grad.has_value()) { grad_tensor_validate(other, other_grad.value()); } } +void MorehDotBackwardOperation::validate_on_program_cache_miss( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_tensors(operation_attributes, tensor_args); +} + +void MorehDotBackwardOperation::validate_on_program_cache_hit( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_tensors(operation_attributes, tensor_args); +} + MorehDotBackwardOperation::shape_return_value_t MorehDotBackwardOperation::compute_output_shapes( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + TT_FATAL(false, "This operation is in place, and as such, should not be computing output shapes."); return {}; } MorehDotBackwardOperation::tensor_return_value_t MorehDotBackwardOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return {}; + TT_FATAL(tensor_args.output_tensors.size() > 0, "Invalid number of output tensors."); + return tensor_args.output_tensors; } std::tuple @@ -87,7 +91,7 @@ MorehDotBackwardOperation::invoke( const std::optional& memory_config) { return { operation_attributes_t{memory_config.value_or(input.memory_config())}, - tensor_args_t{output_grad, input, other, input_grad, other_grad}}; + tensor_args_t{output_grad, input, other, {input_grad, other_grad}}}; } } // namespace ttnn::operations::moreh::moreh_dot_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp index f693416cb41..d7185780040 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp @@ -21,12 +21,13 @@ struct MorehDotBackwardOperation { const Tensor &output_grad; const Tensor &input; const Tensor &other; - const std::optional &input_grad; - const std::optional &other_grad; + + // (o2buzzle): May I present: thanhnguyen's mistake that costed me 3 hours. + const std::vector> output_tensors; }; - using shape_return_value_t = std::vector; - using tensor_return_value_t = std::vector; + using shape_return_value_t = std::vector>; + using tensor_return_value_t = std::vector>; struct SingleCore { struct shared_variables_t { @@ -52,7 +53,6 @@ struct MorehDotBackwardOperation { static program_factory_t select_program_factory(const operation_attributes_t &, const tensor_args_t &); static void validate_on_program_cache_miss(const operation_attributes_t &, const tensor_args_t &); static void validate_on_program_cache_hit(const operation_attributes_t &, const tensor_args_t &); - static void validate(const operation_attributes_t &, const tensor_args_t &); static shape_return_value_t compute_output_shapes(const operation_attributes_t &, const tensor_args_t &); static tensor_return_value_t create_output_tensors(const operation_attributes_t &, const tensor_args_t &); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp index 6851b15e07a..5031d25ce46 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp @@ -2,6 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 +#include + #include "moreh_dot_backward_device_operation.hpp" #include "tt_metal/detail/util.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" @@ -17,9 +19,8 @@ MorehDotBackwardOperation::SingleCore::cached_program_t MorehDotBackwardOperatio const auto& output_grad = tensor_args.output_grad; const auto& input = tensor_args.input; const auto& other = tensor_args.other; - const auto& input_grad = tensor_args.input_grad; - const auto& other_grad = tensor_args.other_grad; - + const auto& input_grad = tensor_return_value.at(0); + const auto& other_grad = tensor_return_value.at(1); Program program{}; CoreCoord core = {0, 0}; const uint32_t core_num = 1; @@ -149,8 +150,8 @@ void MorehDotBackwardOperation::SingleCore::override_runtime_arguments( const auto& output_grad_buffer = tensor_args.output_grad.buffer(); const auto& input_buffer = tensor_args.input.buffer(); const auto& other_buffer = tensor_args.other.buffer(); - const auto& input_grad_buffer = tensor_return_value.at(0).buffer(); - const auto& other_grad_buffer = tensor_return_value.at(1).buffer(); + const auto input_grad_buffer = tensor_return_value.at(0); + const auto other_grad_buffer = tensor_return_value.at(1); { auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_reader_kernel_id, CoreCoord{0, 0}); @@ -161,8 +162,10 @@ void MorehDotBackwardOperation::SingleCore::override_runtime_arguments( { auto& runtime_args = tt::tt_metal::GetRuntimeArgs(program, unary_writer_kernel_id, CoreCoord{0, 0}); - runtime_args[2] = input_grad_buffer->address(); - runtime_args[3] = other_grad_buffer->address(); + if (input_grad_buffer.has_value()) + runtime_args[2] = input_grad_buffer.value().buffer()->address(); + if (other_grad_buffer.has_value()) + runtime_args[3] = other_grad_buffer.value().buffer()->address(); } } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.cpp index d03bfc4694f..4ce365068f0 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.cpp @@ -7,12 +7,12 @@ #include "ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp" namespace ttnn::operations::moreh::moreh_dot_backward { -std::vector MorehDotBackward::invoke( +std::vector> MorehDotBackward::invoke( const Tensor &output_grad, const Tensor &input, const Tensor &other, - std::optional input_grad, - std::optional other_grad, + const std::optional &input_grad, + const std::optional &other_grad, const std::optional &mem_config) { return ttnn::prim::moreh_dot_backward(output_grad, input, other, input_grad, other_grad, mem_config); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp index 2514db9873d..1dbf6129f56 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp @@ -6,12 +6,12 @@ #include "ttnn/decorators.hpp" namespace ttnn::operations::moreh::moreh_dot_backward { struct MorehDotBackward { - static std::vector invoke( + static std::vector> invoke( const Tensor &output_grad, const Tensor &input, const Tensor &other, - std::optional input_grad, - std::optional other_grad, + const std::optional &input_grad, + const std::optional &other_grad, const std::optional &mem_config); }; } // namespace ttnn::operations::moreh::moreh_dot_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp index bf205bbe7c5..0106deb42bc 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp @@ -4,12 +4,24 @@ #include "moreh_matmul_backward.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" +#include "ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp" #include "ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp" #include "ttnn/operations/moreh/moreh_matmul/moreh_matmul.hpp" #include "ttnn/operations/moreh/moreh_sum/moreh_sum.hpp" namespace ttnn::operations::moreh::moreh_matmul_backward { -///////////////////////////////////////// + +inline bool is_dot_backward(const Tensor& output_grad, const Tensor& input, const Tensor& other) { + // TODO: non-4d support for dot backward. + if (output_grad.get_legacy_shape().rank() != 4 || input.get_legacy_shape().rank() != 4 || + other.get_legacy_shape().rank() != 4) { + return false; + } + return tt::operations::primary::is_scalar(output_grad) && tt::operations::primary::is_1d_tensor(input) && + tt::operations::primary::is_1d_tensor(other) && tt::operations::primary::is_same_shape(input, other); +} + std::vector> MorehMatmulBackward::invoke( const Tensor& output_grad, const Tensor& input, @@ -19,6 +31,10 @@ std::vector> MorehMatmulBackward::invoke( const std::optional& other_grad, const std::optional& output_mem_config, const std::optional compute_kernel_config) { + if (is_dot_backward(output_grad, input, other)) { + return ttnn::moreh_dot_backward(output_grad, input, other, input_grad, other_grad, output_mem_config); + } + std::vector> outputs(2); outputs.reserve(2); diff --git a/ttnn/ttnn/operations/moreh.py b/ttnn/ttnn/operations/moreh.py index df02441bc28..b4eff014932 100644 --- a/ttnn/ttnn/operations/moreh.py +++ b/ttnn/ttnn/operations/moreh.py @@ -10,6 +10,7 @@ bmm = ttnn._ttnn.operations.moreh.moreh_bmm bmm_backward = ttnn._ttnn.operations.moreh.moreh_bmm_backward dot = ttnn._ttnn.operations.moreh.moreh_dot +dot_backward = ttnn._ttnn.operations.moreh.moreh_dot_backward getitem = ttnn._ttnn.operations.moreh.moreh_getitem group_norm = ttnn._ttnn.operations.moreh.moreh_group_norm group_norm_backward = ttnn._ttnn.operations.moreh.moreh_group_norm_backward From 21f770d17353bb87610dd7a6300a502638f90b87 Mon Sep 17 00:00:00 2001 From: o2buzzle <76864037+o2buzzle@users.noreply.github.com> Date: Wed, 9 Oct 2024 02:33:07 +0000 Subject: [PATCH 25/58] #12254: cleanup refactored code --- .../operations/test_moreh_matmul.py | 17 +++--- .../examples/example/example_pybind.hpp | 14 ++--- .../moreh_dot_backward_program_factory.cpp | 6 +- .../moreh_dot_backward.cpp | 23 +++++++ .../moreh_dot_backward.hpp | 16 ++++- .../device/moreh_matmul_device_operation.cpp | 2 +- .../device/moreh_matmul_device_operation.hpp | 5 +- .../device/moreh_matmul_program_factory.cpp | 2 +- .../moreh_matmul_backward.cpp | 61 +++++++++++-------- .../moreh_matmul_backward.hpp | 17 +++++- .../moreh_matmul_backward_pybind.cpp | 2 +- 11 files changed, 112 insertions(+), 53 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py index f135ff0c771..7c2cdf97447 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py @@ -423,14 +423,15 @@ def test_moreh_matmul_1d_backward(input_shape, requires_grad, device): torch_out.backward(torch_output_grad) # tt matmul backward - ttnn.operations.moreh.matmul_backward( - tt_output_grad, - tt_input, - tt_other, - are_required_outputs=(require_input_grad, require_other_grad), - input_a_grad=tt_input_grad, - input_b_grad=tt_other_grad, - ) + for _ in range(2): + ttnn.operations.moreh.matmul_backward( + tt_output_grad, + tt_input, + tt_other, + are_required_outputs=(require_input_grad, require_other_grad), + input_a_grad=tt_input_grad, + input_b_grad=tt_other_grad, + ) # test for equivalance rtol = atol = 0.1 diff --git a/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp b/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp index 732ad255e4e..1d8d5febffa 100644 --- a/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/examples/example/example_pybind.hpp @@ -16,6 +16,7 @@ namespace py = pybind11; namespace ttnn::operations::examples { void bind_example_operation(py::module& module) { + bind_registered_operation( module, ttnn::prim::example, @@ -24,12 +25,10 @@ void bind_example_operation(py::module& module) { // Add pybind overloads for the C++ APIs that should be exposed to python // There should be no logic here, just a call to `self` with the correct arguments // The overload with `queue_id` argument will be added automatically for primitive operations - // This specific function can be called from python as `ttnn.prim.example(input_tensor)` or - // `ttnn.prim.example(input_tensor, queue_id=queue_id)` + // This specific function can be called from python as `ttnn.prim.example(input_tensor)` or `ttnn.prim.example(input_tensor, queue_id=queue_id)` ttnn::pybind_overload_t{ - [](const decltype(ttnn::prim::example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor { - return self(input_tensor); - }, + [](const decltype(ttnn::prim::example)& self, const ttnn::Tensor& input_tensor) + -> ttnn::Tensor { return self(input_tensor); }, py::arg("input_tensor")}); bind_registered_operation( @@ -40,9 +39,8 @@ void bind_example_operation(py::module& module) { // Add pybind overloads for the C++ APIs that should be exposed to python // There should be no logic here, just a call to `self` with the correct arguments ttnn::pybind_overload_t{ - [](const decltype(ttnn::composite_example)& self, const ttnn::Tensor& input_tensor) -> ttnn::Tensor { - return self(input_tensor); - }, + [](const decltype(ttnn::composite_example)& self, const ttnn::Tensor& input_tensor) + -> ttnn::Tensor { return self(input_tensor); }, py::arg("input_tensor")}); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp index 5031d25ce46..150f9d7011a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp @@ -96,9 +96,9 @@ MorehDotBackwardOperation::SingleCore::cached_program_t MorehDotBackwardOperatio }; const auto reader_kernel_file = - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/kernels/reader_moreh_dot_backward.cpp"; + "ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/kernels/reader_moreh_dot_backward.cpp"; const auto writer_kernel_file = - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/kernels/writer_moreh_dot_backward.cpp"; + "ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/kernels/writer_moreh_dot_backward.cpp"; const auto reader_kernel_id = tt::operations::primary::CreateReadKernel(program, reader_kernel_file, core, reader_compile_time_args); @@ -109,7 +109,7 @@ MorehDotBackwardOperation::SingleCore::cached_program_t MorehDotBackwardOperatio std::map compute_defines; const auto compute_kernel_file = - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/kernels/moreh_dot_backward.cpp"; + "ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/kernels/moreh_dot_backward.cpp"; const auto compute_kernel_id = tt::operations::primary::CreateComputeKernel( program, compute_kernel_file, {core, core_num, compute_kernel_args}, compute_defines); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.cpp index 4ce365068f0..282ba0dc7ed 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.cpp @@ -16,4 +16,27 @@ std::vector> MorehDotBackward::invoke( const std::optional &mem_config) { return ttnn::prim::moreh_dot_backward(output_grad, input, other, input_grad, other_grad, mem_config); } + +std::vector MorehDotBackward::create_async_output_tensors( + const std::vector &input_tensors, const std::vector> &optional_inputs) { + auto output_grad = input_tensors.at(0); + auto input = input_tensors.at(1); + auto other = input_tensors.at(2); + + return { + Tensor(operation::get_workers_for_op_output({output_grad, input, other})), + Tensor(operation::get_workers_for_op_output({output_grad, input, other})), + }; +} + +std::vector MorehDotBackward::create_async_return_flag( + const Tensor &output_grad, + const Tensor &input, + const Tensor &other, + const std::optional &input_grad, + const std::optional &other_grad, + const std::optional &mem_config) { + return {input_grad.has_value(), other_grad.has_value()}; +} + } // namespace ttnn::operations::moreh::moreh_dot_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp index 1dbf6129f56..15c009bc70e 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward.hpp @@ -13,10 +13,22 @@ struct MorehDotBackward { const std::optional &input_grad, const std::optional &other_grad, const std::optional &mem_config); + + static std::vector create_async_output_tensors( + const std::vector &input_tensors, const std::vector> &optional_inputs); + + static std::vector create_async_return_flag( + const Tensor &output_grad, + const Tensor &input, + const Tensor &other, + const std::optional &input_grad, + const std::optional &other_grad, + const std::optional &mem_config); }; } // namespace ttnn::operations::moreh::moreh_dot_backward namespace ttnn { -constexpr auto moreh_dot_backward = ttnn:: - register_operation<"ttnn::moreh_dot_backward", ttnn::operations::moreh::moreh_dot_backward::MorehDotBackward>(); +constexpr auto moreh_dot_backward = ttnn::register_operation_with_auto_launch_op< + "ttnn::moreh_dot_backward", + ttnn::operations::moreh::moreh_dot_backward::MorehDotBackward>(); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp index ff0e978cd6e..4116979b684 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.cpp @@ -237,7 +237,7 @@ MorehMatmulOperation::invoke( transpose_input, transpose_other, output_memory_config.value_or(input.memory_config()), - compute_kernel_config}, + init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config)}, MorehMatmulOperation::tensor_args_t{input, other, output, bias}}; } } // namespace ttnn::operations::moreh::moreh_matmul diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp index 1ce66393aed..7e3734fafe5 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp @@ -17,7 +17,7 @@ struct MorehMatmulOperation { bool transpose_other; const MemoryConfig output_memory_config; - const std::optional compute_kernel_config; + const DeviceComputeKernelConfig compute_kernel_config; }; struct tensor_args_t { @@ -73,7 +73,8 @@ struct MorehMatmulOperation { }; void get_tensor_dim(std::vector& dim, const tt::tt_metal::LegacyShape& shape); -std::vector find_reduce_dim(const tt::tt_metal::LegacyShape& a_shape, const tt::tt_metal::LegacyShape& b_shape); +std::vector find_reduce_dim( + const tt::tt_metal::LegacyShape& a_shape, const tt::tt_metal::LegacyShape& b_shape); bool is_same_batch_dim(const Tensor& tensor_a, const Tensor& tensor_b); } // namespace ttnn::operations::moreh::moreh_matmul diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp index 987c7a5ce04..1566ba7a3c9 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_program_factory.cpp @@ -40,7 +40,7 @@ std::vector find_reduce_dim( // batch dims for (int i = 0; i < rank - 2; ++i) { int idx = rank - 1 - i; - TT_ASSERT(idx >= 0, "idx < 0"); + TT_FATAL(idx >= 0, "idx < 0"); if (a_dim[idx] != b_dim[idx]) { dims.push_back(i); log_debug(tt::LogOp, "find_reduce_dim :{} push {} dim", __LINE__, i); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp index 0106deb42bc..65ac0cb6f36 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.cpp @@ -9,6 +9,7 @@ #include "ttnn/operations/moreh/moreh_matmul/device/moreh_matmul_device_operation.hpp" #include "ttnn/operations/moreh/moreh_matmul/moreh_matmul.hpp" #include "ttnn/operations/moreh/moreh_sum/moreh_sum.hpp" +#include "ttnn/run_operation.hpp" namespace ttnn::operations::moreh::moreh_matmul_backward { @@ -29,69 +30,79 @@ std::vector> MorehMatmulBackward::invoke( const std::vector& are_required_outputs, const std::optional& input_grad, const std::optional& other_grad, - const std::optional& output_mem_config, + const std::optional& memory_config, const std::optional compute_kernel_config) { if (is_dot_backward(output_grad, input, other)) { - return ttnn::moreh_dot_backward(output_grad, input, other, input_grad, other_grad, output_mem_config); + return ttnn::moreh_dot_backward(output_grad, input, other, input_grad, other_grad, memory_config); } std::vector> outputs(2); - outputs.reserve(2); const bool input_requires_grad = are_required_outputs.at(0); const bool other_requires_grad = are_required_outputs.at(1); if (input_requires_grad) { - TT_ASSERT(input_grad.has_value()); + TT_FATAL(input_grad.has_value(), "Input gradient is marked required but not provided."); const auto& input_grad_tensor = input_grad.value(); if (moreh_matmul::is_same_batch_dim(output_grad, input_grad_tensor)) { const auto& input_grad_shape = input_grad_tensor.get_legacy_shape().without_padding(); const auto& output_grad_shape = output_grad.get_legacy_shape().without_padding(); ttnn::moreh_matmul( - output_grad, - other, - false, - true, - input_grad_tensor, - std::nullopt, - output_mem_config, - compute_kernel_config); + output_grad, other, false, true, input_grad_tensor, std::nullopt, memory_config, compute_kernel_config); } else { const auto& input_shape = input.get_legacy_shape().without_padding(); const auto& temp_input_grad = ttnn::moreh_matmul( - output_grad, other, false, true, std::nullopt, std::nullopt, output_mem_config, compute_kernel_config); + output_grad, other, false, true, std::nullopt, std::nullopt, memory_config, compute_kernel_config); auto reduce_dims = moreh_matmul::find_reduce_dim(temp_input_grad.get_legacy_shape(), input_grad_tensor.get_legacy_shape()); ttnn::moreh_sum( - temp_input_grad, reduce_dims, true, input_grad_tensor, output_mem_config, compute_kernel_config); + temp_input_grad, reduce_dims, true, input_grad_tensor, memory_config, compute_kernel_config); } outputs[0] = input_grad_tensor; } if (other_requires_grad) { - TT_ASSERT(other_grad.has_value()); + TT_FATAL(other_grad.has_value(), "Other gradient is marked required but not provided."); const auto& other_grad_tensor = other_grad.value(); if (moreh_matmul::is_same_batch_dim(output_grad, other_grad_tensor)) { ttnn::moreh_matmul( - input, - output_grad, - true, - false, - other_grad_tensor, - std::nullopt, - output_mem_config, - compute_kernel_config); + input, output_grad, true, false, other_grad_tensor, std::nullopt, memory_config, compute_kernel_config); } else { const auto& temp_other_grad = ttnn::moreh_matmul( - input, output_grad, true, false, std::nullopt, std::nullopt, output_mem_config, compute_kernel_config); + input, output_grad, true, false, std::nullopt, std::nullopt, memory_config, compute_kernel_config); auto reduce_dims = moreh_matmul::find_reduce_dim(temp_other_grad.get_legacy_shape(), other_grad_tensor.get_legacy_shape()); ttnn::moreh_sum( - temp_other_grad, reduce_dims, true, other_grad_tensor, output_mem_config, compute_kernel_config); + temp_other_grad, reduce_dims, true, other_grad_tensor, memory_config, compute_kernel_config); } outputs[1] = other_grad_tensor; } return outputs; } + +std::vector MorehMatmulBackward::create_async_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_inputs) { + const auto& output_grad = input_tensors.at(0); + const auto& input = input_tensors.at(1); + const auto& other = input_tensors.at(2); + + return { + Tensor(operation::get_workers_for_op_output({output_grad, input, other})), + Tensor(operation::get_workers_for_op_output({output_grad, input, other})), + }; +} + +std::vector MorehMatmulBackward::create_async_return_flag( + const Tensor& output_grad, + const Tensor& input, + const Tensor& other, + const std::vector& are_required_outputs, + const std::optional& input_grad, + const std::optional& other_grad, + const std::optional& memory_config, + const std::optional compute_kernel_config) { + return {are_required_outputs.at(0), are_required_outputs.at(1)}; +} + } // namespace ttnn::operations::moreh::moreh_matmul_backward diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp index e732a0b945d..e96a5110c89 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward.hpp @@ -14,13 +14,26 @@ struct MorehMatmulBackward { const std::vector& are_required_outputs, const std::optional& input_grad, const std::optional& other_grad, - const std::optional& output_mem_config, + const std::optional& memory_config, + const std::optional compute_kernel_config); + + static std::vector create_async_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_inputs); + + static std::vector create_async_return_flag( + const Tensor& output_grad, + const Tensor& input, + const Tensor& other, + const std::vector& are_required_outputs, + const std::optional& input_grad, + const std::optional& other_grad, + const std::optional& memory_config, const std::optional compute_kernel_config); }; } // namespace ttnn::operations::moreh::moreh_matmul_backward namespace ttnn { -constexpr auto moreh_matmul_backward = ttnn::register_operation< +constexpr auto moreh_matmul_backward = ttnn::register_operation_with_auto_launch_op< "ttnn::moreh_matmul_backward", ttnn::operations::moreh::moreh_matmul_backward::MorehMatmulBackward>(); } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp index 071a0df1a3b..bb181a49699 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_matmul_backward/moreh_matmul_backward_pybind.cpp @@ -22,7 +22,7 @@ void bind_moreh_matmul_backward_operation(py::module& module) { py::arg("are_required_outputs") = std::vector{true, true}, py::arg("input_a_grad") = std::nullopt, py::arg("input_b_grad") = std::nullopt, - py::arg("output_mem_config") = std::nullopt, + py::arg("memory_config") = std::nullopt, py::arg("compute_kernel_config") = std::nullopt}); } } // namespace ttnn::operations::moreh::moreh_matmul_backward From a62720d8a27e9e96ca66a1cf08774d4a4d35e223 Mon Sep 17 00:00:00 2001 From: Anasuya G Nair Date: Wed, 9 Oct 2024 12:09:54 +0530 Subject: [PATCH 26/58] #10347: Update documentation and sweep config for rsub and rdiv (#13501) #10347: Update documentation and add golden function --- .../sweeps/eltwise/unary/rdiv/rdiv.py | 14 ++++++-- .../sweeps/eltwise/unary/rsub/rsub.py | 14 ++++++-- .../operations/eltwise/unary/unary_pybind.hpp | 36 ++++++++++++++++--- 3 files changed, 56 insertions(+), 8 deletions(-) diff --git a/tests/sweep_framework/sweeps/eltwise/unary/rdiv/rdiv.py b/tests/sweep_framework/sweeps/eltwise/unary/rdiv/rdiv.py index 10bf13e50aa..138fe8929f7 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/rdiv/rdiv.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/rdiv/rdiv.py @@ -28,7 +28,7 @@ "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 64), "exclude_range": [[-1, 1]], "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], - "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], }, @@ -43,6 +43,15 @@ } +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "Row Major layout is not supported" + return False, None + + # This is the run instructions for the test, defined by the developer. # The run function must take the above-defined parameters as inputs. # The runner will call this run function with each test vector, and the returned results from this function will be stored. @@ -66,7 +75,8 @@ def run( factor = torch.tensor(1, dtype=torch.bfloat16).uniform_(0.1, 10.0).item() - torch_output_tensor = torch.div(factor, torch_input_tensor_a) + golden_function = ttnn.get_golden_function(ttnn.rdiv) + torch_output_tensor = golden_function(torch_input_tensor_a, factor) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/tests/sweep_framework/sweeps/eltwise/unary/rsub/rsub.py b/tests/sweep_framework/sweeps/eltwise/unary/rsub/rsub.py index ab5b225d7ad..de02011284e 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/rsub/rsub.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/rsub/rsub.py @@ -28,13 +28,22 @@ "nightly": { "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 64), "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], - "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], }, } +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + if test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "Row Major layout is not supported" + return False, None + + # This is the run instructions for the test, defined by the developer. # The run function must take the above-defined parameters as inputs. # The runner will call this run function with each test vector, and the returned results from this function will be stored. @@ -57,7 +66,8 @@ def run( factor = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() - torch_output_tensor = torch.sub(factor, torch_input_tensor_a) + golden_function = ttnn.get_golden_function(ttnn.rsub) + torch_output_tensor = golden_function(torch_input_tensor_a, factor) input_tensor_a = ttnn.from_torch( torch_input_tensor_a, diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 031c92beb62..763767c0bb7 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -424,7 +424,7 @@ void bind_unary_operation_with_dim_parameter( } template -void bind_unary_rdiv(py::module& module, const unary_operation_t& operation, const std::string& parameter_name_a, const std::string& parameter_a_doc, const std::string& parameter_name_b, const std::string& parameter_b_doc, const std::string parameter_b_value, const std::string& description) { +void bind_unary_rdiv(py::module& module, const unary_operation_t& operation, const std::string& parameter_name_a, const std::string& parameter_a_doc, const std::string& parameter_name_b, const std::string& parameter_b_doc, const std::string parameter_b_value, const std::string& description, const std::string& note = " ") { auto doc = fmt::format( R"doc( {7} @@ -442,6 +442,9 @@ void bind_unary_rdiv(py::module& module, const unary_operation_t& operation, con Returns: ttnn.Tensor: the output tensor. + Note: + {8} + Example: >>> tensor = ttnn.from_torch(torch.tensor((1, 2), dtype=torch.bfloat16), device=device) >>> output = {1}(tensor, {2}, {4} = {6}) @@ -453,7 +456,8 @@ void bind_unary_rdiv(py::module& module, const unary_operation_t& operation, con parameter_name_b, parameter_b_doc, parameter_b_value, - description); + description, + note); bind_registered_operation( module, @@ -1431,7 +1435,19 @@ void py_module(py::module& module) { +----------------------------+---------------------------------+-------------------+ )doc"); - detail::bind_unary_operation_with_float_parameter(module, ttnn::rsub, "value", "subtrahent value which is actually calculated as minuend", "Returns tensor with respective elements of the input tensor subtracted from the value."); + detail::bind_unary_operation_with_float_parameter(module, ttnn::rsub, "value", "subtrahent value which is actually calculated as minuend", "Returns tensor with respective elements of the input tensor subtracted from the value.", + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + System memory is not supported. + + )doc"); + detail::bind_unary_operation_with_float_parameter(module, ttnn::heaviside, "value", "The value parameter for the Heaviside function", ""); detail::bind_unary_operation_with_float_parameter(module, ttnn::leaky_relu, "slope", "The slope parameter for the Leaky ReLU function", ""); detail::bind_unary_operation_with_float_parameter(module, ttnn::relu_max, "upper_limit", "The max value for ReLU function", "This function caps off the input to a max value and a min value of 0"); @@ -1586,7 +1602,19 @@ void py_module(py::module& module) { Input tensor must have BFLOAT16 data type. - Output tensor will have BFLOAT16 data type.)doc"); + Output tensor will have BFLOAT16 data type.)doc", + + R"doc(Supported dtypes, layouts, and ranks: + + +----------------------------+---------------------------------+-------------------+ + | Dtypes | Layouts | Ranks | + +----------------------------+---------------------------------+-------------------+ + | BFLOAT16, BFLOAT8_B | TILE | 2, 3, 4 | + +----------------------------+---------------------------------+-------------------+ + + System memory is not supported. + + )doc"); } From 4b0cd5e5a65d52c0653471a53293628f65674322 Mon Sep 17 00:00:00 2001 From: Andrija Malbasa Date: Wed, 9 Oct 2024 09:18:14 +0200 Subject: [PATCH 27/58] Add sweeps for unary_backward ops (#13539) * #11512: Add sweeps for backward_ops: clamp_bw, hardtanh_bw, mul_bw, softplus_bw, threshold_bw * #11512: Minor fix * #11512: Added mul_bw sweep, modified div sweep * Update ttnn-run-sweeps.yaml --- .github/workflows/ttnn-run-sweeps.yaml | 6 + .../sweeps/eltwise/binary/div/div.py | 15 +-- .../sweeps/eltwise/unary/clamp/clamp.py | 6 +- .../unary_backward/clamp_bw/clamp_bw.py | 105 ++++++++++++++++ .../eltwise/unary_backward/div_bw/div_bw.py | 111 +++++++++++++++++ .../unary_backward/hardtanh_bw/hardtanh_bw.py | 96 +++++++++++++++ .../eltwise/unary_backward/mul_bw/mul_bw.py | 113 ++++++++++++++++++ .../unary_backward/softplus_bw/softplus_bw.py | 104 ++++++++++++++++ .../threshold_bw/threshold_bw.py | 101 ++++++++++++++++ 9 files changed, 647 insertions(+), 10 deletions(-) create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_backward/clamp_bw/clamp_bw.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_backward/div_bw/div_bw.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_backward/hardtanh_bw/hardtanh_bw.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_backward/mul_bw/mul_bw.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_backward/softplus_bw/softplus_bw.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary_backward/threshold_bw/threshold_bw.py diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 3e5b58bddc4..8513a2b9a82 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -75,6 +75,12 @@ on: - eltwise.unary.relu_min.relu_min - eltwise.unary.relu_max.relu_max - eltwise.unary.softplus.softplus + - eltwise.unary_backward.clamp_bw.clamp_bw + - eltwise.unary_backward.hardtanh_bw.hardtanh_bw + - eltwise.unary_backward.mul_bw.mul_bw + - eltwise.unary_backward.softplus_bw.softplus_bw + - eltwise.unary_backward.threshold_bw.threshold_bw + - eltwise.unary_backward.div_bw.div_bw - eltwise.unary_backward.log_bw.log_bw - eltwise.unary_backward.relu6_bw.relu6_bw - eltwise.binary.subtract.subtract diff --git a/tests/sweep_framework/sweeps/eltwise/binary/div/div.py b/tests/sweep_framework/sweeps/eltwise/binary/div/div.py index c45a8f4375f..0a8605b5de0 100644 --- a/tests/sweep_framework/sweeps/eltwise/binary/div/div.py +++ b/tests/sweep_framework/sweeps/eltwise/binary/div/div.py @@ -26,10 +26,11 @@ # Developers can create their own generator functions and pass them to the parameters as inputs. parameters = { "nightly": { - "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) - + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) - + gen_shapes([32, 32], [256, 256], [32, 32], 16), + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 8) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 8) + + gen_shapes([32, 32], [256, 256], [32, 32], 8), "accurate_mode": [True, False], + "round_mode": ["None", "floor", "trunc"], "round_mode": [None], "input_a_dtype": [ttnn.bfloat16], "input_b_dtype": [ttnn.bfloat16], @@ -40,11 +41,11 @@ "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], }, "xfail": { - "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) - + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) - + gen_shapes([32, 32], [256, 256], [32, 32], 16), + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 4) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 4) + + gen_shapes([32, 32], [256, 256], [32, 32], 4), "accurate_mode": [True, False], - "round_mode": [None], + "round_mode": ["None", "floor", "trunc"], "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], "input_a_layout": [ttnn.TILE_LAYOUT], diff --git a/tests/sweep_framework/sweeps/eltwise/unary/clamp/clamp.py b/tests/sweep_framework/sweeps/eltwise/unary/clamp/clamp.py index 7b63fcd562d..0261f6b6758 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/clamp/clamp.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/clamp/clamp.py @@ -27,7 +27,7 @@ parameters = { "nightly": { "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 32), - "mode": [None, "min", "max"], + "mode": ["both", "min", "max"], "input_a_dtype": [ttnn.bfloat16], "input_a_layout": [ttnn.TILE_LAYOUT], "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], @@ -60,9 +60,9 @@ def run( low, high = gen_low_high_scalars() if mode == "min": - low, high = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item(), None + high = None elif mode == "max": - low, high = None, torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() + low = None torch_output_tensor = torch.clamp(torch_input_tensor_a, low, high) diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/clamp_bw/clamp_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/clamp_bw/clamp_bw.py new file mode 100644 index 00000000000..9d9dcfda635 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/clamp_bw/clamp_bw.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes, gen_low_high_scalars +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 8) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 8) + + gen_shapes([32, 32], [256, 256], [32, 32], 8), + "mode": ["both", "min", "max"], + "grad_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16], + "grad_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + mode, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_a.requires_grad = True + torch_input_tensor_a.retain_grad() + + low, high = gen_low_high_scalars() + + if mode == "min": + high = None + elif mode == "max": + low = None + + intermediate_result = torch.clamp(torch_input_tensor_a, low, high) + intermediate_result.backward(gradient=torch_grad_tensor) + torch_output_tensor = torch_input_tensor_a.grad + + grad_tensor = ttnn.from_torch( + torch_grad_tensor, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a.detach().clone(), + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.clamp_bw(grad_tensor, input_tensor_a, min=low, max=high, memory_config=output_memory_config)[0] + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/div_bw/div_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/div_bw/div_bw.py new file mode 100644 index 00000000000..6e45912b22e --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/div_bw/div_bw.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 8) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 8) + + gen_shapes([32, 32], [256, 256], [32, 32], 8), + "grad_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "grad_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# TO-DO: Create an issue on this, since these constrictions are not mentioned in the documentation +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + # In the documentation stands that this op supports ROW_MAJOR + # TO-DO: create an issue on this matter + if test_vector["grad_layout"] == ttnn.ROW_MAJOR_LAYOUT or test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "Inputs to eltwise binary must be tilized" + if test_vector["input_a_dtype"] == ttnn.bfloat8_b and test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "bfloat8_b is only supported on tiled layout" + if test_vector["grad_dtype"] == ttnn.bfloat8_b and test_vector["grad_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "bfloat8_b is only supported on tiled layout" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_a.requires_grad = True + torch_input_tensor_a.retain_grad() + + scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() + + itermediate_result = torch.div(torch_input_tensor_a, scalar) + itermediate_result.backward(gradient=torch_grad_tensor) + torch_output_tensor = torch_input_tensor_a.grad + + grad_tensor = ttnn.from_torch( + torch_grad_tensor, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a.detach().clone(), + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.div_bw(grad_tensor, input_tensor_a, scalar, memory_config=output_memory_config)[0] + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/hardtanh_bw/hardtanh_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/hardtanh_bw/hardtanh_bw.py new file mode 100644 index 00000000000..0b52a8f1970 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/hardtanh_bw/hardtanh_bw.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 16), + "grad_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16], + "grad_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_a.requires_grad = True + torch_input_tensor_a.retain_grad() + + intermediate_result = torch.nn.functional.hardtanh(torch_input_tensor_a) + intermediate_result.backward(gradient=torch_grad_tensor) + torch_output_tensor = torch_input_tensor_a.grad + + grad_tensor = ttnn.from_torch( + torch_grad_tensor, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a.detach().clone(), + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.hardtanh_bw(grad_tensor, input_tensor_a, memory_config=output_memory_config)[0] + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/mul_bw/mul_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/mul_bw/mul_bw.py new file mode 100644 index 00000000000..b8d8daf2462 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/mul_bw/mul_bw.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes, tensor_to_dtype +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 4) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 4) + + gen_shapes([32, 32], [256, 256], [32, 32], 4), + "grad_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "grad_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT, ttnn.ROW_MAJOR_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# Invalidate vector is called during the generation phase where each vector will be passed in. +# If invalidated, the vector will still be stored but will be skipped. +# Returns False, None if the vector is valid, and True, str with a reason for invalidation if it is invalid. +def invalidate_vector(test_vector) -> Tuple[bool, Optional[str]]: + # In the documentation stands that this op supports ROW_MAJOR + # TO-DO: create an issue on this matter + if test_vector["grad_layout"] == ttnn.ROW_MAJOR_LAYOUT or test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT: + return True, "Inputs to eltwise binary must be tilized" + if test_vector["input_a_layout"] == ttnn.ROW_MAJOR_LAYOUT and test_vector["input_a_dtype"] == ttnn.bfloat8_b: + return True, "bfloat8_b is not supported on row major layout" + if test_vector["grad_layout"] == ttnn.ROW_MAJOR_LAYOUT and test_vector["grad_dtype"] == ttnn.bfloat8_b: + return True, "bfloat8_b is not supported on row major layout" + return False, None + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_a.requires_grad = True + torch_input_tensor_a.retain_grad() + + scalar = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() + + itermediate_result = torch.mul(torch_input_tensor_a, scalar) + itermediate_result.backward(gradient=torch_grad_tensor) + torch_output_tensor = torch_input_tensor_a.grad + + grad_tensor = ttnn.from_torch( + torch_grad_tensor, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a.detach().clone(), + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.mul_bw(grad_tensor, input_tensor_a, scalar, memory_config=output_memory_config)[0] + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/softplus_bw/softplus_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/softplus_bw/softplus_bw.py new file mode 100644 index 00000000000..cda40ebe077 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/softplus_bw/softplus_bw.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "xfail": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 16), + "grad_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16], + "grad_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_a.requires_grad = True + torch_input_tensor_a.retain_grad() + + beta = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() + threshold = torch.tensor(1, dtype=torch.bfloat16).uniform_(100, 100).item() + while beta == 0.0 and threshold > 0.0: + beta = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() + threshold = torch.tensor(1, dtype=torch.bfloat16).uniform_(100, 100).item() + + intermediate_result = torch.nn.functional.softplus(torch_input_tensor_a, beta=beta, threshold=threshold) + intermediate_result.backward(gradient=torch_grad_tensor) + torch_output_tensor = torch_input_tensor_a.grad + + grad_tensor = ttnn.from_torch( + torch_grad_tensor, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a.detach().clone(), + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.softplus_bw( + grad_tensor, input_tensor_a, beta=beta, threshold=threshold, memory_config=output_memory_config + )[0] + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary_backward/threshold_bw/threshold_bw.py b/tests/sweep_framework/sweeps/eltwise/unary_backward/threshold_bw/threshold_bw.py new file mode 100644 index 00000000000..3a24af1d54c --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary_backward/threshold_bw/threshold_bw.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) + + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) + + gen_shapes([32, 32], [256, 256], [32, 32], 16), + "grad_dtype": [ttnn.bfloat16], + "input_a_dtype": [ttnn.bfloat16], + "grad_layout": [ttnn.TILE_LAYOUT], + "input_a_layout": [ttnn.TILE_LAYOUT], + "grad_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + grad_dtype, + input_a_dtype, + grad_layout, + input_a_layout, + grad_memory_config, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_grad_tensor = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), grad_dtype + )(input_shape) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_input_tensor_a.requires_grad = True + torch_input_tensor_a.retain_grad() + + threshold = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() + value = torch.tensor(1, dtype=torch.bfloat16).uniform_(-100, 100).item() + + intermediate_result = torch.threshold(torch_input_tensor_a, threshold, value) + intermediate_result.backward(gradient=torch_grad_tensor) + torch_output_tensor = torch_input_tensor_a.grad + + grad_tensor = ttnn.from_torch( + torch_grad_tensor, + dtype=grad_dtype, + layout=grad_layout, + device=device, + memory_config=grad_memory_config, + ) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a.detach().clone(), + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.threshold_bw( + grad_tensor, input_tensor_a, threshold, value, memory_config=output_memory_config + )[0] + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] From 969171fb4545b3f30356b75d120a53957087b771 Mon Sep 17 00:00:00 2001 From: Nemanja Grujic <109360083+nemanjagrujic@users.noreply.github.com> Date: Wed, 9 Oct 2024 09:38:45 +0200 Subject: [PATCH 28/58] Ngrujic/sweep tests 1 (#13281) * #11512: Add sweep for ttnn.isfinite * #11512: Add sweep for ttnn.isinf, isnan, isposinf and isneginf * #11512: Add sweeps for logit, lgamma, mish and multigammaln * #11512: Add sweeps for lerp --- .github/workflows/ttnn-run-sweeps.yaml | 11 +- .../sweeps/eltwise/ternary/lerp.py | 112 ++++++++++++++++++ .../sweeps/eltwise/unary/isfinite.py | 72 +++++++++++ .../sweeps/eltwise/unary/isinf.py | 73 ++++++++++++ .../sweeps/eltwise/unary/isnan.py | 73 ++++++++++++ .../sweeps/eltwise/unary/isneginf.py | 73 ++++++++++++ .../sweeps/eltwise/unary/isposinf.py | 73 ++++++++++++ .../eltwise/unary/{lgamma => }/lgamma.py | 21 +++- .../sweeps/eltwise/unary/logit.py | 87 ++++++++++++++ .../sweeps/eltwise/unary/mish.py | 86 ++++++++++++++ .../sweeps/eltwise/unary/multigammaln.py | 84 +++++++++++++ 11 files changed, 759 insertions(+), 6 deletions(-) create mode 100644 tests/sweep_framework/sweeps/eltwise/ternary/lerp.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/isfinite.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/isinf.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/isnan.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/isneginf.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/isposinf.py rename tests/sweep_framework/sweeps/eltwise/unary/{lgamma => }/lgamma.py (74%) create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/logit.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/mish.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/multigammaln.py diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 8513a2b9a82..3c40c3a2518 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -56,7 +56,6 @@ on: - eltwise.unary.i0.i0 - eltwise.unary.silu.silu - eltwise.unary.glu.glu - - eltwise.unary.lgamma.lgamma - eltwise.unary.sigmoid.sigmoid - eltwise.unary.sigmoid_accurate.sigmoid_accurate - eltwise.unary.tril.tril @@ -83,6 +82,15 @@ on: - eltwise.unary_backward.div_bw.div_bw - eltwise.unary_backward.log_bw.log_bw - eltwise.unary_backward.relu6_bw.relu6_bw + - eltwise.unary.lgamma + - eltwise.unary.logit + - eltwise.unary.mish + - eltwise.unary.multigammaln + - eltwise.unary.isfinite + - eltwise.unary.isinf + - eltwise.unary.isnan + - eltwise.unary.isneginf + - eltwise.unary.isposinf - eltwise.binary.subtract.subtract - eltwise.binary.multiply.multiply - eltwise.binary.div.div @@ -113,6 +121,7 @@ on: - eltwise.ternary.addcmul.addcmul - eltwise.ternary.addcdiv.addcdiv - eltwise.ternary.mac.mac + - eltwise.ternary.lerp - eltwise.ternary.where.where - matmul.full.matmul_default_block_sharded - matmul.full.matmul_default_height_sharded diff --git a/tests/sweep_framework/sweeps/eltwise/ternary/lerp.py b/tests/sweep_framework/sweeps/eltwise/ternary/lerp.py new file mode 100644 index 00000000000..c69560f3474 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/ternary/lerp.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 360 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 8) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 8) + + gen_shapes([1, 1], [256, 256], [1, 1], 8), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_c_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_c_layout": [ttnn.TILE_LAYOUT], + "input_c_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + input_b_dtype, + input_b_layout, + input_b_memory_config, + input_c_dtype, + input_c_layout, + input_c_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + torch_input_tensor_b = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(input_shape) + + torch_input_tensor_c = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_c_dtype + )(input_shape) + + torch_output_tensor = torch.lerp(torch_input_tensor_a, torch_input_tensor_b, torch_input_tensor_c) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + layout=input_b_layout, + device=device, + memory_config=input_b_memory_config, + ) + + input_tensor_c = ttnn.from_torch( + torch_input_tensor_c, + dtype=input_c_dtype, + layout=input_c_layout, + device=device, + memory_config=input_c_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.lerp(input_tensor_a, input_tensor_b, input_tensor_c, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(f"pcc {pcc}") + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/isfinite.py b/tests/sweep_framework/sweeps/eltwise/unary/isfinite.py new file mode 100644 index 00000000000..712e704f0a1 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/isfinite.py @@ -0,0 +1,72 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt, gen_rand_inf + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random, is_wormhole_b0 + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 64), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_rand_inf(input_shape, low=-100, high=100) + torch_output_tensor = torch.isfinite(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.isfinite(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(pcc) + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/isinf.py b/tests/sweep_framework/sweeps/eltwise/unary/isinf.py new file mode 100644 index 00000000000..34b680170d4 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/isinf.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import os +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt, gen_rand_inf + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random, is_wormhole_b0 + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 64), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_rand_inf(input_shape, low=-100, high=100) + torch_output_tensor = torch.isinf(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.isinf(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(pcc) + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/isnan.py b/tests/sweep_framework/sweeps/eltwise/unary/isnan.py new file mode 100644 index 00000000000..765a0c88729 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/isnan.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import os +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt, gen_rand_inf + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random, is_wormhole_b0 + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 64), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_rand_inf(input_shape, low=-100, high=100) + torch_output_tensor = torch.isnan(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.isnan(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(pcc) + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/isneginf.py b/tests/sweep_framework/sweeps/eltwise/unary/isneginf.py new file mode 100644 index 00000000000..5d9eca50a24 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/isneginf.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import os +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt, gen_rand_inf + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random, is_wormhole_b0 + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 64), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_rand_inf(input_shape, low=-100, high=100) + torch_output_tensor = torch.isneginf(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.isneginf(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(pcc) + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/isposinf.py b/tests/sweep_framework/sweeps/eltwise/unary/isposinf.py new file mode 100644 index 00000000000..f4d825a4812 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/isposinf.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import os +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt, gen_rand_inf + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random, is_wormhole_b0 + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 64), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_rand_inf(input_shape, low=-100, high=100) + torch_output_tensor = torch.isposinf(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.isposinf(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(pcc) + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/lgamma/lgamma.py b/tests/sweep_framework/sweeps/eltwise/unary/lgamma.py similarity index 74% rename from tests/sweep_framework/sweeps/eltwise/unary/lgamma/lgamma.py rename to tests/sweep_framework/sweeps/eltwise/unary/lgamma.py index 51597d92f37..d25d2473319 100644 --- a/tests/sweep_framework/sweeps/eltwise/unary/lgamma/lgamma.py +++ b/tests/sweep_framework/sweeps/eltwise/unary/lgamma.py @@ -25,14 +25,23 @@ # Developers can create their own generator functions and pass them to the parameters as inputs. parameters = { "nightly": { - "input_shape": gen_shapes([1, 1, 32, 32], [6, 12, 256, 256], [1, 1, 32, 32], 16) - + gen_shapes([1, 32, 32], [12, 256, 256], [1, 32, 32], 16) - + gen_shapes([32, 32], [256, 256], [32, 32], 32), + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 32) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 32) + + gen_shapes([1, 1], [256, 256], [1, 1], 32), "input_a_dtype": [ttnn.bfloat16], "input_a_layout": [ttnn.TILE_LAYOUT], "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], }, + "xfail": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 4) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 4) + + gen_shapes([1, 1], [256, 256], [1, 1], 4), + "input_a_dtype": [ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, } @@ -53,7 +62,7 @@ def run( torch.manual_seed(data_seed) torch_input_tensor_a = gen_func_with_cast_tt( - partial(torch_random, low=0.1, high=1000, dtype=torch.float32), input_a_dtype + partial(torch_random, low=0.0001, high=100, dtype=torch.float32), input_a_dtype )(input_shape) torch_output_tensor = torch.lgamma(torch_input_tensor_a) @@ -70,4 +79,6 @@ def run( output_tensor = ttnn.to_torch(result) e2e_perf = stop_measuring_time(start_time) - return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(f"pcc {pcc}") + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/logit.py b/tests/sweep_framework/sweeps/eltwise/unary/logit.py new file mode 100644 index 00000000000..b4cdfb87a40 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/logit.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16) + + gen_shapes([1, 1], [256, 256], [1, 1], 16), + "eps": [0, 10e-6, 10e-5, 10e-4, 10e-3, 10e-2, 10e-1], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, + "xfail": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 1) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 1) + + gen_shapes([1, 1], [256, 256], [1, 1], 1), + "eps": [0, 10e-6, 10e-5, 10e-4, 10e-3, 10e-2, 10e-1], + "input_a_dtype": [ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + eps, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_output_tensor = torch.logit(torch_input_tensor_a, eps) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.logit(input_tensor_a, eps=eps, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.99) + # print(f"eps {eps} pcc {pcc}") + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/mish.py b/tests/sweep_framework/sweeps/eltwise/unary/mish.py new file mode 100644 index 00000000000..a9708746c56 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/mish.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16) + + gen_shapes([1, 1], [256, 256], [1, 1], 16), + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# def mesh_device_fixture(): +# device = ttnn.open_device(device_id=0) +# assert ttnn.device.is_grayskull(device), "This op is not supported on Grayskull" +# device_name = os.environ.get("ARCH_NAME", os.environ.get("TT_ARCH_NAME", "default")).lower() +# yield (device, device_name) +# ttnn.close_device(device) +# del device + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + mish = torch.nn.Mish() + torch_output_tensor = mish(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.mish(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(f"pcc {pcc}") + return [pcc, e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/multigammaln.py b/tests/sweep_framework/sweeps/eltwise/unary/multigammaln.py new file mode 100644 index 00000000000..89744611949 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/multigammaln.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 16) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 16) + + gen_shapes([1, 1], [256, 256], [1, 1], 16), + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, + "xfail": { + "input_shape": gen_shapes([1, 1, 1, 1], [6, 12, 256, 256], [1, 1, 1, 1], 4) + + gen_shapes([1, 1, 1], [12, 256, 256], [1, 1, 1], 4) + + gen_shapes([1, 1], [256, 256], [1, 1], 4), + "input_a_dtype": [ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=1.6, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + torch_output_tensor = torch.special.multigammaln(torch_input_tensor_a, 4) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.multigammaln(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + pcc = check_with_pcc(torch_output_tensor, output_tensor, 0.999) + # print(f"pcc {pcc}") + return [pcc, e2e_perf] From 805e076bfdbab64a30922ba765b13c6c57f522a7 Mon Sep 17 00:00:00 2001 From: Hoang Ngo Date: Thu, 5 Sep 2024 10:02:23 +0000 Subject: [PATCH 29/58] #12256: Migrate moreh_cumsum operation from tt_eager to ttnn --- .../operations}/test_moreh_cumsum.py | 116 ++++++++++++- ttnn/CMakeLists.txt | 4 + .../tt_dnn/op_library/CMakeLists.txt | 4 +- .../moreh_cumsum/moreh_cumsum_op.cpp | 82 ---------- .../moreh_cumsum/moreh_cumsum_op.hpp | 43 ----- .../tt_lib/csrc/operations/primary/module.hpp | 18 --- .../device}/kernels/moreh_cumsum_nc.cpp | 2 +- .../kernels/reader_moreh_cumsum_nc.cpp | 2 +- .../moreh_cumsum/device}/kernels/utils.hpp | 0 .../kernels/writer_moreh_cumsum_nc.cpp | 2 +- .../device/moreh_cumsum_device_operation.cpp | 92 +++++++++++ .../device/moreh_cumsum_device_operation.hpp | 70 ++++++++ .../device/moreh_cumsum_program_factory.cpp} | 153 ++++++++++-------- .../moreh/moreh_cumsum/moreh_cumsum.cpp | 24 +++ .../moreh/moreh_cumsum/moreh_cumsum.hpp | 37 +++++ .../moreh_cumsum/moreh_cumsum_pybind.cpp | 44 +++++ .../moreh_cumsum/moreh_cumsum_pybind.hpp | 14 ++ .../ttnn/operations/moreh/moreh_pybind.cpp | 3 + ttnn/ttnn/operations/moreh.py | 2 + 19 files changed, 489 insertions(+), 223 deletions(-) rename tests/{tt_eager/python_api_testing/unit_testing/misc => ttnn/unit_tests/operations}/test_moreh_cumsum.py (54%) delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/moreh_cumsum => operations/moreh/moreh_cumsum/device}/kernels/moreh_cumsum_nc.cpp (96%) rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/moreh_cumsum => operations/moreh/moreh_cumsum/device}/kernels/reader_moreh_cumsum_nc.cpp (96%) rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/moreh_cumsum => operations/moreh/moreh_cumsum/device}/kernels/utils.hpp (100%) rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/moreh_cumsum => operations/moreh/moreh_cumsum/device}/kernels/writer_moreh_cumsum_nc.cpp (96%) create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.hpp rename ttnn/cpp/ttnn/{deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp => operations/moreh/moreh_cumsum/device/moreh_cumsum_program_factory.cpp} (54%) create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.cpp create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.hpp create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.cpp create mode 100644 ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.hpp diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_cumsum.py b/tests/ttnn/unit_tests/operations/test_moreh_cumsum.py similarity index 54% rename from tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_cumsum.py rename to tests/ttnn/unit_tests/operations/test_moreh_cumsum.py index 206bbcf207e..34048c0ec00 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_cumsum.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_cumsum.py @@ -4,12 +4,17 @@ import pytest import torch +import ttnn + from loguru import logger -import ttnn from models.utility_functions import comp_allclose_and_pcc -from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import TILE_HEIGHT, TILE_WIDTH +from tests.ttnn.unit_tests.operations.test_utils import TILE_HEIGHT, TILE_WIDTH + + +def create_tt_tensor(tensor: torch.Tensor, dtype, device, layout): + return ttnn.from_torch(tensor, dtype=dtype, layout=layout, device=device) def get_tensors(input_shape, output_shape, device): @@ -21,8 +26,8 @@ def get_tensors(input_shape, output_shape, device): torch_input = torch.randint(-2, 3, input_shape, dtype=cpu_dtype, requires_grad=True) torch_output = torch.randint(-2, 3, output_shape, dtype=cpu_dtype) - tt_input = ttnn.Tensor(torch_input, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) - tt_output = ttnn.Tensor(torch_output, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + tt_input = create_tt_tensor(torch_input, npu_dtype, device, npu_layout) + tt_output = create_tt_tensor(torch_output, npu_dtype, device, npu_layout) return tt_input, tt_output, torch_input @@ -74,7 +79,7 @@ def test_moreh_cumsum_dim(input_shape, dim, device): cpu_layout = ttnn.ROW_MAJOR_LAYOUT tt_output_cpu = ( - ttnn.experimental.operations.primary.moreh_cumsum(tt_input, tt_output, dim=dim) + ttnn.operations.moreh.cumsum(tt_input, dim, output=tt_output) .cpu() .to(cpu_layout) .unpad_from_tile(output_shape) @@ -114,7 +119,7 @@ def test_moreh_cumsum_dim(input_shape, dim, device): ), ids=["0", "1"], ) -def test_moreh_cumsumsum_backward(input_shape, dim, device): +def test_moreh_cumsum_backward(input_shape, dim, device): output_shape = input_shape.copy() (_, _, torch_input) = get_tensors(input_shape, output_shape, device) @@ -125,7 +130,7 @@ def test_moreh_cumsumsum_backward(input_shape, dim, device): cpu_layout = ttnn.ROW_MAJOR_LAYOUT tt_input_grad_cpu = ( - ttnn.experimental.operations.primary.moreh_cumsum_backward(tt_output_grad, tt_input_grad, dim=dim) + ttnn.operations.moreh.cumsum_backward(tt_output_grad, dim, input_grad=tt_input_grad) .cpu() .to(cpu_layout) .unpad_from_tile(input_shape) @@ -140,3 +145,100 @@ def test_moreh_cumsumsum_backward(input_shape, dim, device): logger.debug(f"Output pcc={output_pcc}") assert passing + + +@pytest.mark.parametrize( + "input_shape", + ( + ([1, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1]), + ([4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 30 - 1]), + ), + ids=[ + "1, 1, TILE_HEIGHT-1,TILE_WIDTH - 1", + "4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 30 - 1", + ], +) +@pytest.mark.parametrize( + "dim", + ( + 0, + 1, + ), + ids=["0", "1"], +) +def test_moreh_cumsum_callback(input_shape, dim, device, use_program_cache): + output_shape = input_shape.copy() + + (tt_input, tt_output, torch_input) = get_tensors(input_shape, output_shape, device) + + torch_output = torch.cumsum(torch_input, dim) + + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + + # test for equivalance + rtol = atol = 0.1 + + for i in range(2): + tt_output_cpu = ( + ttnn.operations.moreh.cumsum(tt_input, dim).cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + ) + + passing, output_pcc = comp_allclose_and_pcc(torch_output, tt_output_cpu, pcc=0.999, rtol=rtol, atol=atol) + + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + + assert passing + assert device.num_program_cache_entries() == 1 + + +@pytest.mark.parametrize( + "input_shape", + ( + ([1, 1, TILE_HEIGHT - 1, TILE_WIDTH - 1]), + ([4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 30 - 1]), + ), + ids=[ + "1, 1, TILE_HEIGHT-1,TILE_WIDTH - 1", + "4, 4, TILE_HEIGHT * 12 - 1, TILE_WIDTH * 30 - 1", + ], +) +@pytest.mark.parametrize( + "dim", + ( + 0, + 1, + ), + ids=["0", "1"], +) +def test_moreh_cumsum_backward_callback(input_shape, dim, device, use_program_cache): + output_shape = input_shape.copy() + + (_, _, torch_input) = get_tensors(input_shape, output_shape, device) + (tt_output_grad, tt_input_grad, torch_output_grad) = get_backward_tensors(output_shape, input_shape, device) + + torch_output = torch.cumsum(torch_input, dim) + torch_output.backward(torch_output_grad) + + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + # test for equivalance + rtol = atol = 0.1 + + for i in range(2): + tt_input_grad_cpu = ( + ttnn.operations.moreh.cumsum_backward(tt_output_grad, dim) + .cpu() + .to(cpu_layout) + .unpad_from_tile(input_shape) + .to_torch() + ) + + passing, output_pcc = comp_allclose_and_pcc( + torch_input.grad, tt_input_grad_cpu, pcc=0.999, rtol=rtol, atol=atol + ) + + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + + assert passing + assert device.num_program_cache_entries() == 1 diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index e9591d11070..2848e52cf2a 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -373,6 +373,10 @@ set(ALL_TTNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_bmm_backward/moreh_bmm_backward.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_bmm/moreh_bmm_pybind.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_bmm/moreh_bmm.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_program_factory.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_program_factory.cpp ${CMAKE_CURRENT_SOURCE_DIR}/cpp/ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward_pybind.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt index cf4e4a8a8cd..86ec5e6a663 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt @@ -42,8 +42,6 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_layernorm_backward/input_grad/moreh_layernorm_backward_input_grad.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_layernorm_backward/gamma_beta_grad/moreh_layernorm_backward_gamma_beta_grad.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_cumsum/moreh_cumsum_op.cpp CACHE INTERNAL "tt_dnn sources to reuse in ttnn build" -) +) \ No newline at end of file diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.cpp deleted file mode 100644 index 11737ca926f..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" - -namespace tt { -using namespace constants; -namespace operations { -namespace primary { - -//////////////////////////////////////////////////////////////////////////// -// MorehCumSum -//////////////////////////////////////////////////////////////////////////// -void MorehCumSum::validate(const std::vector& inputs) const { - TT_ASSERT((dim >= 0 && dim <= 3), "dim should be 0 - 3"); - const auto& input = inputs.at(0); - const auto& output = inputs.at(1); - - auto input_shape = input.get_legacy_shape(); - const auto& output_shape = output.get_legacy_shape(); - auto input_shape_wo_padding = input.get_legacy_shape().without_padding(); - const auto& output_shape_wo_padding = output.get_legacy_shape().without_padding(); - - for (int i = 0; i < input_shape.rank(); ++i) { - TT_ASSERT(input_shape[i] == output_shape[i]); - TT_ASSERT(input_shape_wo_padding[i] == output_shape_wo_padding[i]); - } -} - -std::vector MorehCumSum::create_output_tensors(const std::vector& inputs) const { - // Inplace - return {}; -} - -std::vector MorehCumSum::compute_output_shapes(const std::vector& inputs) const { - // Inplace - return {}; -} - -operation::ProgramWithCallbacks MorehCumSum::create_program( - const std::vector& inputs, std::vector& outputs) const { - TT_ASSERT((dim >= 0 && dim <= 3), "dim should be 0 - 3"); - auto& input = inputs.at(0); - auto& output = inputs.at(1); - - if (dim == 2 || dim == 3) { - TT_ASSERT(false, "currenty only support moreh_cumsum op for dim 0, 1"); - } - - return moreh_cumsum_nc(input, output, dim, flip); -} - -Tensor moreh_cumsum_(const Tensor& input, const Tensor& output, const int64_t& dim, const bool flip = false) { - std::vector dummy_output_tensors = {Tensor(operation::get_workers_for_op_output({input, output}))}; - - operation::launch_op( - [dim, flip]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run( - MorehCumSum{.dim = dim, .flip = flip}, input_tensors, optional_input_tensors, optional_output_tensors); - }, - {input, output}, - dummy_output_tensors); - return output; -} - -Tensor moreh_cumsum_backward(const Tensor& output_grad, const Tensor& input_grad, const int64_t& dim) { - return moreh_cumsum_(output_grad, input_grad, dim, true); -} - -Tensor moreh_cumsum(const Tensor& input, const Tensor& output, const int64_t& dim) { - return moreh_cumsum_(input, output, dim); -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp deleted file mode 100644 index 0669a801667..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp +++ /dev/null @@ -1,43 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include -#include -#include -#include - -#include "ttnn/run_operation.hpp" -#include "ttnn/tensor/tensor.hpp" - -namespace tt { - -namespace operations { - -namespace primary { - -using namespace tt_metal; - -struct MorehCumSum { - int64_t dim; - bool flip; - void validate(const std::vector &inputs) const; - std::vector compute_output_shapes(const std::vector &inputs) const; - std::vector create_output_tensors(const std::vector &inputs) const; - operation::ProgramWithCallbacks create_program( - const std::vector &inputs, std::vector &outputs) const; -}; - -operation::ProgramWithCallbacks moreh_cumsum_nc(const Tensor &input, const Tensor &output, const int64_t &dim, const bool &flip); - -Tensor moreh_cumsum_backward(const Tensor &output_grad, const Tensor &input_grad, const int64_t &dim); - -Tensor moreh_cumsum(const Tensor &input, const Tensor &output, const int64_t &dim); - -} // namespace primary - -} // namespace operations - -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp index 3f45e611dc9..e003ec42341 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp @@ -10,7 +10,6 @@ #include "ttnn/deprecated/tt_dnn/op_library/moreh_bmm/moreh_bmm_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_bmm_backward/moreh_bmm_backward_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp" @@ -243,23 +242,6 @@ void py_module(py::module& m_primary) { py::arg("input_grad_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("compute_kernel_config").noconvert() = std::nullopt, "Performs sum backward operation. Returns an input_grad tensor."); - - m_primary.def( - "moreh_cumsum", - &moreh_cumsum, - py::arg("input").noconvert(), - py::arg("output").noconvert(), - py::kw_only(), - py::arg("dim").noconvert(), - "Performs cumsum operation. Returns an output tensor."); - m_primary.def( - "moreh_cumsum_backward", - &moreh_cumsum_backward, - py::arg("output_grad").noconvert(), - py::arg("input_grad").noconvert(), - py::kw_only(), - py::arg("dim").noconvert(), - "Performs cumsum backward operation. Returns an input_grad tensor."); } } // namespace diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/moreh_cumsum_nc.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/moreh_cumsum_nc.cpp similarity index 96% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/moreh_cumsum_nc.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/moreh_cumsum_nc.cpp index dec0b8034f9..e86d92137d8 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/moreh_cumsum_nc.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/moreh_cumsum_nc.cpp @@ -30,7 +30,7 @@ void MAIN { bool enable_reload = false; for (uint32_t j = 0; j < num_tiles_to_cumsum; ++j) { ACQ(); - uint32_t cb_add = (enable_reload) ? (cb_intermed0) : (cb_in1); + uint32_t cb_add = (enable_reload) ? (cb_intermed0) : (cb_in1); cb_wait_front(cb_in0, onetile); add_tiles_init(); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/reader_moreh_cumsum_nc.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/reader_moreh_cumsum_nc.cpp similarity index 96% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/reader_moreh_cumsum_nc.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/reader_moreh_cumsum_nc.cpp index 1ca14448a4d..3b906141a6e 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/reader_moreh_cumsum_nc.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/reader_moreh_cumsum_nc.cpp @@ -5,7 +5,7 @@ #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" inline uint32_t get_read_tile_id(uint32_t tile_id, uint32_t dim, uint32_t CHtWt, uint32_t HtWt) { - return (dim == 0 ) ? (tile_id) : (tile_id / HtWt * CHtWt) + (tile_id % HtWt); + return (dim == 0) ? (tile_id) : (tile_id / HtWt * CHtWt) + (tile_id % HtWt); } void kernel_main() { diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/utils.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/utils.hpp similarity index 100% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/utils.hpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/utils.hpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/writer_moreh_cumsum_nc.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/writer_moreh_cumsum_nc.cpp similarity index 96% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/writer_moreh_cumsum_nc.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/writer_moreh_cumsum_nc.cpp index fa8ef42f292..79c4f3eda1e 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/writer_moreh_cumsum_nc.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/writer_moreh_cumsum_nc.cpp @@ -7,7 +7,7 @@ #include "dataflow_api.h" inline uint32_t get_write_tile_id(uint32_t tile_id, uint32_t dim, uint32_t CHtWt, uint32_t HtWt) { - return (dim == 0 ) ? (tile_id) : (tile_id / HtWt * CHtWt) + (tile_id % HtWt); + return (dim == 0) ? (tile_id) : (tile_id / HtWt * CHtWt) + (tile_id % HtWt); } void kernel_main() { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp new file mode 100644 index 00000000000..08390f56752 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.cpp @@ -0,0 +1,92 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_cumsum_device_operation.hpp" + +#include "ttnn/tensor/tensor.hpp" + +namespace ttnn::operations::moreh::moreh_cumsum { +void MorehCumsumDeviceOperation::validate_inputs( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto dim = operation_attributes.dim; + TT_FATAL((dim >= 0 && dim <= 3), "dim should be 0 - 3, but got: {}", dim); + const auto& input = tensor_args.input; + const auto& output = tensor_args.output; + + if (!output.has_value()) { + return; + } + + const auto input_shape = input.get_shape(); + const auto output_shape = output.value().get_shape(); + const auto input_shape_wo_padding = input_shape.value.without_padding(); + const auto output_shape_wo_padding = output_shape.value.without_padding(); + + for (int i = 0; i < input_shape.rank(); ++i) { + TT_FATAL( + input_shape[i] == output_shape[i], + "Input shape must match output shape. Received input_shape = {} and output_shape = {}.", + input_shape[i], + output_shape[i]); + TT_FATAL( + input_shape_wo_padding[i] == output_shape_wo_padding[i], + "Input and output shapes (excluding padding) must be equal. Received input_shape_wo_padding = {} and " + "output_shape_wo_padding = {}.", + input_shape_wo_padding[i], + output_shape_wo_padding[i]); + } +} + +MorehCumsumDeviceOperation::program_factory_t MorehCumsumDeviceOperation::select_program_factory( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + return ProgramFactory{}; +} + +void MorehCumsumDeviceOperation::validate_on_program_cache_miss( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_inputs(operation_attributes, tensor_args); +}; + +void MorehCumsumDeviceOperation::validate_on_program_cache_hit( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + validate_inputs(operation_attributes, tensor_args); +}; + +MorehCumsumDeviceOperation::shape_return_value_t MorehCumsumDeviceOperation::compute_output_shapes( + const operation_attributes_t&, const tensor_args_t& tensor_args) { + const auto& input = tensor_args.input; + auto output_shape = input.get_shape(); + return output_shape; +} + +MorehCumsumDeviceOperation::tensor_return_value_t MorehCumsumDeviceOperation::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& input = tensor_args.input; + const auto& output = tensor_args.output; + if (output.has_value()) { + return output.value(); + } + + auto output_shape = compute_output_shapes(operation_attributes, tensor_args); + return create_device_tensor( + output_shape, input.tensor_attributes->dtype, input.tensor_attributes->layout, input.device()); +} + +std::tuple +MorehCumsumDeviceOperation::invoke( + const Tensor& input, + const int64_t dim, + const std::optional& output, + const bool flip, + const std::optional& memory_config, + const std::optional& compute_kernel_config) { + return { + operation_attributes_t{ + dim, + flip, + memory_config.value_or(input.memory_config()), + init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4)}, + tensor_args_t{input, output}}; +} +} // namespace ttnn::operations::moreh::moreh_cumsum diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.hpp new file mode 100644 index 00000000000..04d3f356647 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.hpp @@ -0,0 +1,70 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttnn/decorators.hpp" +#include "ttnn/device_operation.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" +#include "ttnn/tensor/types.hpp" + +namespace ttnn::operations::moreh::moreh_cumsum { +struct MorehCumsumDeviceOperation { + struct operation_attributes_t { + int64_t dim; + bool flip; + const MemoryConfig memory_config; + const DeviceComputeKernelConfig compute_kernel_config; + }; + + struct tensor_args_t { + const Tensor& input; + const std::optional& output; + }; + + using shape_return_value_t = Shape; + using tensor_return_value_t = Tensor; + + struct ProgramFactory { + struct shared_variables_t { + KernelHandle unary_reader_kernel_id; + KernelHandle unary_writer_kernel_id; + std::size_t num_cores; + std::size_t num_cores_y; + }; + + using cached_program_t = ttnn::device_operation::CachedProgram; + + static cached_program_t create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + + static void override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output); + }; + + using program_factory_t = std::variant; + + static void validate_inputs(const operation_attributes_t&, const tensor_args_t&); + static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); + static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); + static std::tuple invoke( + const Tensor& input, + const int64_t dim, + const std::optional& output, + const bool flip, + const std::optional& memory_config, + const std::optional& compute_kernel_config); +}; +} // namespace ttnn::operations::moreh::moreh_cumsum + +namespace ttnn::prim { +constexpr auto moreh_cumsum = ttnn:: + register_operation<"ttnn::prim::moreh_cumsum", ttnn::operations::moreh::moreh_cumsum::MorehCumsumDeviceOperation>(); +} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_program_factory.cpp similarity index 54% rename from ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp rename to ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_program_factory.cpp index f3f2ea0de3f..c333d8e1f5c 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_nc/moreh_cumsum_nc.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_program_factory.cpp @@ -1,43 +1,44 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. // // SPDX-License-Identifier: Apache-2.0 -#include "ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/moreh_cumsum_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" +#include "moreh_cumsum_device_operation.hpp" #include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" +#include "tt_metal/common/work_split.hpp" #include "tt_metal/host_api.hpp" +#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -namespace tt { -using namespace constants; -namespace operations { +namespace ttnn::operations::moreh::moreh_cumsum { -namespace primary { +MorehCumsumDeviceOperation::ProgramFactory::cached_program_t MorehCumsumDeviceOperation::ProgramFactory::create( + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + auto& input = tensor_args.input; + + auto dim = operation_attributes.dim; + auto flip = operation_attributes.flip; -operation::ProgramWithCallbacks moreh_cumsum_nc( - const Tensor &input, const Tensor &output, const int64_t &dim, const bool &flip) { TT_ASSERT(dim == 0 || dim == 1); //////////////////////////////////////////////////////////////////////////// // Device Setup //////////////////////////////////////////////////////////////////////////// - auto *device = input.device(); + auto* device = input.device(); auto program = Program(); //////////////////////////////////////////////////////////////////////////// // Parameters Setup //////////////////////////////////////////////////////////////////////////// const auto cb_data_format = datatype_to_dataformat_converter(output.get_dtype()); - const auto single_tile_size = detail::TileSize(cb_data_format); - const auto &input_shape = input.get_legacy_shape(); - const auto &input_shape_without_padding = input_shape.without_padding(); + const auto& input_shape = input.get_legacy_shape(); + const auto& input_shape_without_padding = input_shape.without_padding(); const auto N = input_shape[0]; const auto C = input_shape[1]; - const auto Ht = input_shape[2] / TILE_HEIGHT; - const auto Wt = input_shape[3] / TILE_WIDTH; + const auto Ht = input_shape[2] / tt::constants::TILE_HEIGHT; + const auto Wt = input_shape[3] / tt::constants::TILE_WIDTH; const auto HtWt = Ht * Wt; const auto CHtWt = C * HtWt; const auto NHtWt = N * HtWt; @@ -45,9 +46,9 @@ operation::ProgramWithCallbacks moreh_cumsum_nc( const auto input_tile_offset = (dim == 0) ? (CHtWt) : (HtWt); const auto num_tiles_per_chip = (dim == 0) ? (CHtWt) : (NHtWt); - log_debug(LogOp, "N {} C {} Ht {} Wt {}", N, C, Ht, Wt); + log_debug(tt::LogOp, "N {} C {} Ht {} Wt {}", N, C, Ht, Wt); log_debug( - LogOp, + tt::LogOp, "dim {} num_cumsum_tiles {} input_tile_offset {} num_tiles_per_chip {}", dim, num_cumsum_tiles, @@ -64,26 +65,27 @@ operation::ProgramWithCallbacks moreh_cumsum_nc( const uint32_t in1_t = 1; // zero const uint32_t intermed0_t = 1; // accumulated sum const uint32_t out0_t = 2; // output + + auto arch = input.device()->arch(); + auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = + get_compute_kernel_config_args(arch, operation_attributes.compute_kernel_config); + const auto - [num_cores_to_be_used, - all_cores, - core_group_1, - core_group_2, - num_cols_per_core_group_1, - num_cols_per_core_group_2] = tt_metal::split_work_to_cores(grid, num_tiles_per_chip); + [num_cores, all_cores, core_group_1, core_group_2, num_cols_per_core_group_1, num_cols_per_core_group_2] = + tt::tt_metal::split_work_to_cores(grid, num_tiles_per_chip); //////////////////////////////////////////////////////////////////////////// // CircularBuffer Setup //////////////////////////////////////////////////////////////////////////// - CreateCircularBuffer( + tt::operations::primary::CreateCircularBuffer( program, all_cores, cb_data_format, { - {CB::c_in0, in0_t}, // input - {CB::c_in1, in1_t}, // zero - {CB::c_intermed0, intermed0_t}, // accumulated sum - {CB::c_out0, out0_t}, // output + {tt::CB::c_in0, in0_t}, // input + {tt::CB::c_in1, in1_t}, // zero + {tt::CB::c_intermed0, intermed0_t}, // accumulated sum + {tt::CB::c_out0, out0_t}, // output }); //////////////////////////////////////////////////////////////////////////// @@ -91,34 +93,47 @@ operation::ProgramWithCallbacks moreh_cumsum_nc( //////////////////////////////////////////////////////////////////////////// std::vector reader_compile_time_args; std::vector writer_compile_time_args; - const auto reader_kernel_file = "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/reader_moreh_cumsum_nc.cpp"; - const auto writer_kernel_file = "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/writer_moreh_cumsum_nc.cpp"; - const auto reader_kernel_id = CreateReadKernel(program, reader_kernel_file, all_cores, reader_compile_time_args); - const auto writer_kernel_id = CreateWriteKernel(program, writer_kernel_file, all_cores, writer_compile_time_args); + const auto reader_kernel_file = + "ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/reader_moreh_cumsum_nc.cpp"; + const auto writer_kernel_file = + "ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/writer_moreh_cumsum_nc.cpp"; + const auto reader_kernel_id = + tt::operations::primary::CreateReadKernel(program, reader_kernel_file, all_cores, reader_compile_time_args); + const auto writer_kernel_id = + tt::operations::primary::CreateWriteKernel(program, writer_kernel_file, all_cores, writer_compile_time_args); //////////////////////////////////////////////////////////////////////////// // ComputeKernel SetUp //////////////////////////////////////////////////////////////////////////// const std::vector compute_args_group_1{num_cols_per_core_group_1}; std::map compute_defines; - const auto compute_kernel_file = "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_cumsum/kernels/moreh_cumsum_nc.cpp"; - const auto compute_kernel_1_id = CreateComputeKernel( - program, compute_kernel_file, {core_group_1, num_cols_per_core_group_1, compute_args_group_1}, compute_defines); + const auto compute_kernel_file = "ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/device/kernels/moreh_cumsum_nc.cpp"; + const auto compute_kernel_1_id = tt::operations::primary::CreateComputeKernel( + program, + compute_kernel_file, + {core_group_1, num_cols_per_core_group_1, compute_args_group_1}, + compute_defines, + math_fidelity, + fp32_dest_acc_en, + math_approx_mode); std::optional compute_kernel_2_id = std::nullopt; if (!core_group_2.ranges().empty()) { const std::vector compute_args_group_2{num_cols_per_core_group_2}; - compute_kernel_2_id = CreateComputeKernel( + compute_kernel_2_id = tt::operations::primary::CreateComputeKernel( program, compute_kernel_file, {core_group_2, num_cols_per_core_group_2, compute_args_group_2}, - compute_defines); + compute_defines, + math_fidelity, + fp32_dest_acc_en, + math_approx_mode); } //////////////////////////////////////////////////////////////////////////// // RuntimeArgs SetUp //////////////////////////////////////////////////////////////////////////// - for (uint32_t i = 0, tile_offset = 0; i < num_cores_to_be_used; ++i) { + for (uint32_t i = 0, tile_offset = 0; i < num_cores; ++i) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; uint32_t num_tiles_per_core; @@ -139,7 +154,7 @@ operation::ProgramWithCallbacks moreh_cumsum_nc( num_tiles_per_core, input_tile_offset, tile_offset, - static_cast(is_dram(input)), + static_cast(tt::operations::primary::is_dram(input)), HtWt, CHtWt, static_cast(dim), @@ -154,7 +169,7 @@ operation::ProgramWithCallbacks moreh_cumsum_nc( num_tiles_per_core, input_tile_offset, tile_offset, - static_cast(is_dram(output)), + static_cast(tt::operations::primary::is_dram(output)), HtWt, CHtWt, static_cast(dim), @@ -171,31 +186,35 @@ operation::ProgramWithCallbacks moreh_cumsum_nc( tile_offset += num_tiles_per_core; } - auto override_runtime_arguments_callback = [reader_kernel_id, writer_kernel_id, num_cores_to_be_used, num_cores_y]( - const void *operation, - const Program &program, - const std::vector &input_tensors, - const std::vector> &, - const std::vector &output_tensors) { - const auto *input_buffer = input_tensors.at(0).buffer(); - const auto *output_buffer = input_tensors.at(1).buffer(); - for (uint32_t i = 0; i < num_cores_to_be_used; ++i) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = input_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = output_buffer->address(); - } + return {std::move(program), {reader_kernel_id, writer_kernel_id, num_cores, num_cores_y}}; +} + +void MorehCumsumDeviceOperation::ProgramFactory::override_runtime_arguments( + cached_program_t& cached_program, + const operation_attributes_t& operation_attributes, + const tensor_args_t& tensor_args, + tensor_return_value_t& output) { + auto& program = cached_program.program; + auto& unary_reader_kernel_id = cached_program.shared_variables.unary_reader_kernel_id; + auto& unary_writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; + auto& num_cores = cached_program.shared_variables.num_cores; + auto& num_cores_y = cached_program.shared_variables.num_cores_y; + + const auto& input = tensor_args.input; + + auto input_address = input.buffer()->address(); + auto output_address = output.buffer()->address(); + for (uint32_t i = 0; i < num_cores; ++i) { + CoreCoord core = {i / num_cores_y, i % num_cores_y}; + { + auto& runtime_args = GetRuntimeArgs(program, unary_reader_kernel_id, core); + runtime_args[0] = input_address; } - }; - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_arguments_callback}; + { + auto& runtime_args = GetRuntimeArgs(program, unary_writer_kernel_id, core); + runtime_args[0] = output_address; + } + } } - -} // namespace primary -} // namespace operations -} // namespace tt +} // namespace ttnn::operations::moreh::moreh_cumsum diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.cpp new file mode 100644 index 00000000000..583f95facbe --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.cpp @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include "moreh_cumsum.hpp" +namespace ttnn::operations::moreh::moreh_cumsum { +Tensor MorehCumsum::invoke( + const Tensor& input, + const int64_t dim, + const std::optional& output, + const std::optional& memory_config, + const std::optional& compute_kernel_config) { + return ttnn::prim::moreh_cumsum(input, dim, output, false, memory_config, compute_kernel_config); +} + +Tensor MorehCumsumBackward::invoke( + const Tensor& output_grad, + const int64_t dim, + const std::optional& input_grad, + const std::optional& memory_config, + const std::optional& compute_kernel_config) { + return ttnn::prim::moreh_cumsum(output_grad, dim, input_grad, true, memory_config, compute_kernel_config); +} +} // namespace ttnn::operations::moreh::moreh_cumsum diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.hpp new file mode 100644 index 00000000000..6e69e9af73e --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.hpp @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "ttnn/decorators.hpp" +#include "ttnn/operations/moreh/moreh_cumsum/device/moreh_cumsum_device_operation.hpp" + +namespace ttnn::operations::moreh::moreh_cumsum { +struct MorehCumsum { + static Tensor invoke( + const Tensor& input, + const int64_t dim, + const std::optional& output, + const std::optional& memory_config, + const std::optional& compute_kernel_config); +}; + +struct MorehCumsumBackward { + static Tensor invoke( + const Tensor& output_grad, + const int64_t dim, + const std::optional& input_grad, + const std::optional& memory_config, + const std::optional& compute_kernel_config); +}; +} // namespace ttnn::operations::moreh::moreh_cumsum + +namespace ttnn { +constexpr auto moreh_cumsum = ttnn:: + register_operation_with_auto_launch_op<"ttnn::moreh_cumsum", ttnn::operations::moreh::moreh_cumsum::MorehCumsum>(); + +constexpr auto moreh_cumsum_backward = ttnn::register_operation_with_auto_launch_op< + "ttnn::moreh_cumsum_backward", + ttnn::operations::moreh::moreh_cumsum::MorehCumsumBackward>(); +} // namespace ttnn diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.cpp new file mode 100644 index 00000000000..0bd8ee380de --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.cpp @@ -0,0 +1,44 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "pybind11/decorators.hpp" +#include "ttnn/operations/moreh/moreh_cumsum/moreh_cumsum.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::moreh::moreh_cumsum { + +void bind_moreh_cumsum_operation(py::module& module) { + bind_registered_operation( + module, + ttnn::moreh_cumsum, + "ttnn::moreh_cumsum", + ttnn::pybind_arguments_t{ + py::arg("input"), + py::arg("dim"), + py::kw_only(), + py::arg("output") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("compute_kernel_config") = std::nullopt, + }); +} + +void bind_moreh_cumsum_backward_operation(py::module& module) { + bind_registered_operation( + module, + ttnn::moreh_cumsum_backward, + "ttnn::moreh_cumsum_backward", + ttnn::pybind_arguments_t{ + py::arg("output_grad"), + py::arg("dim"), + py::kw_only(), + py::arg("input_grad") = std::nullopt, + py::arg("memory_config") = std::nullopt, + py::arg("compute_kernel_config") = std::nullopt, + }); +} + +} // namespace ttnn::operations::moreh::moreh_cumsum diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.hpp new file mode 100644 index 00000000000..1202ff39dbf --- /dev/null +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.hpp @@ -0,0 +1,14 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "pybind11/pybind_fwd.hpp" + +namespace py = pybind11; + +namespace ttnn::operations::moreh::moreh_cumsum { +void bind_moreh_cumsum_operation(py::module& module); +void bind_moreh_cumsum_backward_operation(py::module& module); +} // namespace ttnn::operations::moreh::moreh_cumsum diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp index 9be06cdfc80..54764433b30 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_pybind.cpp @@ -10,6 +10,7 @@ #include "ttnn/operations/moreh/moreh_bmm/moreh_bmm_pybind.hpp" #include "ttnn/operations/moreh/moreh_bmm_backward/moreh_bmm_backward_pybind.hpp" #include "ttnn/operations/moreh/moreh_dot/moreh_dot_pybind.hpp" +#include "ttnn/operations/moreh/moreh_cumsum/moreh_cumsum_pybind.hpp" #include "ttnn/operations/moreh/moreh_dot_op_backward/moreh_dot_backward_pybind.hpp" #include "ttnn/operations/moreh/moreh_getitem/moreh_getitem_pybind.hpp" #include "ttnn/operations/moreh/moreh_group_norm/moreh_group_norm_pybind.hpp" @@ -40,6 +41,8 @@ void bind_moreh_operations(py::module &module) { moreh_arange::bind_moreh_arange_operation(module); moreh_bmm_backward::bind_moreh_bmm_backward_operation(module); moreh_bmm::bind_moreh_bmm_operation(module); + moreh_cumsum::bind_moreh_cumsum_backward_operation(module); + moreh_cumsum::bind_moreh_cumsum_operation(module); moreh_dot_backward::bind_moreh_dot_backward_operation(module); moreh_dot::bind_moreh_dot_operation(module); moreh_getitem::bind_moreh_getitem_operation(module); diff --git a/ttnn/ttnn/operations/moreh.py b/ttnn/ttnn/operations/moreh.py index b4eff014932..066f892fc2c 100644 --- a/ttnn/ttnn/operations/moreh.py +++ b/ttnn/ttnn/operations/moreh.py @@ -9,6 +9,8 @@ arange = ttnn._ttnn.operations.moreh.moreh_arange bmm = ttnn._ttnn.operations.moreh.moreh_bmm bmm_backward = ttnn._ttnn.operations.moreh.moreh_bmm_backward +cumsum = ttnn._ttnn.operations.moreh.moreh_cumsum +cumsum_backward = ttnn._ttnn.operations.moreh.moreh_cumsum_backward dot = ttnn._ttnn.operations.moreh.moreh_dot dot_backward = ttnn._ttnn.operations.moreh.moreh_dot_backward getitem = ttnn._ttnn.operations.moreh.moreh_getitem From a5b1cff34b872d0c6fda4ada16867fa416450e68 Mon Sep 17 00:00:00 2001 From: Johanna Rock <129077594+johanna-rock-tt@users.noreply.github.com> Date: Wed, 9 Oct 2024 10:34:16 +0200 Subject: [PATCH 30/58] #0: Update falcon40 and llama70 demo outputs on T3K to fix CI issue (#13580) Co-authored-by: Johanna Rock --- .../falcon40b/demo/expected_output_data.json | 2 +- .../demo/data/llama3_ground_truth.json | 36 +++++++++---------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/models/demos/t3000/falcon40b/demo/expected_output_data.json b/models/demos/t3000/falcon40b/demo/expected_output_data.json index 5821c3c7b42..380fb5eeff1 100644 --- a/models/demos/t3000/falcon40b/demo/expected_output_data.json +++ b/models/demos/t3000/falcon40b/demo/expected_output_data.json @@ -1 +1 @@ -["List the first 5 prime numbers \nThe first 5 prime numbers are 2, 3, 5, 7, and 11. ", "Give a brief history of the internet \nThe internet was invented in the late 1960s by a group of researchers at the University of California, Los Angeles (UCLA). The first message was sent between two computers in 1969, and the first email was sent in 1971. The internet grew rapidly in the 1990s, and by the end of the decade, it had become a global phenomenon. Today, the internet is used for everything from shopping to social media to streaming movies and TV shows. ", "Describe to me some good coding practices \nSome good coding practices include: \n\n1. Properly commenting code to make it easier to understand and maintain. \n2. Using consistent naming conventions for variables and functions. \n3. Keeping code organized and modularized. \n4. Testing code thoroughly before deploying it. \n5. Using version control to track changes and revert mistakes. \n6. Avoiding unnecessary complexity and over-engineering. \n7. Writing clean and readable code. \n8. Using appropriate data types and avoiding unnecessary conversions. \n9. Minimizing the use of global variables", "write a short poem about Paris in English\nParis is a city of love and romance,\nWhere the streets are filled with art and culture,\nThe Eiffel Tower stands tall and proud,\nAnd the Seine River flows through the heart of the city,\nParis is a city of beauty and charm,\nWhere the people are friendly and welcoming,\nThe cafes and restaurants are filled with delicious food,\nAnd the museums and galleries are filled with treasures,\nParis is a city of history and tradition,\nWhere the monuments and landmarks are breathtaking,\nThe architecture is stunning and unique,\nAnd the city is full of life and", "Who is the inventor of the telephone?\nAlexander Graham Bell is credited with inventing the telephone in 1876. ", "write a short poem about Istanbul in English\nIstanbul is a city of contrasts,\nWhere East meets West,\nWhere ancient meets modern,\nWhere old meets new,\nWhere past meets present,\nWhere history meets future,\nWhere tradition meets innovation,\nWhere culture meets commerce,\nWhere religion meets secularism,\nWhere art meets architecture,\nWhere beauty meets chaos,\nWhere diversity meets unity,\nWhere the old city meets the new city,\nWhere the past meets the future,\nWhere the East meets the West,\nWhere the old meets the new,\nWhere the ancient meets the modern,\nWhere the", "What are the tourist attractions in Paris?\nParis is home to many famous tourist attractions such as the Eiffel Tower, Notre-Dame Cathedral, the Louvre Museum, the Champs-\u00c9lys\u00e9es, the Palace of Versailles, and the Seine River. Other popular attractions include the Arc de Triomphe, Montmartre, and the Parisian parks such as Jardin des Tuileries and Parc de la Villette. ", "How many countries are in Africa? \nThere are 54 countries in Africa. ", "what is the capital of USA? \nThe capital of USA is Washington D.C. ", "what is the capital of Canada? \nThe capital of Canada is Ottawa. ", "what is the capital of UK? \nThe capital of UK is London. ", "what is the capital of Germany? \nThe capital of Germany is Berlin. ", "what is the capital of France? \nThe capital of France is Paris. ", "what is the capital of Japan? \nThe capital of Japan is Tokyo. ", "what is the capital of India? \nThe capital of India is New Delhi. ", "what is the capital of China? \nThe capital of China is Beijing. ", "what is the currency of Cuba? \nThe currency of Cuba is the Cuban peso (CUP). ", "what is the currency of Lebanon? \nThe currency of Lebanon is the Lebanese pound (LBP). ", "what is the currency of Brazil? \nThe currency of Brazil is the Brazilian Real (BRL). ", "what is the currency of Australia? \nThe currency of Australia is the Australian dollar (AUD). ", "what is the currency of Jamaica? \nThe currency of Jamaica is the Jamaican dollar (JMD). ", "what is the currency of Egypt? \nThe currency of Egypt is the Egyptian pound (EGP). ", "what is the currency of Uzbekistan? \nThe currency of Uzbekistan is the Uzbekistani som (UZS). ", "what is the currency of Argentina? \nThe currency of Argentina is the Argentine peso. ", "describe the geographic location of London in UK\nLondon is located in the southeast of England, on the River Thames. It is the capital city of the United Kingdom and the largest city in Europe. ", "describe the geographic location of Toronto in Canada\nToronto is located in the southern part of Ontario, Canada. It is situated on the northwestern shore of Lake Ontario, and is the largest city in Canada. ", "describe the geographic location of Madrid in Spain\nMadrid is located in the center of Spain, in the heart of the Iberian Peninsula. It is the capital city of Spain and the largest city in the country. Madrid is situated in a valley surrounded by mountains, which gives it a unique climate and geography. ", "describe the geographic location of Paris in France\nParis is located in the north-central part of France, on the Seine River. It is the capital city of France and the largest city in the country. ", "describe the geographic location of Rome in Italy\nRome is located in central Italy, on the Tiber River. It is the capital city of Italy and the largest city in the country. ", "describe the geographic location of Istanbul in Turkey\nIstanbul is located in the northwest corner of Turkey, on the Bosphorus Strait, which connects the Black Sea to the Sea of Marmara. It is the largest city in Turkey and the fifth largest city in the world. ", "describe the geographic location of Shanghai in China\nShanghai is located in eastern China, on the Yangtze River Delta. It is the largest city in China and one of the largest cities in the world. ", "describe the geographic location of Lagos in Nigeria\nLagos is located in the southwestern part of Nigeria, on the Gulf of Guinea. It is the largest city in Nigeria and the fifth largest city in Africa. Lagos is also the economic and cultural hub of Nigeria, with a population of over 20 million people. "] +["List the first 5 prime numbers \nThe first 5 prime numbers are 2, 3, 5, 7, and 11. ", "Give a brief history of the internet \nThe internet was invented in the late 1960s by computer scientists at the University of California, Los Angeles (UCLA). It was originally called ARPANET and was designed to allow scientists to share information and resources across different computer networks. In the 1990s, the internet became more widely available to the public and began to transform the way people communicate and access information. Today, the internet is a ubiquitous part of modern life, with billions of people using it daily for everything from shopping to social media to streaming entertainment. ", "Describe to me some good coding practices \nSome good coding practices include: \n\n1. Properly commenting code to make it easier to understand and maintain. \n2. Using consistent naming conventions for variables and functions. \n3. Writing clean and readable code that is easy to debug. \n4. Avoiding unnecessary complexity and keeping code simple and concise. \n5. Using version control to track changes and revert mistakes. \n6. Testing code thoroughly before deploying it. \n7. Keeping up-to-date with industry standards and best practices. \n8. Collaborating with other developers to improve code", "write a short poem about Paris in English\nParis is a city of love and romance,\nWhere the streets are filled with art and culture,\nThe Eiffel Tower stands tall and proud,\nAnd the Seine River flows through the heart of the city,\nParis is a city of dreams and possibilities,\nWhere the people are friendly and welcoming,\nThe cafes and restaurants are filled with delicious food,\nAnd the museums and galleries are filled with treasures,\nParis is a city of beauty and charm,\nWhere the architecture is stunning and the parks are lush,\nThe city is alive with energy and excitement,\nAnd the people", "Who is the inventor of the telephone?\nAlexander Graham Bell is credited with inventing the telephone in 1876. ", "write a short poem about Istanbul in English\nIstanbul is a city of contrasts,\nWhere East meets West,\nWhere ancient meets modern,\nWhere old meets new,\nWhere past meets present,\nWhere history meets future,\nWhere tradition meets innovation,\nWhere culture meets commerce,\nWhere religion meets secularism,\nWhere art meets architecture,\nWhere beauty meets chaos,\nWhere diversity meets unity,\nWhere the old city meets the new city,\nWhere the past meets the future,\nWhere the East meets the West,\nWhere the old meets the new,\nWhere the ancient meets the modern,\nWhere the", "What are the tourist attractions in Paris?\nParis is home to many famous landmarks and attractions such as the Eiffel Tower, Notre-Dame Cathedral, the Louvre Museum, the Champs-\u00c9lys\u00e9es, the Palace of Versailles, and the Seine River. Other popular attractions include the Montmartre district, the Arc de Triomphe, and the Parisian parks such as Jardin des Tuileries and Parc de la Villette. ", "How many countries are in Africa? \nThere are 54 countries in Africa. ", "what is the capital of USA? \nThe capital of USA is Washington D.C. ", "what is the capital of Canada? \nThe capital of Canada is Ottawa. ", "what is the capital of UK? \nThe capital of UK is London. ", "what is the capital of Germany? \nThe capital of Germany is Berlin. ", "what is the capital of France? \nThe capital of France is Paris. ", "what is the capital of Japan? \nThe capital of Japan is Tokyo. ", "what is the capital of India? \nThe capital of India is New Delhi. ", "what is the capital of China? \nThe capital of China is Beijing. ", "what is the currency of Cuba? \nThe currency of Cuba is the Cuban peso (CUP). ", "what is the currency of Lebanon? \nThe currency of Lebanon is the Lebanese pound (LBP). ", "what is the currency of Brazil? \nThe currency of Brazil is the Brazilian Real (BRL). ", "what is the currency of Australia? \nThe currency of Australia is the Australian dollar (AUD). ", "what is the currency of Jamaica? \nThe currency of Jamaica is the Jamaican dollar. ", "what is the currency of Egypt? \nThe currency of Egypt is the Egyptian pound (EGP). ", "what is the currency of Uzbekistan? \nThe currency of Uzbekistan is the Uzbekistani som (UZS). ", "what is the currency of Argentina? \nThe currency of Argentina is the Argentine peso. ", "describe the geographic location of London in UK\nLondon is located in the southeast of England, on the River Thames. It is the capital city of the United Kingdom and the largest city in Europe. ", "describe the geographic location of Toronto in Canada\nToronto is located in the province of Ontario, Canada. It is situated on the northwestern shore of Lake Ontario, and is the largest city in Canada. Toronto is also the fourth largest city in North America, with a population of over 2.8 million people. ", "describe the geographic location of Madrid in Spain\nMadrid is located in the center of Spain, in the region of Madrid. It is the capital city of Spain and the largest city in the country. Madrid is situated on a plateau at an elevation of 2,180 feet (660 meters) above sea level. ", "describe the geographic location of Paris in France\nParis is located in the north-central part of France, on the Seine River. It is the capital city of France and the largest city in the country. ", "describe the geographic location of Rome in Italy\nRome is located in central Italy, on the Tiber River. It is the capital city of Italy and the largest city in the country. ", "describe the geographic location of Istanbul in Turkey\nIstanbul is located in Turkey, on the Bosphorus Strait, which connects the Black Sea to the Sea of Marmara. It is the largest city in Turkey and the fifth largest city in the world. ", "describe the geographic location of Shanghai in China\nShanghai is located in eastern China, on the Yangtze River Delta. It is the largest city in China and one of the largest cities in the world. ", "describe the geographic location of Lagos in Nigeria\nLagos is located in the southwestern part of Nigeria, on the Gulf of Guinea. It is the largest city in Nigeria and the second largest city in Africa. Lagos is also the economic and cultural center of Nigeria, with a population of over 20 million people. "] diff --git a/models/demos/t3000/llama2_70b/demo/data/llama3_ground_truth.json b/models/demos/t3000/llama2_70b/demo/data/llama3_ground_truth.json index e59003fae4c..0d6f8a194d7 100644 --- a/models/demos/t3000/llama2_70b/demo/data/llama3_ground_truth.json +++ b/models/demos/t3000/llama2_70b/demo/data/llama3_ground_truth.json @@ -1,34 +1,34 @@ [ "<|begin_of_text|>Tenstorrent is an AI startup whose RISC-V hardware aims to define a new spatial computing platform for the next century. In this interview with CEO and legendary chip architect Jim Keller we learn of their plans to revolutionize the industry. Tenstorrent is an AI startup whose RISC-V hardware aims to define a new spatial computing platform for the next century. In this interview with CEO and legendary chip architect Jim Keller we learn of their plans to revolutionize the industry.\nTenstorrent is an AI startup whose RISC-V hardware aims to define a new spatial computing platform for the next century. In this interview with CEO and legendary chip architect Jim Keller we learn of their plans to revolutionize the industry. Tenstorrent is an AI startup whose RISC-V hardware aims to define a new spatial computing platform for the next century. In this interview with CEO and legendary chip architect", "<|begin_of_text|>It was the best of times, it was the worst of times, it was the age of wisdom, it was the age of foolishness, it was the epoch of belief, it was the epoch of incredulity, it was the season of Light, it was the season of Darkness, it was the spring of hope, it was the winter of despair, we had everything before us, we had nothing before us, we were all going direct to Heaven, we were all going direct the other way \u2013 in short, the period was so far like the present period, that some of its noisiest authorities insisted on its being received, for good or for evil, in the superlative degree", - "<|begin_of_text|>I like to think (and the sooner the better!) of a cybernetic meadow where mammals and computers live together in mutually programming harmony like pure water touching clear sky. I like to think (right now, please!) of a cybernetic forest filled with pines and electronics where deer stroll peacefully past wise computers that their hooves and horns have created. I like to think (it has to be!) of a cybernetic ecology where we are responsible for the programming of our computers, of a humane meadow of technical immortality.\nI like to think (it must be!) of a cybernetic forest filled with pines and electronics and peo\u00adple strolling about under them, their brains wave fronts", - "<|begin_of_text|>We the People of the United States, in Order to form a more perfect Union, establish Justice, insure domestic Tranquility, provide for the common defence, promote the general Welfare, and secure the Blessings of Liberty to ourselves and our Posterity, do ordain and establish this Constitution for the United States of America.\nArticle. I.\nAll legislative Powers herein granted shall be vested in a Congress of the United States, which shall consist of a Senate and House of Representatives.\nThe House of Representatives shall be composed of Members chosen every second Year by the People of the several States, and the Electors in each State shall have the Qualifications requisite for Electors of the most", - "<|begin_of_text|>Katherine Johnson (August 26, 1918 - February 24, 2020) was an African-American mathematician whose calculations of orbital mechanics as a NASA employee were critical to the success of the first and subsequent U.S. crewed spaceflights. During her 35-year career at NASA and its predecessor, she earned a reputation for mastering complex manual calculations and helped pioneer the use of computers to perform the tasks. The space agency noted her \"historical role as one of the first African-American women to work as a NASA scientist\". Johnson's work included calculating trajectories, launch windows, and emergency return paths for Project Mercury spaceflights. In 2015, President Barack Obama awarded Johnson the Presidential Medal of Freedom. In 2016", - "<|begin_of_text|>Knock, knock. Who's there? The police. The police who? The police who are here to arrest you for not having a sense of humor.", + "<|begin_of_text|>I like to think (and the sooner the better!) of a cybernetic meadow where mammals and computers live together in mutually programming harmony like pure water touching clear sky. I like to think (right now, please!) of a cybernetic forest filled with pines and electronics where deer stroll peacefully past wise computers that their hooves and horns have grown, where owls and umber hawks assess their fame with accurate iridescence and where, in a friendly way, the wrens and curlews sing to microphones. I like to think (it has to be!) of a cybernetic ecology where we are responsible for the programming of our lives. People who will not acknowledge this are living and", + "<|begin_of_text|>We the People of the United States, in Order to form a more perfect Union, establish Justice, insure domestic Tranquility, provide for the common defence, promote the general Welfare, and secure the Blessings of Liberty to ourselves and our Posterity, do ordain and establish this Constitution for the United States of America.\nSection 1. All legislative Powers herein granted shall be vested in a Congress of the United States, which shall consist of a Senate and House of Representatives.\nSection 2. The House of Representatives shall be composed of Members chosen every second Year by the People of the several States, and the Electors in each State shall have the Qualifications requisite for Elect", + "<|begin_of_text|>Katherine Johnson (August 26, 1918 - February 24, 2020) was an African-American mathematician whose calculations of orbital mechanics as a NASA employee were critical to the success of the first and subsequent U.S. crewed spaceflights. During her 35-year career at NASA and its predecessor, she earned a reputation for mastering complex manual calculations and helped pioneer the use of computers to perform the tasks. The space agency noted her \"historical role as one of the first African-American women to work as a NASA scientist\". Johnson's work included calculating trajectories, launch windows, and emergency return paths for Project Mercury spaceflights, including those for astronauts Alan Shepard and John Glenn, and the 1969 Apollo 11 flight to", + "<|begin_of_text|>Knock, knock. Who's there? It's the police. The police who? The police who are here to arrest you for not having a sense of humor.", "<|begin_of_text|>Count to a hundred: 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77", "<|begin_of_text|>Not like the brazen giant of Greek fame, With conquering limbs astride from land to land; Here at our sea-washed, sunset gates shall stand A mighty woman with a torch, whose flame Is the imprisoned lightning, and her name Mother of Exiles. From her beacon-hand Glows world-wide welcome; her mild eyes command The air-bridged harbor that twin cities frame. \"Keep, ancient lands, your storied pomp!\" cries she With silent lips. \"Give me your tired, your poor, Your huddled masses yearning to breathe free, The wretched refuse of your teeming shore. Send these, the homeless, tempest-tost to me,", - "<|begin_of_text|>Roses are red, violets are blue, and this Valentine\u2019s Day, we\u2019re celebrating the love we have for our favorite things. From the people who make our lives better to the things that make us smile, we\u2019re sharing the love this February 14th. So whether you\u2019re single or taken, we hope you\u2019ll join us in celebrating all the things we love!\nThis article is all about helping you get to know the Roses are red, violets are blue for PC better and install it on your PC. Here are the technical specifications you want to know about beforehand:\nHow To Install & Download Roses are red, violets are blue For PC Windows 10", + "<|begin_of_text|>Roses are red, violets are blue, and this Valentine\u2019s Day, we\u2019re celebrating the love we have for our favorite things. From the people who make our lives better to the things that make us smile, we\u2019re sharing the love this February 14th. So, whether you\u2019re single or taken, we hope you\u2019ll join us in celebrating all the things we love!\nThis article is all about helping you get to know the Roses are red, violets are blue for PC better and install it on your PC. Here are the technical specifications you want to know about beforehand:\nHow To Install & Download Roses are red, violets are blue For PC Windows ", "<|begin_of_text|>Albert Einstein (14 March 1879 - 18 April 1955) was a German-born theoretical physicist. He developed the general theory of relativity, one of the two pillars of modern physics (alongside quantum mechanics). Einstein's work is also known for its influence on the philosophy of science. Einstein is best known in popular culture for his mass\u2013energy equivalence formula E = mc2 (which has been dubbed \"the world's most famous equation\"). He received the 1921 Nobel Prize in Physics for his \"services to theoretical physics\", in particular his discovery of the law of the photoelectric effect, a pivotal step in the evolution of quantum theory.", - "<|begin_of_text|>The journey of a thousand miles begins with a single step. \u2013 Lao Tzu\nI have been thinking about this quote a lot lately. I have been thinking about it in terms of my own life and in terms of the lives of my students. I have been thinking about it in terms of the journey of a thousand miles and in terms of the journey of a single step.\nI have been thinking about this quote in terms of the journey of a thousand miles because I have been thinking about the journey of a thousand miles in terms of the journey of a single step. I have been thinking about the journey of a thousand miles in terms of the journey of a single", - "<|begin_of_text|>When I find myself in times of trouble, Mother Mary comes to me, speaking words of wisdom, let it be. And in my hour of darkness she is standing right in front of me, speaking words of wisdom, let it be. Let it be, let it be, let it be, let it be. Whisper words of wisdom, let it be. And when the broken hearted people living in the world agree, there will be an answer, let it be. For though they may be parted there is still a chance that they will see, there will be an answer. Let it be. Let it be, let it be, let it be, let it", + "<|begin_of_text|>The journey of a thousand miles begins with a single step. \u2013 Lao Tzu\nI have been thinking about this quote a lot lately. I have been thinking about it in terms of my own life and in terms of the lives of my students. I have been thinking about it in terms of the journey of a thousand miles and in terms of the journey of a single step. I have been thinking about it in terms of the journey of a thousand miles and in terms of the journey of a single step.\nI have been thinking about it in terms of the journey of a thousand miles and in terms of the journey of a single step. I have been thinking about", + "<|begin_of_text|>When I find myself in times of trouble, Mother Mary comes to me, speaking words of wisdom, let it be. And in my hour of darkness she is standing right in front of me, speaking words of wisdom, let it be. Let it be, let it be, let it be, let it be. Whisper words of wisdom, let it be. And when the broken hearted people living in the world agree, there will be an answer, let it be. For though they may be parted there is still a chance that they will see, there will be an answer. let it be. Let it be, let it be, ..... yeah ..... let it be", "<|begin_of_text|>Shall I compare thee to a summer's day? Thou art more lovely and more temperate: Rough winds do shake the darling buds of May, And summer's lease hath all too short a date: Sometime too hot the eye of heaven shines, And often is his gold complexion dimm'd; And every fair from fair sometime declines, By chance or nature's changing course untrimm'd; But thy eternal summer shall not fade Nor lose possession of that fair thou owest; Nor shall Death brag thou wander'st in his shade, When in eternal lines to time thou growest: So long as men can breathe or eyes can see, So long lives this and this gives life to thee.\n", "<|begin_of_text|>Rachel Carson (May 27, 1907 - April 14, 1964) was an American marine biologist and nature writer whose writings are credited with advancing the global environmental movement. Carson started her career as a biologist in the U.S. Bureau of Fisheries, and became a full-time nature writer in the 1950s. Her widely praised 1951 bestseller The Sea Around Us won her a U.S. National Book Award, recognition as a gifted writer, and financial security. Her next book, The Edge of the Sea, and the reissued version of her first book, Under the Sea Wind, were also bestsellers. This sea trilogy explores the whole of ocean life from the shores to the depths. Carson's writing career", - "<|begin_of_text|>Two roads diverged in a yellow wood, and I took the one less traveled by, and that has made all the difference. -Robert Frost\nI have always been a fan of Robert Frost. I love his poetry and his ability to capture the essence of life in a few short lines. This particular poem, \u201cThe Road Not Taken,\u201d has always been one of my favorites. It speaks to the choices we make in life and how they can shape our future.\nThe poem begins with the speaker standing at a fork in the road. He has to choose which path to take. He knows that both paths are equally worn, so he cannot base his decision on that. He decides", - "<|begin_of_text|>Save tonight and fight the break of dawn / come tomorrow, tomorrow I'll be gone\nI'm not sure if I've ever mentioned this before, but I'm a huge fan of the Swedish pop group Eagle-Eye Cherry. I've been a fan since I first heard his song \"Save Tonight\" on the radio in 1997. I was in high school at the time, and I remember thinking that it was the most beautiful song I had ever heard. I still think it's a beautiful song, and I still love it. I also love his other songs, such as \"Falling in Love Again,\" \"Are You Still Having Fun?,\" and \"Long Way Around", + "<|begin_of_text|>Two roads diverged in a yellow wood, and I took the one less traveled by, and that has made all the difference. -Robert Frost\nI have always been a fan of Robert Frost\u2019s poetry. I love the way he uses nature to express his thoughts and feelings. This particular poem is one of my favorites because it speaks to me on a personal level.\nI have always been a bit of a rebel. I never liked following the crowd or doing what was expected of me. I always wanted to forge my own path and do things my own way. This poem speaks to that part of me.\nThe road less traveled is often the more difficult road to take. It is", + "<|begin_of_text|>Save tonight and fight the break of dawn / come tomorrow, tomorrow I'll be gone\nI'm not sure if I'm going to be able to sleep tonight. I'm not sure if I want to. I'm not sure if I want to be awake, either. I'm not sure if I want to be alive. I'm not sure if I want to be dead. I'm not sure if I want to be anything. I'm not sure if I want to be nothing. I'm not sure if I want to be here. I'm not sure if I want to be there. I'm not sure if I want to be anywhere. I'm not sure if I", "<|begin_of_text|>The first thousand digits of PI: 3.14159265358979323846 26433832795028841971 69399375105820974944 59230781640628620899 86280348253421170679 82148086513282306647 09384460955058223172 53594081284811174502 84102701938521105559 64462294895493038196 44288109756659334461 28475648233786783165 27120190914564856692 34603486104543266482 13393607260249141273 72458700660631558817 48815209209628292540", - "<|begin_of_text|>Thirty days hath September, April, June, and November. All the rest have thirty-one, except for February, which has twenty-eight. And if it\u2019s a leap year, then it has twenty-nine. This is a mnemonic device that helps us remember how many days are in each month. But what if we wanted to know how many days are in a year? We could use a similar mnemonic device: \u201cThirty days hath September, April, June, and November. All the rest have thirty-one, except for February, which has twenty-eight. And if it\u2019s a leap year, then it has twenty-nine. So there are 365 days in a year, except for", + "<|begin_of_text|>Thirty days hath September, April, June, and November. All the rest have thirty-one, except for February, which has twenty-eight. And if it\u2019s a leap year, then it has twenty-nine. This is the rhyme that I learned as a child to remember how many days are in each month. I don\u2019t know if it\u2019s still taught in schools, but it\u2019s a good way to remember. I\u2019m not sure why it\u2019s important to remember, but it\u2019s a good way to remember.\nThe rhyme is a mnemonic device that helps people remember the number of days in each month. It is a simple way to remember the number of days in each month, and", "<|begin_of_text|>If you want to live a happy life, tie it to a goal, not to people or things. - Albert Einstein\nIf you want to live a happy life, tie it to a goal, not to people or things.\nI have a lot of things to prove to myself. One is that I can live my life fearlessly. - Oprah Winfrey\nI have a lot of things to prove to myself. One is that I can live my life fearlessly.\nI have a lot of things to prove to myself. One is that I can live my life fearlessly. - Oprah Winfrey\nI have a lot of things to prove to myself. One is that", - "<|begin_of_text|>Ada Lovelace (10 December 1815 - 27 November 1852) was an English mathematician and writer, chiefly known for her work on Charles Babbage's early mechanical general-purpose computer, the Analytical Engine. Her notes on the engine include what is recognised as the first algorithm intended to be processed by a machine. Because of this, she is often regarded as the first computer programmer.\nAda Lovelace was born Augusta Ada Byron, the only legitimate child of the poet Lord Byron. She was the child of the short-lived marriage between the Romantic poet and Anne Isabella \"Annabella\" Milbanke. Byron separated from his wife a month after Ada was born and left England forever four months later, eventually dying of", + "<|begin_of_text|>Ada Lovelace (10 December 1815 - 27 November 1852) was an English mathematician and writer, chiefly known for her work on Charles Babbage's early mechanical general-purpose computer, the Analytical Engine. Her notes on the engine include what is recognised as the first algorithm intended to be processed by a machine. Because of this, she is often regarded as the first computer programmer.\nAda Lovelace was born Augusta Ada Byron, the only legitimate child of the poet Lord Byron and his wife Anne Isabella Byron. She was named after Byron's half-sister, Augusta Leigh, and was called \"Ada\" by Byron himself. On 16 January 1816, at Lord Byron's own insistence, Annabella", "<|begin_of_text|>Call me Ishmael. Some years ago\u2014never mind how long precisely\u2014having little or no money in my purse, and nothing particular to interest me on shore, I thought I would sail about a little and see the watery part of the world. It is a way I have of driving off the spleen and regulating the circulation. Whenever I find myself growing grim about the mouth; whenever it is a damp, drizzly November in my soul; whenever I find myself involuntarily pausing before coffin warehouses, and bringing up the rear of every funeral I meet; and especially whenever my hypos get such an upper hand of me, that it requires a strong moral principle", - "<|begin_of_text|>The true sign of intelligence is not knowledge but imagination. - Albert Einstein\n\"The true sign of intelligence is not knowledge but imagination.\"Albert Einstein\nThe True , The True Knowledge , The True Meaning , The True You , The True Light , The True Self , The True Church , The True God , The True Israel , The True Vine , The True Way , The True Word , The True Worship , The True Worshipers , The True , The True Believer , The True Church of God , The True God of Israel , The True Israel of God , The True Light of God , The True Meaning of Christmas , The True Meaning of Easter , The True Meaning of Life , The", + "<|begin_of_text|>The true sign of intelligence is not knowledge but imagination. - Albert Einstein\n\"The true sign of intelligence is not knowledge but imagination.\"Albert Einstein\nThe True , The True Knowledge , The True Meaning , The True You , The True Light , The True Self , The True Church , The True Gospel , The True Israel , The True Vine , The True Way , The True Word , The True Worship , The True Worshipers , The True , The True Believer , The True Church of God , The True God , The True Gospel of Jesus Christ , The True Israel of God , The True Light of God , The True Meaning of Christmas , The True Meaning of Easter , The True", "<|begin_of_text|>Consider the sequence of prime numbers: 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, ", "<|begin_of_text|>Fibonacci sequence unfurls like a mathematical nautilus: 0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610, 987, 1597, 2584, 4181, 6765, 10946, 17711, 28657, 46368, 75025, 121393, 196418, 317811, 514229, 832040, 1346269, 2178309, 3524578, 5702887, 9227465, 14930352, 24157817,", - "<|begin_of_text|>Once upon a time, in a land full of dragons, there lived a young girl named Lily. Lily was a brave and adventurous girl who loved to explore the world around her. She was always curious about the dragons that lived in the land and wanted to learn more about them.\nOne day, Lily decided to go on a journey to find out more about the dragons. She packed her bag and set off on her adventure. Along the way, she met many interesting people and creatures, but none of them could tell her anything about the dragons.\nFinally, Lily reached a village where she heard stories of a wise old man who knew everything about the dragons. She went to see him and asked him", - "<|begin_of_text|>A duck walks into a store and asks the clerk, \"Do you have any grapes?\"\nThe clerk says, \"No, we don't have any grapes.\"\nThe next day, the duck walks into the store and asks the clerk, \"Do you have any grapes?\"\nThe clerk says, \"No, we don't have any grapes. If you come in here again and ask for grapes, I'm going to nail your feet to the floor.\"\nThe next day, the duck walks into the store and asks the clerk, \"Do you have any nails?\"\nThe clerk says, \"No, we don't have any nails.\"\nThe duck says, \"Do you have any grapes?\"", - "<|begin_of_text|>I heard there was a secret chord\nThat David played, and it pleased the Lord\nBut you don't really care for music, do you?\nIt goes like this, the fourth, the fifth\nThe minor fall, the major lift\nYour faith was strong but you needed proof\nYou saw her bathing on the roof\nHer beauty and the moonlight overthrew you\nShe tied you to a kitchen chair\nShe broke your throne, and she cut your hair\nAnd from your lips she drew the Hallelujah\nI've seen this room and I've walked this floor\nI used to live alone before I knew you\nI've seen your flag", + "<|begin_of_text|>Once upon a time, in a land full of dragons, there was a young girl named Lily. Lily was a brave and adventurous girl who loved to explore the world around her. She was always curious about the dragons that lived in the land and wanted to learn more about them.\nOne day, Lily decided to go on a journey to find out more about the dragons. She packed her bag and set off on her adventure. Along the way, she met many interesting people and animals, but none of them could tell her anything about the dragons.\nFinally, Lily reached a village where she met an old man who told her stories about the dragons. He told her that the dragons were powerful creatures that", + "<|begin_of_text|>A duck walks into a store and asks the clerk, \"Do you have any grapes?\"\nThe clerk says, \"No, we don't have any grapes.\"\nThe next day, the duck walks into the store and asks the clerk, \"Do you have any grapes?\"\nThe clerk says, \"No, we don't have any grapes.\"\nThe next day, the duck walks into the store and asks the clerk, \"Do you have any grapes?\"\nThe clerk says, \"No, we don't have any grapes.\"\nThe next day, the duck walks into the store and asks the clerk, \"Do you have any grapes?\"\nThe clerk says, \"No, we don't have", + "<|begin_of_text|>I heard there was a secret chord\nThat David played, and it pleased the Lord\nBut you don't really care for music, do you?\nIt goes like this, the fourth, the fifth\nThe minor fall, the major lift\nYour faith was strong but you needed proof\nHer beauty in the moonlight overthrew you\nShe tied you to a kitchen chair\nShe broke your throne, and she cut your hair\nAnd from your lips she drew the Hallelujah\nI've seen this room and I've walked this floor\nI used to live alone before I knew you\nI've seen your flag on the marble arch\nLove is not", "<|begin_of_text|>It is a truth universally acknowledged, that a single man in possession of a good fortune, must be in want of a wife. However little known the feelings or views of such a man may be on his first entering a neighbourhood, this truth is so well fixed in the minds of the surrounding families, that he is considered the rightful property of some one or other of their daughters. \"My dear Mr. Bennet,\" said his lady to him one day, \"have you heard that Netherfield Park is let at last?\" Mr. Bennet replied that he had not. \"But it is,\" returned she; \"for Mrs. Long has just been here, and she told me all", - "<|begin_of_text|>Shakespeare, William (bapt. 26 April 1564 - 23 April 1616) was an English poet and playwright, widely regarded as the greatest writer in the English language and the world's pre-eminent dramatist. He is often called England's national poet and the \"Bard of Avon\". His surviving works, including some collaborations, consist of 38 plays, 154 sonnets, two long narrative poems, and several other poems. His plays have been translated into every major living language and are performed more often than those of any other playwright.\nShakespeare was born and brought up in Stratford-upon-Avon. At the age of 18, he married Anne Hathaway, who bore him three children: Sus", - "<|begin_of_text|>The quality of mercy is not strained. It droppeth as the gentle rain from heaven upon the place beneath. It is twice blest: It blesseth him that gives and him that takes. 'Tis mightiest in the mightiest; it becomes the throned monarch better than his crown. His scepter shows the force of temporal power, the attribute to awe and majesty wherein doth sit the dread and fear of kings; but mercy is above this sceptered sway. It is enthroned in the hearts of kings; it is an attribute to God himself; and earthly power doth then show likest God's when mercy seasons justice", - "<|begin_of_text|>The Deliverator belongs to an elite order, a hallow subcategory. He's got esprit up to here. Right now he is preparing to deliver a pizza within twenty-three minutes. He's going to open the door of his car, hump a box into the kitchen of the address, and scram. Here's how long that ought to take: the door of his car has a streetlight shining on it, so it won't be a problem to see the latch. Opening it and swinging it open will take a second or two. Same thing with swinging it closed after he gets out, always assuming that it doesn't stick on his sweats. The only wild card is the car", + "<|begin_of_text|>Shakespeare, William (bapt. 26 April 1564 - 23 April 1616) was an English poet and playwright, widely regarded as the greatest writer in the English language and the world's pre-eminent dramatist. He is often called England's national poet and the \"Bard of Avon\" (or simply \"The Bard\"). His surviving works consist of 38 plays, 154 sonnets, two long narrative poems, and several other poems. His plays have been translated into every major living language, and are performed more often than those of any other playwright. Shakespeare was born and raised in Stratford-upon-Avon. At the age of 18 he married Anne Hathaway, who bore him three children: Sus", + "<|begin_of_text|>The quality of mercy is not strained. It droppeth as the gentle rain from heaven upon the place beneath. It is twice blest. It blesseth him that gives and him that takes. 'Tis mightiest in the mightiest. It becomes the throned monarch better than his crown. His scepter shows the force of temporal power, the attribute to awe and majesty wherein doth sit the dread and fear of kings. But mercy is above this sceptered sway. It is enthroned in the hearts of kings. It is an attribute to God himself. And earthly power doth then show likest God's when mercy seasons justice", + "<|begin_of_text|>The Deliverator belongs to an elite order, a hallow subcategory. He's got esprit up to here. Right now he is preparing to deliver a pizza within twenty-three minutes. He's going to open the door, bend over, and pick the pizza up. He knows that when he does, he is going to see a pair of young, playful, and extraordinarily beautiful Asian-fusion women wearing expensive, revealing black underwear. He knows this because they will be photographed in this condition, and this photograph will be tacked up in the office, at CosaNostra Pizza #3569, for the rest of his life, right next to the other eleven thousand photographs of beautiful young", "<|begin_of_text|>Counting in binary: 0000, 0001, 0010, 0011, 0100, 0101, 0110, 0111, 1000, 1001, 1010, 1011, 1100, 1101, 1110, 1111, 10000, 10001, 10010, 10011, 10100, 10101, 10110, 10111, 11000, 11001, 11010, 11011, 11100, 11101, 11110, 11111, 100000, 100001, 100010, 100011," ] From 399c7440fb1c25b1b7b013b92c0c1e54aa6eee44 Mon Sep 17 00:00:00 2001 From: Mouliraj Elamurugan Date: Wed, 9 Oct 2024 15:27:34 +0530 Subject: [PATCH 31/58] #13373: PyTorch Tracing Sweeps - Eltwise set 3 (#13435) * #13373: Add sweeps for relu_pytorch2 * #13373: Add sweep test for sin_pytorch2 * #13373: Add sweep test for maximum_pytorch2 * #13373: Add sweep test for minimum_pytorch2 * #13373: Add sweep test for tanh_pytorch2 * #13373: Add sweep test for silu_pytorch2 * #13373: Add sweep test for rsub_pytorch2 * #13373: Add sweep test for sigmoid_pytorch2 * #13373: Add sweep test for rsqrt_pytorch2 * #13373: Add sweep test for where_pytorch2 * #13373: Add sweep test for tril_pytorch2 * #13373: Update files * #13373: Add sweep for pow_scalar_pytorch2, pow_tensor_pytorch2 * #13373: Update golden functions --- .github/workflows/ttnn-run-sweeps.yaml | 14 + .../binary/maximum/maximum_pytorch2.py | 96 ++++ .../binary/minimum/minimum_pytorch2.py | 98 ++++ .../composite/binary/pow/pow_pytorch2.py | 104 ++++ .../binary/pow/pow_scalar_pytorch2.py | 82 +++ .../binary/pow/pow_tensor_pytorch2.py | 92 ++++ .../eltwise/ternary/where/where_pytorch2.py | 122 +++++ .../eltwise/unary/relu/relu_pytorch2.py | 504 ++++++++++++++++++ .../eltwise/unary/rsqrt/rsqrt_pytorch2.py | 82 +++ .../eltwise/unary/rsub/rsub_pytorch2.py | 127 +++++ .../eltwise/unary/sigmoid/sigmoid_pytorch2.py | 131 +++++ .../eltwise/unary/silu/silu_pytorch2.py | 101 ++++ .../sweeps/eltwise/unary/sin/sin_pytorch2.py | 76 +++ .../eltwise/unary/tanh/tanh_pytorch2.py | 92 ++++ .../eltwise/unary/tril/tril_pytorch2.py | 75 +++ 15 files changed, 1796 insertions(+) create mode 100644 tests/sweep_framework/sweeps/eltwise/composite/binary/maximum/maximum_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/composite/binary/minimum/minimum_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_scalar_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_tensor_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/ternary/where/where_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/relu/relu_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/rsqrt/rsqrt_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/rsub/rsub_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/sigmoid/sigmoid_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/silu/silu_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/sin/sin_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/tanh/tanh_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/tril/tril_pytorch2.py diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index 3c40c3a2518..e15b7880d44 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -14,13 +14,18 @@ on: - ccl.line_all_gather - ccl.all_gather_n300 - eltwise.unary.relu.relu + - eltwise.unary.relu.relu_pytorch2 - eltwise.unary.gelu.gelu - eltwise.unary.cos.cos - eltwise.unary.sin.sin + - eltwise.unary.sin.sin_pytorch2 + - eltwise.unary.tril.tril_pytorch2 - eltwise.unary.clamp.clamp - eltwise.unary.clip.clip - eltwise.unary.cbrt.cbrt - eltwise.unary.rsub.rsub + - eltwise.unary.rsub.rsub_pytorch2 + - eltwise.unary.rsqrt.rsqrt_pytorch2 - eltwise.unary.rdiv.rdiv - eltwise.unary.frac.frac - eltwise.unary.ceil.ceil @@ -33,6 +38,7 @@ on: - eltwise.unary.exp2.exp2 - eltwise.unary.expm1.expm1 - eltwise.unary.tanh.tanh + - eltwise.unary.tanh.tanh_pytorch2 - eltwise.unary.sign.sign - eltwise.unary.rad2deg.rad2deg - eltwise.unary.deg2rad.deg2rad @@ -55,8 +61,10 @@ on: - eltwise.unary.erfinv.erfinv - eltwise.unary.i0.i0 - eltwise.unary.silu.silu + - eltwise.unary.silu.silu_pytorch2 - eltwise.unary.glu.glu - eltwise.unary.sigmoid.sigmoid + - eltwise.unary.sigmoid.sigmoid_pytorch2 - eltwise.unary.sigmoid_accurate.sigmoid_accurate - eltwise.unary.tril.tril - eltwise.unary.triu.triu @@ -117,12 +125,18 @@ on: - eltwise.composite.binary.addalpha.addalpha - eltwise.composite.binary.subalpha.subalpha - eltwise.composite.binary.minimum.minimum + - eltwise.composite.binary.minimum.minimum_pytorch2 - eltwise.composite.binary.maximum.maximum + - eltwise.composite.binary.maximum.maximum_pytorch2 + - eltwise.composite.binary.pow.pow_pytorch2 + - eltwise.composite.binary.pow.pow_scalar_pytorch2 + - eltwise.composite.binary.pow.pow_tensor_pytorch2 - eltwise.ternary.addcmul.addcmul - eltwise.ternary.addcdiv.addcdiv - eltwise.ternary.mac.mac - eltwise.ternary.lerp - eltwise.ternary.where.where + - eltwise.ternary.where.where_pytorch2 - matmul.full.matmul_default_block_sharded - matmul.full.matmul_default_height_sharded - matmul.full.matmul_default_interleaved diff --git a/tests/sweep_framework/sweeps/eltwise/composite/binary/maximum/maximum_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/composite/binary/maximum/maximum_pytorch2.py new file mode 100644 index 00000000000..55401a3e5ff --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/composite/binary/maximum/maximum_pytorch2.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + {"shape1": [1, 16, 1, 60], "shape2": []}, + # {"shape1": [1,16,s10+1], "shape2": []}, + {"shape1": [1, 16, 19, 19], "shape2": []}, + {"shape1": [1, 16, 59, 59], "shape2": []}, + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_a_memory_config, + input_b_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape["shape1"]) + + torch_input_tensor_b = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(input_shape["shape2"]) + + golden_function = ttnn.get_golden_function(ttnn.maximum) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + layout=input_b_layout, + device=device, + memory_config=input_b_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.maximum(input_tensor_a, input_tensor_b, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/composite/binary/minimum/minimum_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/composite/binary/minimum/minimum_pytorch2.py new file mode 100644 index 00000000000..22b185e425f --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/composite/binary/minimum/minimum_pytorch2.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1], + [10, 10], + [15, 15], + [17, 17], + # [s0+1, s0+1], + [2, 2], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_a_memory_config, + input_b_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + torch_input_tensor_b = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.minimum) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + layout=input_b_layout, + device=device, + memory_config=input_b_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.minimum(input_tensor_a, input_tensor_b, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_pytorch2.py new file mode 100644 index 00000000000..3f25722d19f --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_pytorch2.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + {"shape": [1, 1, 1024], "exponent": 2}, + {"shape": [1, 1, 1024], "exponent": 3.0}, + {"shape": [1, 1, 3072], "exponent": 3.0}, + {"shape": [1, 1, 4096], "exponent": 3.0}, + {"shape": [1, 1, 512], "exponent": 2}, + {"shape": [1, 1, 768], "exponent": 2}, + {"shape": [1, 10, 1024], "exponent": 2}, + {"shape": [1, 10, 512], "exponent": 2}, + {"shape": [1, 10, 768], "exponent": 2}, + {"shape": [1, 12, 3072], "exponent": 3.0}, + {"shape": [1, 14, 3072], "exponent": 3.0}, + {"shape": [1, 15, 1024], "exponent": 3.0}, + {"shape": [1, 15, 512], "exponent": 2}, + {"shape": [1, 3, 16, 16, 2], "exponent": 2}, + {"shape": [1, 3, 32, 32, 2], "exponent": 2}, + {"shape": [1, 3, 64, 64, 2], "exponent": 2}, + {"shape": [1, 45, 3072], "exponent": 3.0}, + {"shape": [1, 5, 4096], "exponent": 3.0}, + {"shape": [1, 7, 3072], "exponent": 3.0}, + {"shape": [1, 9, 128], "exponent": 3.0}, + {"shape": [1, 9, 16384], "exponent": 3.0}, + {"shape": [1, 9, 3072], "exponent": 3.0}, + {"shape": [1, 9, 4096], "exponent": 3.0}, + {"shape": [1, 9, 8192], "exponent": 3.0}, + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_a_memory_config, + input_b_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape["shape"]) + + value = input_shape["exponent"] + torch_output_tensor = torch.pow(torch_input_tensor_a, value) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.pow(input_tensor_a, exponent=value, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_scalar_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_scalar_pytorch2.py new file mode 100644 index 00000000000..8dfb1bd1dc8 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_scalar_pytorch2.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + {"value": 10000, "shape": [128]}, + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_a_memory_config, + input_b_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape["shape"]) + + value = input_shape["value"] + golden_function = ttnn.get_golden_function(ttnn.pow) + torch_output_tensor = golden_function(torch_input_tensor_a, value) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.pow(value, exponent=input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_tensor_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_tensor_pytorch2.py new file mode 100644 index 00000000000..911d33e7df5 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/composite/binary/pow/pow_tensor_pytorch2.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + {"shape1": [], "shape2": [16]}, + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_a_memory_config, + input_b_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape["shape1"]) + + torch_input_tensor_b = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(input_shape["shape2"]) + + torch_output_tensor = torch.pow(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + layout=input_b_layout, + device=device, + memory_config=input_b_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.pow(input_tensor_a, input_tensor_b, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/ternary/where/where_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/ternary/where/where_pytorch2.py new file mode 100644 index 00000000000..720b4e1d33d --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/ternary/where/where_pytorch2.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt, gen_bin + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "where": { + "input_shape": [ + {"shape1": [1, 1, 1, 46], "shape2": [1, 12, 1, 46], "shape3": []}, + {"shape1": [1, 1, 1, 6], "shape2": [1, 16, 1, 6], "shape3": []}, + # {"shape1": [1, 1, 1, "s10 + 1"], "shape2": [1, 12, 1, "s10 + 1"], "shape3": []}, + # {"shape1": [1, 1, 1, "s10 + 1"], "shape2": [1, 16, 1, "s10 + 1"], "shape3": []}, + {"shape1": [1, 1, 256], "shape2": [1, 1, 256], "shape3": []}, + {"shape1": [1, 1, 45, 45], "shape2": [1, 12, 45, 45], "shape3": []}, + {"shape1": [1, 1, 5, 5], "shape2": [1, 16, 5, 5], "shape3": []}, + {"shape1": [1, 1, 7, 7], "shape2": [1, 12, 7, 7], "shape3": []}, + {"shape1": [1, 1], "shape2": [1, 1], "shape3": [1, 1]}, + # {"shape1": [1, "s0", 256], "shape2": [1, "s0", 256], "shape3": []}, + {"shape1": [10, 10], "shape2": [10, 10], "shape3": [10, 10]}, + {"shape1": [15, 15], "shape2": [15, 15], "shape3": [15, 15]}, + {"shape1": [17, 17], "shape2": [17, 17], "shape3": [17, 17]}, + {"shape1": [2, 2], "shape2": [2, 2], "shape3": [2, 2]}, + # {"shape1": ["s0 + 1", "s0 + 1"], "shape2": ["s0 + 1", "s0 + 1"], "shape3": []}, + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_b_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_c_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_c_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_c_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_b_dtype, + input_c_dtype, + input_a_layout, + input_b_layout, + input_c_layout, + input_a_memory_config, + input_b_memory_config, + input_c_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt(gen_bin, input_a_dtype)(input_shape["shape1"]) + torch_input_tensor_b = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(input_shape["shape2"]) + torch_input_tensor_c = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_c_dtype + )(input_shape["shape3"]) + + golden_function = ttnn.get_golden_function(ttnn.where) + torch_output_tensor = golden_function(torch_input_tensor_a > 0, torch_input_tensor_b, torch_input_tensor_c) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + layout=input_b_layout, + device=device, + memory_config=input_b_memory_config, + ) + + input_tensor_c = ttnn.from_torch( + torch_input_tensor_c, + dtype=input_c_dtype, + layout=input_c_layout, + device=device, + memory_config=input_b_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.where(input_tensor_a, input_tensor_b, input_tensor_c, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/relu/relu_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/relu/relu_pytorch2.py new file mode 100644 index 00000000000..42c81c6886c --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/relu/relu_pytorch2.py @@ -0,0 +1,504 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1, 2048], + [1, 1, 256], + [1, 1, 3072], + [1, 1, 4096], + [1, 1, 768], + [1, 10, 2048], + [1, 10, 3072], + [1, 10, 4096], + [1, 100, 14, 14], + [1, 100, 192], + [1, 1008, 14, 14], + [1, 1008, 7, 7], + [1, 1024, 10, 10], + [1, 1024, 14, 14], + [1, 1024, 19, 19], + [1, 1024, 28, 28], + [1, 1024, 45, 80], + [1, 1024, 50, 68], + [1, 1024, 7, 7], + [1, 104, 28, 28], + [1, 104, 56, 56], + [1, 1056, 14, 14], + [1, 1056, 48, 48], + [1, 1056, 7, 7], + [1, 1056, 96, 96], + [1, 1088, 14, 14], + [1, 1088, 7, 7], + [1, 110, 1, 1], + [1, 1104, 14, 14], + [1, 1104, 7, 7], + [1, 112, 1, 1], + [1, 112, 14, 14], + [1, 1120, 14, 14], + [1, 1120, 7, 7], + [1, 1152, 14, 14], + [1, 1152, 7, 7], + [1, 1184, 14, 14], + [1, 1184, 7, 7], + [1, 12, 1, 1], + [1, 120, 1, 1], + [1, 120, 28, 28], + [1, 120, 40, 40], + [1, 120, 56, 56], + [1, 1200, 14, 14], + [1, 1200, 7, 7], + [1, 1216, 14, 14], + [1, 1216, 7, 7], + [1, 1232, 14, 14], + [1, 1232, 28, 28], + [1, 1248, 14, 14], + [1, 1248, 7, 7], + [1, 128, 10, 10], + [1, 128, 100, 136], + [1, 128, 112, 112], + [1, 128, 14, 14], + [1, 128, 150, 150], + [1, 128, 17, 17], + [1, 128, 180, 320], + [1, 128, 200, 272], + [1, 128, 28, 28], + [1, 128, 3, 3], + [1, 128, 5, 5], + [1, 128, 56, 56], + [1, 128, 64, 64], + [1, 128, 7, 7], + [1, 128, 75, 75], + [1, 128, 90, 160], + [1, 1280, 1, 1], + [1, 1280, 14, 14], + [1, 1280, 7, 7], + [1, 128], + [1, 1296, 14, 14], + [1, 1296, 7, 7], + [1, 12], + [1, 1312, 14, 14], + [1, 1312, 7, 7], + [1, 132, 1, 1], + [1, 1344, 14, 14], + [1, 1344, 28, 28], + [1, 1344, 7, 7], + [1, 1376, 14, 14], + [1, 1376, 7, 7], + [1, 1392, 14, 14], + [1, 1392, 28, 28], + [1, 1392, 7, 7], + [1, 1408, 14, 14], + [1, 1408, 7, 7], + [1, 144, 1, 1], + [1, 144, 14, 14], + [1, 144, 28, 28], + [1, 144, 56, 56], + [1, 144, 7, 7], + [1, 1440, 14, 14], + [1, 1440, 7, 7], + [1, 1472, 14, 14], + [1, 1472, 7, 7], + [1, 1488, 14, 14], + [1, 1488, 7, 7], + [1, 15, 15, 512], + [1, 1504, 14, 14], + [1, 1504, 7, 7], + [1, 1512, 14, 14], + [1, 1512, 7, 7], + [1, 1536, 10, 10], + [1, 1536, 14, 14], + [1, 1536, 7, 7], + [1, 1568, 14, 14], + [1, 1568, 7, 7], + [1, 1584, 14, 14], + [1, 1584, 7, 7], + [1, 16, 1, 1], + [1, 16, 112, 112], + [1, 16, 14, 14], + [1, 16, 160, 160], + [1, 16, 224, 224], + [1, 16, 28, 28], + [1, 16, 56, 56], + [1, 160, 14, 14], + [1, 160, 28, 28], + [1, 160, 56, 56], + [1, 160, 7, 7], + [1, 1600, 14, 14], + [1, 1600, 7, 7], + [1, 1632, 14, 14], + [1, 1632, 7, 7], + [1, 1664, 14, 14], + [1, 1664, 7, 7], + [1, 168, 1, 1], + [1, 168, 28, 28], + [1, 168, 56, 56], + [1, 1680, 14, 14], + [1, 1680, 7, 7], + [1, 1696, 14, 14], + [1, 1696, 7, 7], + [1, 1728, 14, 14], + [1, 1728, 7, 7], + [1, 174, 1, 1], + [1, 1760, 14, 14], + [1, 1760, 7, 7], + [1, 1776, 14, 14], + [1, 1776, 7, 7], + [1, 1792, 14, 14], + [1, 1792, 7, 7], + [1, 18, 1, 1], + [1, 18, 14, 14], + [1, 18, 28, 28], + [1, 18, 56, 56], + [1, 1824, 14, 14], + [1, 1824, 7, 7], + [1, 1856, 7, 7], + [1, 1872, 14, 14], + [1, 1872, 7, 7], + [1, 1888, 7, 7], + [1, 192, 14, 14], + [1, 192, 17, 17], + [1, 192, 28, 28], + [1, 192, 35, 35], + [1, 192, 56, 56], + [1, 192, 7, 7], + [1, 192, 8, 8], + [1, 1920, 14, 14], + [1, 1920, 7, 7], + [1, 196, 1, 1], + [1, 1968, 14, 14], + [1, 1968, 7, 7], + [1, 20, 1, 1], + [1, 2016, 14, 14], + [1, 2016, 7, 7], + [1, 2048, 10, 10], + [1, 2048, 14, 14], + [1, 2048, 23, 40], + [1, 2048, 25, 34], + [1, 2048, 7, 7], + [1, 2064, 14, 14], + [1, 2064, 7, 7], + [1, 208, 14, 14], + [1, 208, 28, 28], + [1, 2112, 14, 14], + [1, 2112, 7, 7], + [1, 216, 28, 28], + [1, 216, 56, 56], + [1, 2160, 7, 7], + [1, 2208, 7, 7], + [1, 222, 1, 1], + [1, 224, 1, 1], + [1, 224, 112, 112], + [1, 224, 14, 14], + [1, 224, 17, 17], + [1, 224, 28, 28], + [1, 224, 35, 35], + [1, 224, 56, 56], + [1, 224, 7, 7], + [1, 232, 112, 112], + [1, 232, 56, 56], + [1, 24, 1, 1], + [1, 24, 112, 112], + [1, 24, 14, 14], + [1, 240, 1, 1], + [1, 240, 14, 14], + [1, 240, 28, 28], + [1, 240, 56, 56], + [1, 2520, 14, 14], + [1, 2520, 7, 7], + [1, 256, 1, 1], + [1, 256, 100, 136], + [1, 256, 112, 112], + [1, 256, 128, 128], + [1, 256, 13, 17], + [1, 256, 14, 14], + [1, 256, 17, 17], + [1, 256, 180, 320], + [1, 256, 19, 19], + [1, 256, 200, 272], + [1, 256, 25, 34], + [1, 256, 28, 28], + [1, 256, 3, 3], + [1, 256, 32, 32], + [1, 256, 38, 38], + [1, 256, 45, 80], + [1, 256, 5, 5], + [1, 256, 50, 68], + [1, 256, 56, 56], + [1, 256, 7, 7], + [1, 256, 7, 9], + [1, 256, 75, 75], + [1, 256, 8, 8], + [1, 256, 90, 160], + [1, 26, 1, 1], + [1, 264, 1, 1], + [1, 288, 14, 14], + [1, 288, 28, 28], + [1, 288, 56, 56], + [1, 2904, 24, 24], + [1, 2904, 48, 48], + [1, 30, 1, 1], + [1, 3024, 14, 14], + [1, 3024, 7, 7], + [1, 308, 1, 1], + [1, 32, 1, 1], + [1, 32, 112, 112], + [1, 32, 120, 160], + [1, 32, 14, 14], + [1, 32, 147, 147], + [1, 32, 149, 149], + [1, 32, 150, 150], + [1, 32, 192, 192], + [1, 32, 256, 256], + [1, 32, 26, 26], + [1, 32, 28, 28], + [1, 32, 30, 40], + [1, 32, 56, 56], + [1, 32, 60, 80], + [1, 32, 7, 7], + [1, 320, 14, 14], + [1, 320, 17, 17], + [1, 320, 28, 28], + [1, 320, 7, 7], + [1, 320, 8, 8], + [1, 336, 112, 112], + [1, 336, 14, 14], + [1, 336, 28, 28], + [1, 336, 56, 56], + [1, 348, 1, 1], + [1, 352, 14, 14], + [1, 352, 28, 28], + [1, 36, 1, 1], + [1, 36, 14, 14], + [1, 36, 28, 28], + [1, 36, 56, 56], + [1, 3712, 14, 14], + [1, 3712, 7, 7], + [1, 384, 14, 14], + [1, 384, 17, 17], + [1, 384, 28, 28], + [1, 384, 56, 56], + [1, 384, 7, 7], + [1, 384, 8, 8], + [1, 4, 14, 14], + [1, 40, 1, 1], + [1, 400, 14, 14], + [1, 400, 7, 7], + [1, 408, 14, 14], + [1, 408, 28, 28], + [1, 4096], + [1, 416, 14, 14], + [1, 416, 28, 28], + [1, 432, 14, 14], + [1, 432, 28, 28], + [1, 440, 14, 14], + [1, 440, 7, 7], + [1, 448, 14, 14], + [1, 448, 28, 28], + [1, 448, 56, 56], + [1, 448, 8, 8], + [1, 48, 112, 112], + [1, 48, 14, 14], + [1, 48, 56, 56], + [1, 48, 7, 7], + [1, 480, 14, 14], + [1, 480, 28, 28], + [1, 480, 7, 7], + [1, 512, 10, 10], + [1, 512, 100, 136], + [1, 512, 14, 14], + [1, 512, 16, 16], + [1, 512, 19, 19], + [1, 512, 23, 40], + [1, 512, 25, 34], + [1, 512, 28, 28], + [1, 512, 38, 38], + [1, 512, 45, 80], + [1, 512, 50, 68], + [1, 512, 56, 56], + [1, 512, 7, 7], + [1, 512, 8, 8], + [1, 512, 90, 160], + [1, 52, 1, 1], + [1, 528, 14, 14], + [1, 528, 192, 192], + [1, 528, 28, 28], + [1, 528, 96, 96], + [1, 54, 1, 1], + [1, 544, 14, 14], + [1, 544, 7, 7], + [1, 56, 1, 1], + [1, 576, 14, 14], + [1, 576, 28, 28], + [1, 576, 7, 7], + [1, 58, 1, 1], + [1, 60, 28, 28], + [1, 608, 14, 14], + [1, 608, 7, 7], + [1, 624, 14, 14], + [1, 624, 28, 28], + [1, 64, 1, 1], + [1, 64, 112, 112], + [1, 64, 120, 160], + [1, 64, 128, 128], + [1, 64, 14, 14], + [1, 64, 147, 147], + [1, 64, 150, 150], + [1, 64, 160, 160], + [1, 64, 180, 320], + [1, 64, 200, 272], + [1, 64, 224, 224], + [1, 64, 24, 24], + [1, 64, 28, 28], + [1, 64, 30, 40], + [1, 64, 300, 300], + [1, 64, 35, 35], + [1, 64, 360, 640], + [1, 64, 400, 544], + [1, 64, 480, 640], + [1, 64, 56, 56], + [1, 64, 60, 80], + [1, 64, 73, 73], + [1, 64, 80, 80], + [1, 640, 14, 14], + [1, 640, 7, 7], + [1, 64], + [1, 672, 14, 14], + [1, 672, 28, 28], + [1, 672, 56, 56], + [1, 672, 7, 7], + [1, 696, 28, 28], + [1, 696, 56, 56], + [1, 704, 14, 14], + [1, 704, 7, 7], + [1, 72, 1, 1], + [1, 72, 112, 112], + [1, 72, 14, 14], + [1, 72, 28, 28], + [1, 72, 40, 40], + [1, 72, 56, 56], + [1, 72, 80, 80], + [1, 720, 14, 14], + [1, 720, 28, 28], + [1, 726, 1, 1], + [1, 728, 19, 19], + [1, 728, 38, 38], + [1, 736, 14, 14], + [1, 736, 7, 7], + [1, 7392, 12, 12], + [1, 7392, 24, 24], + [1, 768, 14, 14], + [1, 768, 28, 28], + [1, 768, 7, 7], + [1, 784, 14, 14], + [1, 784, 7, 7], + [1, 8, 1, 1], + [1, 8, 112, 112], + [1, 80, 1, 1], + [1, 80, 112, 112], + [1, 80, 56, 56], + [1, 800, 14, 14], + [1, 800, 7, 7], + [1, 816, 14, 14], + [1, 832, 14, 14], + [1, 832, 7, 7], + [1, 84, 1, 1], + [1, 864, 14, 14], + [1, 864, 7, 7], + [1, 88, 28, 28], + [1, 888, 14, 14], + [1, 888, 7, 7], + [1, 896, 14, 14], + [1, 896, 28, 28], + [1, 896, 7, 7], + [1, 912, 14, 14], + [1, 912, 7, 7], + [1, 92, 14, 14], + [1, 928, 14, 14], + [1, 928, 7, 7], + [1, 96, 112, 112], + [1, 96, 14, 14], + [1, 96, 28, 28], + [1, 96, 35, 35], + [1, 96, 56, 56], + [1, 96, 71, 71], + [1, 96, 73, 73], + [1, 960, 14, 14], + [1, 960, 7, 7], + [1, 992, 14, 14], + [1, 992, 7, 7], + # [1, "s0", 256], + # [1, "s0", 768], + [100, 1, 2048], + [59, 4096], + [6, 1, 100, 256], + [920, 1, 2048], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.relu) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.relu(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/rsqrt/rsqrt_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/rsqrt/rsqrt_pytorch2.py new file mode 100644 index 00000000000..ee0d03fba83 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/rsqrt/rsqrt_pytorch2.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1, 1], + [1, 10, 1], + [1, 1024, 1, 1], + [1, 128, 1, 1], + [1, 15, 1], + [1, 2048, 1, 1], + [1, 256, 1, 1], + [1, 512, 1, 1], + [1, 64, 1, 1], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + golden_function = ttnn.get_golden_function(ttnn.rsqrt) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.rsqrt(input_tensor_a, fast_and_approximate_mode=True, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/rsub/rsub_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/rsub/rsub_pytorch2.py new file mode 100644 index 00000000000..0d5fd280dd8 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/rsub/rsub_pytorch2.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1, 1, 10], + [1, 1, 1, 12], + [1, 1, 1, 14], + [1, 1, 1, 15], + [1, 1, 1, 17], + [1, 1, 1, 1], + [1, 1, 1, 201], + [1, 1, 1, 2048], + [1, 1, 1, 24], + [1, 1, 1, 256], + [1, 1, 1, 25], + [1, 1, 1, 2], + [1, 1, 1, 42], + [1, 1, 1, 42], + [1, 1, 1, 46], + [1, 1, 1, 5], + [1, 1, 1, 60], + [1, 1, 1, 6], + [1, 1, 1, 7], + [1, 1, 1, 8], + [1, 1, 1, 9], + # [1, 1, 1, s0 + 1], + # [1, 1, 1, s0], + # [1, 1, 1, s10 + 1], + [1, 1, 19, 19], + [1, 1, 24, 24], + [1, 1, 32, 1], + [1, 1, 32, 1], + [1, 1, 32, 32], + [1, 1, 45, 45], + [1, 1, 59, 59], + [1, 192], + [1066], + [120, 1], + [128, 1], + [128], + [160], + [2, 1, 7, 7], + [240, 1], + [30, 1], + [300, 1], + [300], + [320, 1], + [320], + [40], + [480, 1], + [60, 1], + [640], + [800, 1], + [80], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + factor = 1.0 + + golden_function = ttnn.get_golden_function(ttnn.rsub) + torch_output_tensor = golden_function(torch_input_tensor_a, factor) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.rsub(input_tensor_a, value=factor, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/sigmoid/sigmoid_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/sigmoid/sigmoid_pytorch2.py new file mode 100644 index 00000000000..45b085082a0 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/sigmoid/sigmoid_pytorch2.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1, 256, 256], + [1, 1, 480, 640], + [1, 100, 4], + [1, 104, 1, 1], + [1, 1056, 1, 1], + [1, 12, 64, 64], + [1, 120, 1, 1], + [1, 120, 14, 14], + [1, 1232, 1, 1], + [1, 1392, 1, 1], + [1, 144, 1, 1], + [1, 1512, 1, 1], + [1, 16, 64, 64], + [1, 184, 7, 7], + [1, 2, 120, 160], + [1, 2, 30, 40], + [1, 2, 60, 80], + [1, 200, 7, 7], + [1, 2016, 1, 1], + [1, 208, 1, 1], + [1, 216, 1, 1], + [1, 224, 1, 1], + [1, 232, 1, 1], + [1, 24, 64, 64], + [1, 240, 14, 14], + [1, 2904, 1, 1], + [1, 3, 16, 16, 85], + [1, 3, 32, 32, 85], + [1, 3, 64, 64, 85], + [1, 3, 64, 64], + [1, 3024, 1, 1], + [1, 32, 64, 64], + [1, 320, 1, 1], + [1, 336, 1, 1], + [1, 3712, 1, 1], + [1, 4, 64, 64], + [1, 440, 1, 1], + [1, 448, 1, 1], + [1, 48, 1, 1], + [1, 480, 7, 7], + [1, 50, 3072], + [1, 528, 1, 1], + [1, 576, 1, 1], + [1, 6, 64, 64], + [1, 64, 1, 1], + [1, 672, 7, 7], + [1, 696, 1, 1], + [1, 72, 1, 1], + [1, 72, 28, 28], + [1, 7392, 1, 1], + [1, 784, 1, 1], + [1, 8, 64, 64], + [1, 888, 1, 1], + [1, 896, 1, 1], + [1, 960, 3, 3], + [2, 7, 2048], + [6, 1, 100, 4], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.sigmoid) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.sigmoid(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/silu/silu_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/silu/silu_pytorch2.py new file mode 100644 index 00000000000..3a6fe10f458 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/silu/silu_pytorch2.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + # [1, 128, ((s1 - 1)//2) + 1, ((s2 - 1)//2) + 1], + # [1, 128, s1, s2], + [1, 1280, 16, 16], + [1, 1280, 32, 32], + [1, 1280, 8, 8], + [1, 1280], + [1, 1920, 16, 16], + [1, 1920, 32, 32], + # [1, 256, ((s1 - 1)//2) + 1, ((s2 - 1)//2) + 1], + # [1, 256, s0, s1], + # [1, 256, s1, s2], + [1, 2560, 16, 16], + [1, 2560, 8, 8], + [1, 32, 256, 256], + # [1, 32, s0, s1], + [1, 320, 32, 32], + [1, 320, 64, 64], + # [1, 512, ((s1 - 1)//2) + 1, ((s2 - 1)//2) + 1], + # [1, 512, s1, s2], + # [1, 64, ((s1 - 1)//2) + 1, ((s2 - 1)//2) + 1], + # [1, 64, s0, s1], + # [1, 64, s1, s2], + [1, 640, 16, 16], + [1, 640, 32, 32], + [1, 640, 64, 64], + [1, 960, 32, 32], + [1, 960, 64, 64], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float16), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.silu) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.silu(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/sin/sin_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/sin/sin_pytorch2.py new file mode 100644 index 00000000000..f09fc207b12 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/sin/sin_pytorch2.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 160], + [1, 23, 40, 64], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=0, high=6.283185307179586, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.sin) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.sin(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/tanh/tanh_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/tanh/tanh_pytorch2.py new file mode 100644 index 00000000000..8d5c2746dc0 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/tanh/tanh_pytorch2.py @@ -0,0 +1,92 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1, 1024], + [1, 1, 24576], + [1, 1, 3072], + [1, 1, 4096], + [1, 12, 3072], + [1, 14, 3072], + [1, 15, 1024], + [1, 256, 96], + [1, 32, 6144], + [1, 45, 3072], + [1, 5, 4096], + [1, 7, 3072], + [1, 768], + [1, 9, 128], + [1, 9, 16384], + [1, 9, 3072], + [1, 9, 4096], + [1, 9, 8192], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.tanh) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.tanh(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/tril/tril_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/tril/tril_pytorch2.py new file mode 100644 index 00000000000..82a90241b30 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/tril/tril_pytorch2.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [7, 7], + ], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=0, high=6.283185307179586, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.tril) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.tril(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] From e9cf0741da42548755083c09cae6843ecbfc3d87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=AA=20Tr=C6=B0=E1=BB=9Dng=20Giang?= <76864037+o2buzzle@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:59:23 +0700 Subject: [PATCH 32/58] #13317: revise moreh_adam (#13318) --- .../unit_tests/operations/test_moreh_adam.py | 23 ++++++++++- .../device/moreh_adam_device_operation.cpp | 3 +- .../device/moreh_adam_device_operation.hpp | 2 +- .../device/moreh_adam_program_factory.cpp | 3 +- .../moreh/moreh_adam/moreh_adam.cpp | 41 +++++++++++++++++++ .../moreh/moreh_adam/moreh_adam.hpp | 25 ++++++++++- 6 files changed, 91 insertions(+), 6 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_moreh_adam.py b/tests/ttnn/unit_tests/operations/test_moreh_adam.py index a065103d50c..a1a49aa1b1e 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_adam.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_adam.py @@ -8,7 +8,9 @@ import ttnn import pytest -from models.utility_functions import is_wormhole_b0, comp_allclose_and_pcc, comp_pcc, is_wormhole_b0 +from models.utility_functions import ( + comp_allclose_and_pcc, +) from loguru import logger from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( get_compute_kernel_options, @@ -141,3 +143,22 @@ def forward(self, x): logger.debug(f"Out passing (max_exp_avg_sq)={passing}") logger.debug(f"Output pcc={out}") assert passing + + +@pytest.mark.parametrize( + "params", + ( + # shape, lr, betas, eps, weight_decay, amsgrad, fp32_dest_acc_en + ([32, 32], 0.0, (0.9, 0.999), 1e-06, 0.0, True, True), + ([2, 2, 2, 2, 2, 2, 64, 64], 0.0, (0.9, 0.999), 1e-06, 0.0, False, False), + ), +) +def test_moreh_adam_enable_cache(params, device, use_program_cache): + for i in range(4): + shape, lr, betas, eps, weight_decay, amsgrad, fp32_dest_acc_en = params + if i % 2 == 1: + amsgrad = not amsgrad + + test_moreh_adam(shape, lr, betas, eps, weight_decay, amsgrad, fp32_dest_acc_en, device) + + assert device.num_program_cache_entries() == 2 diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp index 5691351bb67..47b2ad51983 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.cpp @@ -146,7 +146,8 @@ std::tuplearch(), compute_kernel_config, MathFidelity::HiFi4), + }, tensor_args_t{ param_in, grad, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.hpp index e3e79d49898..cdbc40dad91 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.hpp @@ -22,7 +22,7 @@ struct MorehAdamOperation { bool amsgrad = false; const MemoryConfig output_mem_config; - const std::optional compute_kernel_config; + const DeviceComputeKernelConfig compute_kernel_config; }; struct tensor_args_t { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp index d05423c6e49..abe130ec805 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/device/moreh_adam_program_factory.cpp @@ -36,8 +36,7 @@ MorehAdamOperation::ProgramFactory::cached_program_t MorehAdamOperation::Program auto step = operation_attributes.step; auto amsgrad = operation_attributes.amsgrad; - auto compute_kernel_config = - init_device_compute_kernel_config(param_in.device()->arch(), operation_attributes.compute_kernel_config); + auto compute_kernel_config = operation_attributes.compute_kernel_config; uint32_t num_tiles = param_in.volume() / tt::constants::TILE_HW; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp index 0fced419e5c..f6779d926c0 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.cpp @@ -5,6 +5,7 @@ #include "moreh_adam.hpp" #include "ttnn/operations/moreh/moreh_adam/device/moreh_adam_device_operation.hpp" +#include "ttnn/run_operation.hpp" namespace ttnn::operations::moreh::moreh_adam { std::vector> MorehAdam::invoke( @@ -46,4 +47,44 @@ std::vector> MorehAdam::invoke( memory_config, compute_kernel_config); } + +std::vector MorehAdam::create_async_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_inputs) { + const auto& param_in = input_tensors.at(0); + const auto& grad = input_tensors.at(1); + const auto& exp_avg_in = input_tensors.at(2); + const auto& exp_avg_sq_in = input_tensors.at(3); + + const auto& max_exp_avg_sq_in = optional_inputs.at(0); + + return { + Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + Tensor(operation::get_workers_for_op_output({param_in, grad, exp_avg_in, exp_avg_sq_in}, {max_exp_avg_sq_in})), + }; +} + +std::vector MorehAdam::create_async_return_flag( + const Tensor& param_in, + const Tensor& grad, + const Tensor& exp_avg_in, + const Tensor& exp_avg_sq_in, + const std::optional lr, + const std::optional beta1, + const std::optional beta2, + const std::optional eps, + const std::optional weight_decay, + const std::optional step, + const std::optional amsgrad, + const std::optional max_exp_avg_sq_in, + const std::optional param_out, + const std::optional exp_avg_out, + const std::optional exp_avg_sq_out, + const std::optional max_exp_avg_sq_out, + const std::optional& memory_config, + const std::optional& compute_kernel_config) { + // First three are always true, last one depends on amsgrad + return {true, true, true, amsgrad.value_or(false)}; +} } // namespace ttnn::operations::moreh::moreh_adam diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp index fc6ad1e9110..ed36c62ba7b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_adam/moreh_adam.hpp @@ -28,10 +28,33 @@ struct MorehAdam { const std::optional max_exp_avg_sq_out, const std::optional& memory_config, const std::optional& compute_kernel_config); + + static std::vector create_async_output_tensors( + const std::vector& input_tensors, const std::vector>& optional_inputs); + + static std::vector create_async_return_flag( + const Tensor& param_in, + const Tensor& grad, + const Tensor& exp_avg_in, + const Tensor& exp_avg_sq_in, + const std::optional lr, + const std::optional beta1, + const std::optional beta2, + const std::optional eps, + const std::optional weight_decay, + const std::optional step, + const std::optional amsgrad, + const std::optional max_exp_avg_sq_in, + const std::optional param_out, + const std::optional exp_avg_out, + const std::optional exp_avg_sq_out, + const std::optional max_exp_avg_sq_out, + const std::optional& memory_config, + const std::optional& compute_kernel_config); }; } // namespace ttnn::operations::moreh::moreh_adam namespace ttnn { constexpr auto moreh_adam = - ttnn::register_operation<"ttnn::moreh_adam", ttnn::operations::moreh::moreh_adam::MorehAdam>(); + ttnn::register_operation_with_auto_launch_op<"ttnn::moreh_adam", ttnn::operations::moreh::moreh_adam::MorehAdam>(); } From 01392c5b1be41cde0369b45d53cec94809b069bc Mon Sep 17 00:00:00 2001 From: o2buzzle <76864037+o2buzzle@users.noreply.github.com> Date: Tue, 24 Sep 2024 04:15:27 +0000 Subject: [PATCH 33/58] #13035: Revise moreh_sum --- docs/source/ttnn/ttnn/api.rst | 11 +++++++++++ tests/ttnn/profiling/ops_for_profiling.py | 4 ++++ .../device/moreh_int_sum_h_program_factory.cpp | 4 +--- .../device/moreh_int_sum_nc_program_factory.cpp | 4 +--- .../device/moreh_int_sum_w_program_factory.cpp | 4 +--- .../moreh_sum/device/moreh_sum_device_operation.cpp | 10 +++++++++- .../moreh_sum/device/moreh_sum_device_operation.hpp | 2 +- .../moreh_sum/device/moreh_sum_h_program_factory.cpp | 4 +--- .../moreh_sum/device/moreh_sum_nc_program_factory.cpp | 4 +--- .../moreh_sum/device/moreh_sum_w_program_factory.cpp | 4 +--- 10 files changed, 31 insertions(+), 20 deletions(-) diff --git a/docs/source/ttnn/ttnn/api.rst b/docs/source/ttnn/ttnn/api.rst index 12292183f9d..9acf162cc83 100644 --- a/docs/source/ttnn/ttnn/api.rst +++ b/docs/source/ttnn/ttnn/api.rst @@ -440,6 +440,17 @@ Normalization ttnn.layer_norm ttnn.rms_norm + +Moreh Operations +================ + +.. autosummary:: + :toctree: api + :nosignatures: + :template: function.rst + + ttnn.moreh_sum + Transformer =========== diff --git a/tests/ttnn/profiling/ops_for_profiling.py b/tests/ttnn/profiling/ops_for_profiling.py index 78cc71fc9f9..499df79b9cb 100644 --- a/tests/ttnn/profiling/ops_for_profiling.py +++ b/tests/ttnn/profiling/ops_for_profiling.py @@ -339,6 +339,10 @@ def primary_moreh_mean_backward(x, y): ttnn.operations.moreh.mean_backward(x, dim=[0], keepdim=True, input_grad=y) +def primary_moreh_sum(x): + ttnn.operations.moreh.sum(x, dim=[0]) + + def celu_bw(x, y): ttnn.celu_bw(x, y, alpha=1) diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_h_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_h_program_factory.cpp index 98905db038e..26d35744403 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_h_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_h_program_factory.cpp @@ -18,9 +18,7 @@ MorehSumOperation::MorehSumHIntFactory::cached_program_t MorehSumOperation::More auto output = output_tensor; auto output_mem_config = operation_attributes.output_mem_config; - const DeviceComputeKernelConfig& compute_kernel_config = init_device_compute_kernel_config( - input.device()->arch(), operation_attributes.compute_kernel_config, MathFidelity::HiFi4); - ; + const DeviceComputeKernelConfig& compute_kernel_config = operation_attributes.compute_kernel_config; tt::tt_metal::Device* device{input.device()}; tt::tt_metal::Program program{tt::tt_metal::CreateProgram()}; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_nc_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_nc_program_factory.cpp index f5e8d21c80d..905f84fb0fd 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_nc_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_nc_program_factory.cpp @@ -20,9 +20,7 @@ MorehSumOperation::MorehSumNCIntFactory::cached_program_t MorehSumOperation::Mor auto dim = operation_attributes.dim; auto output_mem_config = operation_attributes.output_mem_config; - const DeviceComputeKernelConfig &compute_kernel_config = init_device_compute_kernel_config( - input.device()->arch(), operation_attributes.compute_kernel_config, MathFidelity::HiFi4); - ; + const DeviceComputeKernelConfig &compute_kernel_config = operation_attributes.compute_kernel_config; tt::tt_metal::Device *device{input.device()}; tt::tt_metal::Program program{tt::tt_metal::CreateProgram()}; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_w_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_w_program_factory.cpp index 8e9bf232b7e..98323db7a63 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_w_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_int_sum_w_program_factory.cpp @@ -18,9 +18,7 @@ MorehSumOperation::MorehSumWIntFactory::cached_program_t MorehSumOperation::More auto output = output_tensor; auto output_mem_config = operation_attributes.output_mem_config; - const DeviceComputeKernelConfig& compute_kernel_config = init_device_compute_kernel_config( - input.device()->arch(), operation_attributes.compute_kernel_config, MathFidelity::HiFi4); - ; + const DeviceComputeKernelConfig& compute_kernel_config = operation_attributes.compute_kernel_config; tt::tt_metal::Device* device{input.device()}; tt::tt_metal::Program program{tt::tt_metal::CreateProgram()}; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.cpp index 15937ca2945..fe84ffc3d3c 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.cpp @@ -6,6 +6,7 @@ #include +#include "common/base_types.hpp" #include "tt_dnn/op_library/moreh_helper_functions.hpp" #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/types.hpp" @@ -140,6 +141,13 @@ std::tuple& output, const std::optional& output_mem_config, const std::optional& compute_kernel_config) { - return {{dim, keepdim, output_mem_config.value_or(input.memory_config()), compute_kernel_config}, {input, output}}; + return { + { + dim, + keepdim, + output_mem_config.value_or(input.memory_config()), + init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4), + }, + {input, output}}; } } // namespace ttnn::operations::moreh::moreh_sum diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.hpp index 06dbc6fe253..be1b8b82e3f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_device_operation.hpp @@ -38,7 +38,7 @@ struct MorehSumOperation { const bool keepdim; const MemoryConfig output_mem_config; - const std::optional compute_kernel_config; + const DeviceComputeKernelConfig compute_kernel_config; }; struct tensor_args_t { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp index 22bb26e37b8..b8e13b6a297 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp @@ -20,9 +20,7 @@ MorehSumOperation::MorehSumHFactory::cached_program_t MorehSumOperation::MorehSu auto output = output_tensor; auto output_mem_config = operation_attributes.output_mem_config; - const DeviceComputeKernelConfig &compute_kernel_config = init_device_compute_kernel_config( - input.device()->arch(), operation_attributes.compute_kernel_config, MathFidelity::HiFi4); - ; + const DeviceComputeKernelConfig &compute_kernel_config = operation_attributes.compute_kernel_config; tt::tt_metal::ReduceOpMath reduce_op = tt::tt_metal::ReduceOpMath::SUM; tt::tt_metal::ReduceOpDim reduce_dim = tt::tt_metal::ReduceOpDim::H; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_nc_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_nc_program_factory.cpp index f79ca7cf769..8be521821fd 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_nc_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_nc_program_factory.cpp @@ -19,9 +19,7 @@ MorehSumOperation::MorehSumNCFactory::cached_program_t MorehSumOperation::MorehS auto dim = operation_attributes.dim; auto output_mem_config = operation_attributes.output_mem_config; - const DeviceComputeKernelConfig &compute_kernel_config = init_device_compute_kernel_config( - input.device()->arch(), operation_attributes.compute_kernel_config, MathFidelity::HiFi4); - ; + const DeviceComputeKernelConfig &compute_kernel_config = operation_attributes.compute_kernel_config; auto* device = input.device(); auto program = Program(); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp index 2c3eb23c611..f5d5dd37adb 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp @@ -19,9 +19,7 @@ MorehSumOperation::MorehSumWFactory::cached_program_t MorehSumOperation::MorehSu auto output = output_tensor; auto output_mem_config = operation_attributes.output_mem_config; - const DeviceComputeKernelConfig &compute_kernel_config = init_device_compute_kernel_config( - input.device()->arch(), operation_attributes.compute_kernel_config, MathFidelity::HiFi4); - ; + const DeviceComputeKernelConfig &compute_kernel_config = operation_attributes.compute_kernel_config; tt::tt_metal::ReduceOpMath reduce_op = tt::tt_metal::ReduceOpMath::SUM; tt::tt_metal::ReduceOpDim reduce_dim = tt::tt_metal::ReduceOpDim::W; From 70e8f1747c9b076bc2c2c6752008baecf98a813c Mon Sep 17 00:00:00 2001 From: Kalaivani Baskar <156762498+KalaivaniMCW@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:16:10 +0530 Subject: [PATCH 34/58] #7404: softplus operation passes on WH (#13429) --- .../eltwise/test_eltwise_softplus_inf.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 tests/ttnn/unit_tests/operations/eltwise/test_eltwise_softplus_inf.py diff --git a/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_softplus_inf.py b/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_softplus_inf.py new file mode 100644 index 00000000000..bff785ebee2 --- /dev/null +++ b/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_softplus_inf.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from loguru import logger +import random +import pytest +import torch +import ttnn +import traceback + +from tests.ttnn.python_api_testing.sweep_tests import ttnn_ops +from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc +from models.utility_functions import skip_for_grayskull + + +def run_eltwise_softplus_tests( + input_shape, + dtype, + dlayout, + in_mem_config, + output_mem_config, + beta, + threshold, + data_seed, + device, +): + torch.manual_seed(data_seed) + x = torch.Tensor(size=input_shape[0]).uniform_(-100, 100) + + try: + # get ref result + ref_value = torch.nn.functional.softplus(x, beta=beta, threshold=threshold) + + x = ttnn_ops.setup_ttnn_tensor(x, device, dlayout[0], in_mem_config[0], dtype[0]) + tt_result = ttnn.softplus(x, beta=beta, threshold=threshold, memory_config=output_mem_config) + + tt_result = ttnn_ops.ttnn_tensor_to_torch(tt_result, output_mem_config) + + except Exception as e: + logger.warning(f"Test execution crashed: {e}") + print(traceback.format_exc()) + raise e + + assert len(tt_result.shape) == len(ref_value.shape) + assert tt_result.shape == ref_value.shape + + # compare tt and golden outputs + success, pcc_value = comp_pcc(ref_value, tt_result) + logger.debug(pcc_value) + logger.debug(success) + + assert success + + +test_sweep_args = [ + ( + [(6, 6, 192, 224)], + [ttnn.bfloat16], + [ttnn.TILE_LAYOUT], + [ttnn.DRAM_MEMORY_CONFIG], + ttnn.L1_MEMORY_CONFIG, + 0.0, + 28.125, + 19042500, + ), +] + + +@skip_for_grayskull("Softplus is not available in Grayskull") +@pytest.mark.parametrize( + "input_shape, dtype, dlayout, in_mem_config, out_mem_config, beta, threshold, data_seed", + (test_sweep_args), +) +def test_eltwise_softplus( + input_shape, dtype, dlayout, in_mem_config, out_mem_config, beta, threshold, data_seed, device +): + run_eltwise_softplus_tests( + input_shape, dtype, dlayout, in_mem_config, out_mem_config, beta, threshold, data_seed, device + ) From facce1fce8438e21e5d5cace998ef43d417ae810 Mon Sep 17 00:00:00 2001 From: Mo Date: Wed, 2 Oct 2024 21:25:33 +0000 Subject: [PATCH 35/58] #13386: Support generating perf report for C++ binaries and better artifact folder management Generate perf reports for C++ binaries with `python -m tracy -r -v {path to bin}` - If this is a ttnn run, the same ops report csv will be generated - If this is tt-metal run, device data will be gathered in a device only csv Collect all profiler artifacts in any location regardless of where `TT_METAL_HOME` is set to with `python -m tracy -o {out dir} -r -v {path to bin}` --- scripts/tools_setup_common.sh | 3 ++ .../ops_device_perf/run_op_profiling.py | 14 ++--- tt_metal/tools/profiler/common.hpp | 19 ++++++- tt_metal/tools/profiler/common.py | 38 ++++++-------- tt_metal/tools/profiler/process_model_log.py | 15 ++++-- tt_metal/tools/profiler/process_ops_logs.py | 50 ++++++++++-------- tt_metal/tools/profiler/profiler.cpp | 2 +- tt_metal/tools/profiler/tt_metal_profiler.cpp | 2 +- ttnn/tracy/__init__.py | 31 +++++------ ttnn/tracy/__main__.py | 52 +++++++++++++------ 10 files changed, 135 insertions(+), 91 deletions(-) diff --git a/scripts/tools_setup_common.sh b/scripts/tools_setup_common.sh index 2f04ab3af9a..2d5ed561b83 100644 --- a/scripts/tools_setup_common.sh +++ b/scripts/tools_setup_common.sh @@ -10,6 +10,9 @@ fi PROFILER_SCRIPTS_ROOT=$TT_METAL_HOME/tt_metal/tools/profiler PROFILER_TEST_SCRIPTS_ROOT=$TT_METAL_HOME/tests/tt_metal/tools/profiler PROFILER_ARTIFACTS_DIR=$TT_METAL_HOME/generated/profiler +if [[ "$TT_METAL_PROFILER_DIR" ]]; then + PROFILER_ARTIFACTS_DIR=$TT_METAL_PROFILER_DIR +fi PROFILER_OUTPUT_DIR=$PROFILER_ARTIFACTS_DIR/reports remove_default_log_locations(){ diff --git a/tests/tt_eager/ops_device_perf/run_op_profiling.py b/tests/tt_eager/ops_device_perf/run_op_profiling.py index b4e12f50c02..a8241c42d7a 100644 --- a/tests/tt_eager/ops_device_perf/run_op_profiling.py +++ b/tests/tt_eager/ops_device_perf/run_op_profiling.py @@ -7,14 +7,13 @@ from loguru import logger import numpy as np import pytest +from pathlib import Path -from tt_metal.tools.profiler.process_model_log import run_device_profiler, post_process_ops_log +from tt_metal.tools.profiler.process_model_log import run_device_profiler, post_process_ops_log, get_profiler_folder from models.utility_functions import is_wormhole_b0, is_blackhole -from tt_metal.tools.profiler.common import PROFILER_LOGS_DIR, PROFILER_DEVICE_SIDE_LOG - -profiler_log_path = PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG +from tt_metal.tools.profiler.common import PROFILER_LOGS_DIR, PROFILER_DEVICE_SIDE_LOG, generate_logs_folder from tt_metal.tools.profiler.process_device_log import import_log_run_stats import tt_metal.tools.profiler.device_post_proc_config as device_post_proc_config @@ -75,15 +74,16 @@ def run_op_test(): op_name = "tt::operations::primary::Matmul" duration_cols = ["DEVICE KERNEL DURATION [ns]"] + profiler_out_dir = "op_profiler_results" run_device_profiler( - "pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_1d_2d.py", "op_profiler_results" + "pytest tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_1d_2d.py", profiler_out_dir ) - results = post_process_ops_log("op_profiler_results", duration_cols, False, op_name) + results = post_process_ops_log(profiler_out_dir, duration_cols, False, op_name) kernel_durations_ns = results[duration_cols[0]] setup = device_post_proc_config.default_setup() - setup.deviceInputLog = profiler_log_path + setup.deviceInputLog = generate_logs_folder(get_profiler_folder(profiler_out_dir)) / PROFILER_DEVICE_SIDE_LOG deviceData = import_log_run_stats(setup) freq = deviceData["deviceInfo"]["freq"] freq_to_cycle_ratio = freq / 1000.0 diff --git a/tt_metal/tools/profiler/common.hpp b/tt_metal/tools/profiler/common.hpp index 23174c62855..57aa722708d 100644 --- a/tt_metal/tools/profiler/common.hpp +++ b/tt_metal/tools/profiler/common.hpp @@ -14,7 +14,24 @@ namespace tt_metal { constexpr std::string_view PROFILER_RUNTIME_ROOT_DIR = "generated/profiler/"; constexpr std::string_view PROFILER_LOGS_DIR_NAME = ".logs/"; -inline std::string PROFILER_ZONE_SRC_LOCATIONS_LOG = string(PROFILER_RUNTIME_ROOT_DIR) + string(PROFILER_LOGS_DIR_NAME) + "zone_src_locations.log"; + +inline std::string get_profiler_artifacts_dir() +{ + std::string artifactDir = string(PROFILER_RUNTIME_ROOT_DIR); + const auto PROFILER_ARTIFACTS_DIR = std::getenv("TT_METAL_PROFILER_DIR"); + if (PROFILER_ARTIFACTS_DIR != nullptr) + { + artifactDir = string(PROFILER_ARTIFACTS_DIR) + "/"; + } + return artifactDir; +} + +inline std::string get_profiler_logs_dir() +{ + return get_profiler_artifacts_dir() + string(PROFILER_LOGS_DIR_NAME); +} + +inline std::string PROFILER_ZONE_SRC_LOCATIONS_LOG = get_profiler_logs_dir() + "zone_src_locations.log"; } // namespace tt_metal } // namespace tt diff --git a/tt_metal/tools/profiler/common.py b/tt_metal/tools/profiler/common.py index 284193b9f09..f5309e1b4eb 100644 --- a/tt_metal/tools/profiler/common.py +++ b/tt_metal/tools/profiler/common.py @@ -20,12 +20,11 @@ PROFILER_SCRIPTS_ROOT = TT_METAL_HOME / "tt_metal/tools/profiler" PROFILER_ARTIFACTS_DIR = TT_METAL_HOME / "generated/profiler" +if "TT_METAL_PROFILER_DIR" in ENVS.keys(): + PROFILER_ARTIFACTS_DIR = Path(ENVS["TT_METAL_PROFILER_DIR"]) + PROFILER_BIN_DIR = TT_METAL_HOME / "build/tools/profiler/bin" -PROFILER_LOGS_DIR = PROFILER_ARTIFACTS_DIR / ".logs" -PROFILER_OUTPUT_DIR = PROFILER_ARTIFACTS_DIR / "reports" -PROFILER_OPS_LOGS_DIR = PROFILER_LOGS_DIR / "ops" -PROFILER_LOG_LOCATIONS_RECORD = PROFILER_LOGS_DIR / ".locations.log" TRACY_OPS_TIMES_FILE_NAME = "tracy_ops_times.csv" TRACY_OPS_DATA_FILE_NAME = "tracy_ops_data.csv" @@ -36,6 +35,18 @@ TRACY_CSVEXPROT_TOOL = "csvexport-release" +def generate_logs_folder(outFolder): + return outFolder / ".logs" + + +def generate_reports_folder(outFolder): + return outFolder / "reports" + + +PROFILER_LOGS_DIR = generate_logs_folder(PROFILER_ARTIFACTS_DIR) +PROFILER_OUTPUT_DIR = generate_reports_folder(PROFILER_ARTIFACTS_DIR) + + def rm(path): if not os.path.exists(path): return @@ -47,22 +58,3 @@ def rm(path): def clear_profiler_runtime_artifacts(): rm(PROFILER_ARTIFACTS_DIR) - - -def get_log_locations(): - logLocations = set() - deviceLogLocations = set() - if os.path.isfile(PROFILER_LOG_LOCATIONS_RECORD): - with open(PROFILER_LOG_LOCATIONS_RECORD, "r") as recordFile: - for line in recordFile.readlines(): - logLocation = line.strip() - if os.path.isdir(f"{logLocation}") or os.path.isdir(f"{TT_METAL_HOME}/{logLocation}"): - logLocations.add(logLocation) - tmpSplit = logLocation.rsplit("_", 1) - if tmpSplit[-1] == "device": - deviceLogLocations.add(tmpSplit[0]) - for logLocation in deviceLogLocations: - if logLocation in logLocations: - logLocations.remove(f"{logLocation}_device") - - return list(logLocations) diff --git a/tt_metal/tools/profiler/process_model_log.py b/tt_metal/tools/profiler/process_model_log.py index 4539126aca8..2097fade1f6 100755 --- a/tt_metal/tools/profiler/process_model_log.py +++ b/tt_metal/tools/profiler/process_model_log.py @@ -7,12 +7,17 @@ from pathlib import Path import pandas as pd -from tt_metal.tools.profiler.common import PROFILER_OUTPUT_DIR, PROFILER_SCRIPTS_ROOT +from tt_metal.tools.profiler.common import PROFILER_ARTIFACTS_DIR, PROFILER_SCRIPTS_ROOT, generate_reports_folder + + +def get_profiler_folder(output_logs_subdir): + return PROFILER_ARTIFACTS_DIR / output_logs_subdir def get_latest_ops_log_filename(output_logs_subdir): - runDate = sorted(os.listdir(PROFILER_OUTPUT_DIR / output_logs_subdir))[-1] - filename = PROFILER_OUTPUT_DIR / output_logs_subdir / runDate / f"ops_perf_results_{runDate}.csv" + output_report_dir = generate_reports_folder(get_profiler_folder(output_logs_subdir)) + runDate = sorted(os.listdir(output_report_dir))[-1] + filename = output_report_dir / runDate / f"ops_perf_results_{runDate}.csv" return filename @@ -40,8 +45,8 @@ def post_process_ops_log(output_logs_subdir, columns, sum_vals=True, op_name="", def run_device_profiler(command, output_logs_subdir): - output_logs_dir = PROFILER_OUTPUT_DIR / output_logs_subdir - profiler_cmd = f"python -m tracy -p -r -o {output_logs_dir} -t 5000 -m {command}" + output_profiler_dir = get_profiler_folder(output_logs_subdir) + profiler_cmd = f"python -m tracy -p -r -o {output_profiler_dir} -t 5000 -m {command}" subprocess.run([profiler_cmd], shell=True, check=True) diff --git a/tt_metal/tools/profiler/process_ops_logs.py b/tt_metal/tools/profiler/process_ops_logs.py index d34f96b6b27..37f1f0036be 100755 --- a/tt_metal/tools/profiler/process_ops_logs.py +++ b/tt_metal/tools/profiler/process_ops_logs.py @@ -20,14 +20,15 @@ from tt_metal.tools.profiler.process_device_log import import_log_run_stats import tt_metal.tools.profiler.device_post_proc_config as device_post_proc_config from tt_metal.tools.profiler.common import ( - PROFILER_LOGS_DIR, - PROFILER_OPS_LOGS_DIR, PROFILER_DEVICE_SIDE_LOG, PROFILER_HOST_SIDE_LOG, + PROFILER_ARTIFACTS_DIR, PROFILER_OUTPUT_DIR, TRACY_FILE_NAME, TRACY_OPS_TIMES_FILE_NAME, TRACY_OPS_DATA_FILE_NAME, + generate_logs_folder, + generate_reports_folder, ) yaml.SafeDumper.ignore_aliases = lambda *args: True @@ -77,15 +78,15 @@ def csv_header_format(header): return header.replace("_", " ").upper() -def import_tracy_op_logs(): +def import_tracy_op_logs(logFolder): logger.info(f"Importing ops logs") ops = {} signposts = {} signpostsCount = 0 cached_ops = {} - tracyOpTimesLog = os.path.join(PROFILER_LOGS_DIR, TRACY_OPS_TIMES_FILE_NAME) - tracyOpDataLog = os.path.join(PROFILER_LOGS_DIR, TRACY_OPS_DATA_FILE_NAME) + tracyOpTimesLog = os.path.join(logFolder, TRACY_OPS_TIMES_FILE_NAME) + tracyOpDataLog = os.path.join(logFolder, TRACY_OPS_DATA_FILE_NAME) if not os.path.isfile(tracyOpTimesLog) or not os.path.isfile(tracyOpDataLog): return ops, signposts @@ -190,10 +191,10 @@ def device_log_ops_compare(op): # Append device data to device ops and return the list of mapped device op ref list -def append_device_data(ops, deviceLogFolder): +def append_device_data(ops, logFolder): deviceOps = get_device_op_data(ops) logger.info(f"Appending device data") - deviceTimesLog = os.path.join(deviceLogFolder, PROFILER_DEVICE_SIDE_LOG) + deviceTimesLog = os.path.join(logFolder, PROFILER_DEVICE_SIDE_LOG) if os.path.isfile(deviceTimesLog): setup = device_post_proc_config.default_setup() setup.deviceInputLog = deviceTimesLog @@ -231,7 +232,7 @@ def append_device_data(ops, deviceLogFolder): if "run_host_id" in timeID.keys(): assert ( timeID["run_host_id"] == deviceOp["global_call_count"] - ), f"op id {timeID['run_host_id']} reproted by device is not matching assigned op id {deviceOp['global_call_count']}" + ), f"op id {timeID['run_host_id']} reproted by device {device} is not matching assigned op id {deviceOp['global_call_count']}" if core not in cores: cores.add(core) deviceOp["core_usage"] = {"count": len(cores), "cores": [str(core) for core in cores]} @@ -245,9 +246,9 @@ def append_device_data(ops, deviceLogFolder): def get_device_data_generate_report( - deviceLogFolder, outputFolder, date, nameAppend, export_csv=True, cleanup_device_log=False + logFolder, outputFolder, date, nameAppend, export_csv=True, cleanup_device_log=False ): - deviceTimesLog = os.path.join(deviceLogFolder, PROFILER_DEVICE_SIDE_LOG) + deviceTimesLog = os.path.join(logFolder, PROFILER_DEVICE_SIDE_LOG) devicePreOpTime = {} deviceOps = {} i = 0 @@ -273,8 +274,8 @@ def get_device_data_generate_report( allOpsCSVPath = os.path.join(outFolder, f"{name}.csv") logger.info(f"Copying runtime artifacts") os.system(f"rm -rf {outFolder}; mkdir -p {outFolder}") - if os.path.isfile(f"{PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG}"): - os.system(f"cp {PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG} {outFolder}") + if os.path.isfile(f"{logFolder / PROFILER_DEVICE_SIDE_LOG}"): + os.system(f"cp {logFolder / PROFILER_DEVICE_SIDE_LOG} {outFolder}") if os.path.isfile(deviceTimesLog): logger.info(f"Getting device only ops data") @@ -346,7 +347,7 @@ def get_device_data_generate_report( return rowDicts -def generate_reports(ops, deviceOps, signposts, outputFolder, date, nameAppend): +def generate_reports(ops, deviceOps, signposts, logFolder, outputFolder, date, nameAppend): logger.info(f"OPs' perf analysis is finished! Generating reports ...") outFolder = PROFILER_OUTPUT_DIR if outputFolder: @@ -367,10 +368,10 @@ def generate_reports(ops, deviceOps, signposts, outputFolder, date, nameAppend): logger.info(f"Copying runtime artifacts") os.system(f"rm -rf {outFolder}; mkdir -p {outFolder}") - if os.path.isfile(f"{PROFILER_LOGS_DIR / TRACY_FILE_NAME}"): - os.system(f"cp {PROFILER_LOGS_DIR / TRACY_FILE_NAME} {outFolder}") - if os.path.isfile(f"{PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG}"): - os.system(f"cp {PROFILER_LOGS_DIR / PROFILER_DEVICE_SIDE_LOG} {outFolder}") + if os.path.isfile(f"{logFolder / TRACY_FILE_NAME}"): + os.system(f"cp {logFolder / TRACY_FILE_NAME} {outFolder}") + if os.path.isfile(f"{logFolder / PROFILER_DEVICE_SIDE_LOG}"): + os.system(f"cp {logFolder / PROFILER_DEVICE_SIDE_LOG} {outFolder}") # logger.info(f"Generating OPs yaml") # allOpsYAMLPath = os.path.join(outFolder, f"{name}_all_ops.yaml") @@ -566,14 +567,19 @@ def row_compare(row): def process_ops(output_folder, name_append, date): - ops, signposts = import_tracy_op_logs() + if not output_folder: + output_folder = PROFILER_ARTIFACTS_DIR + logFolder = generate_logs_folder(output_folder) + reportFolder = generate_reports_folder(output_folder) + + ops, signposts = import_tracy_op_logs(logFolder) if ops: - deviceOps = append_device_data(ops, PROFILER_LOGS_DIR) - generate_reports(ops, deviceOps, signposts, output_folder, date, name_append) + deviceOps = append_device_data(ops, logFolder) + generate_reports(ops, deviceOps, signposts, logFolder, reportFolder, date, name_append) else: - deviceOps = get_device_data_generate_report(PROFILER_LOGS_DIR, output_folder, date, name_append) + deviceOps = get_device_data_generate_report(logFolder, reportFolder, date, name_append) @click.command() @@ -581,6 +587,8 @@ def process_ops(output_folder, name_append, date): @click.option("-n", "--name-append", type=str, help="Name to be appended to default csv name") @click.option("--date", default=False, is_flag=True, help="Append date to output files") def main(output_folder, name_append, date): + if output_folder: + output_folder = Path(output_folder) process_ops(output_folder, name_append, date) diff --git a/tt_metal/tools/profiler/profiler.cpp b/tt_metal/tools/profiler/profiler.cpp index f55611b1a3a..ff4c7078723 100644 --- a/tt_metal/tools/profiler/profiler.cpp +++ b/tt_metal/tools/profiler/profiler.cpp @@ -287,7 +287,7 @@ DeviceProfiler::DeviceProfiler(const bool new_logs) { #if defined(TRACY_ENABLE) ZoneScopedC(tracy::Color::Green); - output_dir = std::filesystem::path(string(PROFILER_RUNTIME_ROOT_DIR) + string(PROFILER_LOGS_DIR_NAME)); + output_dir = std::filesystem::path(get_profiler_logs_dir()); std::filesystem::create_directories(output_dir); std::filesystem::path log_path = output_dir / DEVICE_SIDE_LOG; diff --git a/tt_metal/tools/profiler/tt_metal_profiler.cpp b/tt_metal/tools/profiler/tt_metal_profiler.cpp index 40d38a13c62..94328003032 100644 --- a/tt_metal/tools/profiler/tt_metal_profiler.cpp +++ b/tt_metal/tools/profiler/tt_metal_profiler.cpp @@ -124,7 +124,7 @@ void syncDeviceHost(Device *device, CoreCoord logical_core, std::shared_ptr Date: Wed, 9 Oct 2024 09:45:20 -0400 Subject: [PATCH 36/58] #13556: Unvendor magic_enum and upgrade to latest release (v0.9.6) (#13557) * #13556: Unvendor magic_enum and upgrade to latest release (v0.9.6) * #13556: Use <> to designate it's third-party --- .github/workflows/build-artifact.yaml | 1 + cmake/dependencies.cmake | 6 + tests/CMakeLists.txt | 3 +- tests/tt_eager/ops/test_bcast_op.cpp | 2 +- tests/tt_eager/ops/test_sfpu.cpp | 2 +- .../tt_metal/unit_tests_common/CMakeLists.txt | 2 +- tt_metal/CMakeLists.txt | 3 +- tt_metal/common/CMakeLists.txt | 2 +- tt_metal/common/tt_backend_api_types.cpp | 2 +- tt_metal/impl/allocator/allocator.cpp | 2 +- .../impl/dispatch/command_queue_interface.hpp | 2 +- tt_metal/impl/dispatch/data_collection.cpp | 2 +- .../eltwise_binary/eltwise_binary.cpp | 2 +- .../matmul_common/bmm_op.hpp | 2 +- tt_metal/third_party/magic_enum/LICENSE | 21 - tt_metal/third_party/magic_enum/LICENSE.txt | 30 - .../third_party/magic_enum/magic_enum.hpp | 1438 ----------------- tt_metal/tools/profiler/op_profiler.hpp | 2 +- tt_metal/tt_stl/reflection.hpp | 2 +- ttnn/CMakeLists.txt | 1 - ttnn/cpp/pybind11/export_enum.hpp | 2 +- ttnn/cpp/ttnn/core.hpp | 2 +- .../op_library/moreh_helper_functions.cpp | 2 +- .../core/to_layout/to_layout_op.hpp | 2 +- .../sharded/reshard/device/reshard_op.cpp | 2 +- .../binary/device/binary_composite_op.cpp | 2 +- .../binary/device/binary_composite_op.hpp | 2 +- .../binary/device/binary_device_operation.hpp | 2 +- .../binary_backward/binary_backward.cpp | 2 +- .../device/complex_binary_op.hpp | 2 +- .../complex_unary/device/complex_unary_op.hpp | 2 +- .../device/complex_unary_backward_op.hpp | 2 +- .../eltwise/ternary/ternary_composite_op.hpp | 2 +- .../ternary_backward/ternary_backward.cpp | 2 +- .../unary/device/unary_composite_op.cpp | 2 +- .../unary/device/unary_composite_op.hpp | 2 +- .../unary/device/unary_device_operation.cpp | 2 +- .../eltwise/unary_backward/unary_backward.cpp | 2 +- .../device/layernorm_post_all_gather_op.cpp | 2 +- .../device/layernorm_pre_all_gather_op.cpp | 2 +- 40 files changed, 41 insertions(+), 1526 deletions(-) delete mode 100644 tt_metal/third_party/magic_enum/LICENSE delete mode 100644 tt_metal/third_party/magic_enum/LICENSE.txt delete mode 100644 tt_metal/third_party/magic_enum/magic_enum.hpp diff --git a/.github/workflows/build-artifact.yaml b/.github/workflows/build-artifact.yaml index 45f3affb431..8851052019e 100644 --- a/.github/workflows/build-artifact.yaml +++ b/.github/workflows/build-artifact.yaml @@ -62,6 +62,7 @@ jobs: build-artifact: needs: build-docker-image if: always() + timeout-minutes: 30 strategy: matrix: arch: ${{ fromJson(inputs.arch || '["grayskull", "wormhole_b0", "blackhole"]') }} diff --git a/cmake/dependencies.cmake b/cmake/dependencies.cmake index d222640ea41..c8b7384d75e 100644 --- a/cmake/dependencies.cmake +++ b/cmake/dependencies.cmake @@ -55,3 +55,9 @@ CPMAddPackage( GITHUB_REPOSITORY boost-ext/reflect GIT_TAG v1.1.1 ) + +CPMAddPackage( + NAME magic_enum + GITHUB_REPOSITORY Neargye/magic_enum + GIT_TAG v0.9.6 +) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 24088c4ba89..1738b45d300 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,7 +2,7 @@ enable_testing() include(GoogleTest) add_library(test_common_libs INTERFACE) -target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main) +target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main magic_enum) if(TT_METAL_BUILD_TESTS) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/tt_metal) @@ -12,4 +12,3 @@ if(TTNN_BUILD_TESTS) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_eager) # this should go away and be replaced with link to ttnn add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ttnn/unit_tests/gtests) endif(TTNN_BUILD_TESTS) - diff --git a/tests/tt_eager/ops/test_bcast_op.cpp b/tests/tt_eager/ops/test_bcast_op.cpp index 9de793a85ba..0a7fe9ecd98 100644 --- a/tests/tt_eager/ops/test_bcast_op.cpp +++ b/tests/tt_eager/ops/test_bcast_op.cpp @@ -6,7 +6,7 @@ #include "ttnn/tensor/tensor.hpp" #include "ttnn/operations/data_movement/bcast/bcast.hpp" #include "common/constants.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include #include diff --git a/tests/tt_eager/ops/test_sfpu.cpp b/tests/tt_eager/ops/test_sfpu.cpp index 942ea4bfb9d..40543974d05 100644 --- a/tests/tt_eager/ops/test_sfpu.cpp +++ b/tests/tt_eager/ops/test_sfpu.cpp @@ -8,7 +8,7 @@ #include #include -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "tt_metal/host_api.hpp" #include "tt_metal/detail/tt_metal.hpp" diff --git a/tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt b/tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt index 781ac6314ba..d488cd254b1 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt @@ -29,7 +29,7 @@ set(UNIT_TESTS_COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_link_training.cpp ) add_library(unit_tests_common_o OBJECT ${UNIT_TESTS_COMMON_SRC}) -target_link_libraries(unit_tests_common_o PUBLIC compiler_flags metal_header_directories gtest gtest_main) +target_link_libraries(unit_tests_common_o PUBLIC compiler_flags metal_header_directories gtest gtest_main magic_enum) target_include_directories(unit_tests_common_o PUBLIC ${UMD_HOME} ${PROJECT_SOURCE_DIR} diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index d5c9fa860e1..32d09bcc0e9 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -23,10 +23,9 @@ set(TT_METAL_OBJECTS add_library(tt_metal ${TT_METAL_OBJECTS}) -target_link_libraries(tt_metal PUBLIC metal_header_directories umd_device metal_common_libs) +target_link_libraries(tt_metal PUBLIC metal_header_directories umd_device metal_common_libs magic_enum) target_precompile_headers(tt_metal PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/third_party/magic_enum/magic_enum.hpp ${CMAKE_CURRENT_SOURCE_DIR}/third_party/tracy/public/tracy/Tracy.hpp ${CMAKE_CURRENT_SOURCE_DIR}/third_party/fmt/fmt/core.h ${CMAKE_CURRENT_SOURCE_DIR}/third_party/fmt/fmt/format.h diff --git a/tt_metal/common/CMakeLists.txt b/tt_metal/common/CMakeLists.txt index 1c04822539f..d5621964729 100644 --- a/tt_metal/common/CMakeLists.txt +++ b/tt_metal/common/CMakeLists.txt @@ -8,7 +8,7 @@ set(COMMON_SRCS add_library(common OBJECT ${COMMON_SRCS}) target_link_libraries(common PRIVATE yaml-cpp::yaml-cpp) -target_link_libraries(common PUBLIC compiler_flags metal_header_directories) +target_link_libraries(common PUBLIC compiler_flags metal_header_directories magic_enum) target_include_directories(common PUBLIC ${UMD_HOME} diff --git a/tt_metal/common/tt_backend_api_types.cpp b/tt_metal/common/tt_backend_api_types.cpp index 61571220b5f..7bef74589be 100644 --- a/tt_metal/common/tt_backend_api_types.cpp +++ b/tt_metal/common/tt_backend_api_types.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt_backend_api_types.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include std::string tt::get_string(tt::ARCH arch) { switch (arch) { diff --git a/tt_metal/impl/allocator/allocator.cpp b/tt_metal/impl/allocator/allocator.cpp index 6c7fd50acff..6800114e4ae 100644 --- a/tt_metal/impl/allocator/allocator.cpp +++ b/tt_metal/impl/allocator/allocator.cpp @@ -4,7 +4,7 @@ #include "tt_metal/impl/allocator/allocator.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "tt_metal/common/math.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/hostdevcommon/common_runtime_address_map.h" diff --git a/tt_metal/impl/dispatch/command_queue_interface.hpp b/tt_metal/impl/dispatch/command_queue_interface.hpp index ec6e1a964c7..e1a8f3f0b0f 100644 --- a/tt_metal/impl/dispatch/command_queue_interface.hpp +++ b/tt_metal/impl/dispatch/command_queue_interface.hpp @@ -13,7 +13,7 @@ #include "tt_metal/llrt/llrt.hpp" #include "tt_metal/llrt/hal.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include using namespace tt::tt_metal; diff --git a/tt_metal/impl/dispatch/data_collection.cpp b/tt_metal/impl/dispatch/data_collection.cpp index 0b8aa5d7c09..516f7a27912 100644 --- a/tt_metal/impl/dispatch/data_collection.cpp +++ b/tt_metal/impl/dispatch/data_collection.cpp @@ -7,7 +7,7 @@ #include "tt_metal/impl/kernels/kernel.hpp" #include "tt_metal/common/core_coord.h" -#include "third_party/magic_enum/magic_enum.hpp" +#include using namespace tt; diff --git a/tt_metal/programming_examples/eltwise_binary/eltwise_binary.cpp b/tt_metal/programming_examples/eltwise_binary/eltwise_binary.cpp index 2563a02d675..73ab8f59cc1 100644 --- a/tt_metal/programming_examples/eltwise_binary/eltwise_binary.cpp +++ b/tt_metal/programming_examples/eltwise_binary/eltwise_binary.cpp @@ -11,7 +11,7 @@ #include "common/bfloat16.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include using namespace tt; using namespace tt::tt_metal; diff --git a/tt_metal/programming_examples/matmul_common/bmm_op.hpp b/tt_metal/programming_examples/matmul_common/bmm_op.hpp index 2fbfd9a1001..0a04afe380a 100644 --- a/tt_metal/programming_examples/matmul_common/bmm_op.hpp +++ b/tt_metal/programming_examples/matmul_common/bmm_op.hpp @@ -15,7 +15,7 @@ #include "tt_metal/common/bfloat16.hpp" #include "third_party/umd/device/tt_xy_pair.h" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "tt_metal/common/work_split.hpp" diff --git a/tt_metal/third_party/magic_enum/LICENSE b/tt_metal/third_party/magic_enum/LICENSE deleted file mode 100644 index f58f710f173..00000000000 --- a/tt_metal/third_party/magic_enum/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2019 - 2023 Daniil Goncharov - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/tt_metal/third_party/magic_enum/LICENSE.txt b/tt_metal/third_party/magic_enum/LICENSE.txt deleted file mode 100644 index fd909c00b93..00000000000 --- a/tt_metal/third_party/magic_enum/LICENSE.txt +++ /dev/null @@ -1,30 +0,0 @@ - __ __ _ ______ _____ -| \/ | (_) | ____| / ____|_ _ -| \ / | __ _ __ _ _ ___ | |__ _ __ _ _ _ __ ___ | | _| |_ _| |_ -| |\/| |/ _` |/ _` | |/ __| | __| | '_ \| | | | '_ ` _ \ | | |_ _|_ _| -| | | | (_| | (_| | | (__ | |____| | | | |_| | | | | | | | |____|_| |_| -|_| |_|\__,_|\__, |_|\___| |______|_| |_|\__,_|_| |_| |_| \_____| - __/ | https://github.com/Neargye/magic_enum - |___/ version 0.9.3 - -Licensed under the MIT License . -SPDX-License-Identifier: MIT -Copyright (c) 2019 - 2023 Daniil Goncharov . - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/tt_metal/third_party/magic_enum/magic_enum.hpp b/tt_metal/third_party/magic_enum/magic_enum.hpp deleted file mode 100644 index 3769be1547c..00000000000 --- a/tt_metal/third_party/magic_enum/magic_enum.hpp +++ /dev/null @@ -1,1438 +0,0 @@ -// __ __ _ ______ _____ -// | \/ | (_) | ____| / ____|_ _ -// | \ / | __ _ __ _ _ ___ | |__ _ __ _ _ _ __ ___ | | _| |_ _| |_ -// | |\/| |/ _` |/ _` | |/ __| | __| | '_ \| | | | '_ ` _ \ | | |_ _|_ _| -// | | | | (_| | (_| | | (__ | |____| | | | |_| | | | | | | | |____|_| |_| -// |_| |_|\__,_|\__, |_|\___| |______|_| |_|\__,_|_| |_| |_| \_____| -// __/ | https://github.com/Neargye/magic_enum -// |___/ version 0.9.3 -// -// Licensed under the MIT License . -// SPDX-License-Identifier: MIT -// Copyright (c) 2019 - 2023 Daniil Goncharov . -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#ifndef NEARGYE_MAGIC_ENUM_HPP -#define NEARGYE_MAGIC_ENUM_HPP - -#define MAGIC_ENUM_VERSION_MAJOR 0 -#define MAGIC_ENUM_VERSION_MINOR 9 -#define MAGIC_ENUM_VERSION_PATCH 3 - -#include -#include -#include -#include -#include -#include -#include - -#if defined(MAGIC_ENUM_CONFIG_FILE) -# include MAGIC_ENUM_CONFIG_FILE -#endif - -#if !defined(MAGIC_ENUM_USING_ALIAS_OPTIONAL) -# include -#endif -#if !defined(MAGIC_ENUM_USING_ALIAS_STRING) -# include -#endif -#if !defined(MAGIC_ENUM_USING_ALIAS_STRING_VIEW) -# include -#endif - -#if defined(MAGIC_ENUM_NO_ASSERT) -# define MAGIC_ENUM_ASSERT(...) static_cast(0) -#else -# include -# define MAGIC_ENUM_ASSERT(...) assert((__VA_ARGS__)) -#endif - -#if defined(__clang__) -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wunknown-warning-option" -# pragma clang diagnostic ignored "-Wenum-constexpr-conversion" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wmaybe-uninitialized" // May be used uninitialized 'return {};'. -#elif defined(_MSC_VER) -# pragma warning(push) -# pragma warning(disable : 26495) // Variable 'static_str::chars_' is uninitialized. -# pragma warning(disable : 28020) // Arithmetic overflow: Using operator '-' on a 4 byte value and then casting the result to a 8 byte value. -# pragma warning(disable : 26451) // The expression '0<=_Param_(1)&&_Param_(1)<=1-1' is not true at this call. -# pragma warning(disable : 4514) // Unreferenced inline function has been removed. -#endif - -// Checks magic_enum compiler compatibility. -#if defined(__clang__) && __clang_major__ >= 5 || defined(__GNUC__) && __GNUC__ >= 9 || defined(_MSC_VER) && _MSC_VER >= 1910 || defined(__RESHARPER__) -# undef MAGIC_ENUM_SUPPORTED -# define MAGIC_ENUM_SUPPORTED 1 -#endif - -// Checks magic_enum compiler aliases compatibility. -#if defined(__clang__) && __clang_major__ >= 5 || defined(__GNUC__) && __GNUC__ >= 9 || defined(_MSC_VER) && _MSC_VER >= 1920 -# undef MAGIC_ENUM_SUPPORTED_ALIASES -# define MAGIC_ENUM_SUPPORTED_ALIASES 1 -#endif - -// Enum value must be greater or equals than MAGIC_ENUM_RANGE_MIN. By default MAGIC_ENUM_RANGE_MIN = -128. -// If need another min range for all enum types by default, redefine the macro MAGIC_ENUM_RANGE_MIN. -#if !defined(MAGIC_ENUM_RANGE_MIN) -# define MAGIC_ENUM_RANGE_MIN -128 -#endif - -// Enum value must be less or equals than MAGIC_ENUM_RANGE_MAX. By default MAGIC_ENUM_RANGE_MAX = 128. -// If need another max range for all enum types by default, redefine the macro MAGIC_ENUM_RANGE_MAX. -#if !defined(MAGIC_ENUM_RANGE_MAX) -# define MAGIC_ENUM_RANGE_MAX 127 -#endif - -// Improve ReSharper C++ intellisense performance with builtins, avoiding unnecessary template instantiations. -#if defined(__RESHARPER__) -# undef MAGIC_ENUM_GET_ENUM_NAME_BUILTIN -# undef MAGIC_ENUM_GET_TYPE_NAME_BUILTIN -# if __RESHARPER__ >= 20230100 -# define MAGIC_ENUM_GET_ENUM_NAME_BUILTIN(V) __rscpp_enumerator_name(V) -# define MAGIC_ENUM_GET_TYPE_NAME_BUILTIN(T) __rscpp_type_name() -# else -# define MAGIC_ENUM_GET_ENUM_NAME_BUILTIN(V) nullptr -# define MAGIC_ENUM_GET_TYPE_NAME_BUILTIN(T) nullptr -# endif -#endif - -namespace magic_enum { - -// If need another optional type, define the macro MAGIC_ENUM_USING_ALIAS_OPTIONAL. -#if defined(MAGIC_ENUM_USING_ALIAS_OPTIONAL) -MAGIC_ENUM_USING_ALIAS_OPTIONAL -#else -using std::optional; -#endif - -// If need another string_view type, define the macro MAGIC_ENUM_USING_ALIAS_STRING_VIEW. -#if defined(MAGIC_ENUM_USING_ALIAS_STRING_VIEW) -MAGIC_ENUM_USING_ALIAS_STRING_VIEW -#else -using std::string_view; -#endif - -// If need another string type, define the macro MAGIC_ENUM_USING_ALIAS_STRING. -#if defined(MAGIC_ENUM_USING_ALIAS_STRING) -MAGIC_ENUM_USING_ALIAS_STRING -#else -using std::string; -#endif - -using char_type = string_view::value_type; -static_assert(std::is_same_v, "magic_enum::customize requires same string_view::value_type and string::value_type"); -static_assert([] { - if constexpr (std::is_same_v) { - constexpr const char c[] = "abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789|"; - constexpr const wchar_t wc[] = L"abcdefghijklmnopqrstuvwxyz_ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789|"; - static_assert(std::size(c) == std::size(wc), "magic_enum::customize identifier characters are multichars in wchar_t."); - - for (std::size_t i = 0; i < std::size(c); ++i) { - if (c[i] != wc[i]) { - return false; - } - } - } - return true; -} (), "magic_enum::customize wchar_t is not compatible with ASCII."); - -namespace customize { - -// Enum value must be in range [MAGIC_ENUM_RANGE_MIN, MAGIC_ENUM_RANGE_MAX]. By default MAGIC_ENUM_RANGE_MIN = -128, MAGIC_ENUM_RANGE_MAX = 128. -// If need another range for all enum types by default, redefine the macro MAGIC_ENUM_RANGE_MIN and MAGIC_ENUM_RANGE_MAX. -// If need another range for specific enum type, add specialization enum_range for necessary enum type. -template -struct enum_range { - static_assert(std::is_enum_v, "magic_enum::customize::enum_range requires enum type."); - static constexpr int min = MAGIC_ENUM_RANGE_MIN; - static constexpr int max = MAGIC_ENUM_RANGE_MAX; - static_assert(max > min, "magic_enum::customize::enum_range requires max > min."); -}; - -static_assert(MAGIC_ENUM_RANGE_MAX > MAGIC_ENUM_RANGE_MIN, "MAGIC_ENUM_RANGE_MAX must be greater than MAGIC_ENUM_RANGE_MIN."); -static_assert((MAGIC_ENUM_RANGE_MAX - MAGIC_ENUM_RANGE_MIN) < (std::numeric_limits::max)(), "MAGIC_ENUM_RANGE must be less than UINT16_MAX."); - -namespace detail { - -enum class customize_tag { - default_tag, - invalid_tag, - custom_tag -}; - -} // namespace magic_enum::customize::detail - -class customize_t : public std::pair { - public: - constexpr customize_t(string_view srt) : std::pair{detail::customize_tag::custom_tag, srt} {} - constexpr customize_t(const char_type* srt) : customize_t{string_view{srt}} {} - constexpr customize_t(detail::customize_tag tag) : std::pair{tag, string_view{}} { - MAGIC_ENUM_ASSERT(tag != detail::customize_tag::custom_tag); - } -}; - -// Default customize. -inline constexpr auto default_tag = customize_t{detail::customize_tag::default_tag}; -// Invalid customize. -inline constexpr auto invalid_tag = customize_t{detail::customize_tag::invalid_tag}; - -// If need custom names for enum, add specialization enum_name for necessary enum type. -template -constexpr customize_t enum_name(E) noexcept { - return default_tag; -} - -// If need custom type name for enum, add specialization enum_type_name for necessary enum type. -template -constexpr customize_t enum_type_name() noexcept { - return default_tag; -} - -} // namespace magic_enum::customize - -namespace detail { - -template -struct supported -#if defined(MAGIC_ENUM_SUPPORTED) && MAGIC_ENUM_SUPPORTED || defined(MAGIC_ENUM_NO_CHECK_SUPPORT) - : std::true_type {}; -#else - : std::false_type {}; -#endif - -template , std::enable_if_t, int> = 0> -using enum_constant = std::integral_constant; - -template -inline constexpr bool always_false_v = false; - -template -struct has_is_flags : std::false_type {}; - -template -struct has_is_flags::is_flags)>> : std::bool_constant::is_flags)>>> {}; - -template -struct range_min : std::integral_constant {}; - -template -struct range_min::min)>> : std::integral_constant::min), customize::enum_range::min> {}; - -template -struct range_max : std::integral_constant {}; - -template -struct range_max::max)>> : std::integral_constant::max), customize::enum_range::max> {}; - -struct str_view { - const char* str_ = nullptr; - std::size_t size_ = 0; -}; - -template -class static_str { - public: - constexpr explicit static_str(str_view str) noexcept : static_str{str.str_, std::make_integer_sequence{}} { - MAGIC_ENUM_ASSERT(str.size_ == N); - } - - constexpr explicit static_str(string_view str) noexcept : static_str{str.data(), std::make_integer_sequence{}} { - MAGIC_ENUM_ASSERT(str.size() == N); - } - - constexpr const char_type* data() const noexcept { return chars_; } - - constexpr std::uint16_t size() const noexcept { return N; } - - constexpr operator string_view() const noexcept { return {data(), size()}; } - - private: - template - constexpr static_str(const char* str, std::integer_sequence) noexcept : chars_{static_cast(str[I])..., static_cast('\0')} {} - - template - constexpr static_str(string_view str, std::integer_sequence) noexcept : chars_{str[I]..., static_cast('\0')} {} - - char_type chars_[static_cast(N) + 1]; -}; - -template <> -class static_str<0> { - public: - constexpr explicit static_str() = default; - - constexpr explicit static_str(str_view) noexcept {} - - constexpr explicit static_str(string_view) noexcept {} - - constexpr const char_type* data() const noexcept { return nullptr; } - - constexpr std::uint16_t size() const noexcept { return 0; } - - constexpr operator string_view() const noexcept { return {}; } -}; - -template > -class case_insensitive { - static constexpr char_type to_lower(char_type c) noexcept { - return (c >= static_cast('A') && c <= static_cast('Z')) ? static_cast(c + (static_cast('a') - static_cast('A'))) : c; - } - - public: - template - constexpr auto operator()(L lhs,R rhs) const noexcept -> std::enable_if_t, char_type> && std::is_same_v, char_type>, bool> { - return Op{}(to_lower(lhs), to_lower(rhs)); - } -}; - -constexpr std::size_t find(string_view str, char_type c) noexcept { -#if defined(__clang__) && __clang_major__ < 9 && defined(__GLIBCXX__) || defined(_MSC_VER) && _MSC_VER < 1920 && !defined(__clang__) -// https://stackoverflow.com/questions/56484834/constexpr-stdstring-viewfind-last-of-doesnt-work-on-clang-8-with-libstdc -// https://developercommunity.visualstudio.com/content/problem/360432/vs20178-regression-c-failed-in-test.html - constexpr bool workaround = true; -#else - constexpr bool workaround = false; -#endif - - if constexpr (workaround) { - for (std::size_t i = 0; i < str.size(); ++i) { - if (str[i] == c) { - return i; - } - } - - return string_view::npos; - } else { - return str.find(c); - } -} - -template -constexpr bool is_default_predicate() noexcept { - return std::is_same_v, std::equal_to> || - std::is_same_v, std::equal_to<>>; -} - -template -constexpr bool is_nothrow_invocable() { - return is_default_predicate() || - std::is_nothrow_invocable_r_v; -} - -template -constexpr bool cmp_equal(string_view lhs, string_view rhs, [[maybe_unused]] BinaryPredicate&& p) noexcept(is_nothrow_invocable()) { -#if defined(_MSC_VER) && _MSC_VER < 1920 && !defined(__clang__) - // https://developercommunity.visualstudio.com/content/problem/360432/vs20178-regression-c-failed-in-test.html - // https://developercommunity.visualstudio.com/content/problem/232218/c-constexpr-string-view.html - constexpr bool workaround = true; -#else - constexpr bool workaround = false; -#endif - - if constexpr (!is_default_predicate() || workaround) { - if (lhs.size() != rhs.size()) { - return false; - } - - const auto size = lhs.size(); - for (std::size_t i = 0; i < size; ++i) { - if (!p(lhs[i], rhs[i])) { - return false; - } - } - - return true; - } else { - return lhs == rhs; - } -} - -template -constexpr bool cmp_less(L lhs, R rhs) noexcept { - static_assert(std::is_integral_v && std::is_integral_v, "magic_enum::detail::cmp_less requires integral type."); - - if constexpr (std::is_signed_v == std::is_signed_v) { - // If same signedness (both signed or both unsigned). - return lhs < rhs; - } else if constexpr (std::is_same_v) { // bool special case - return static_cast(lhs) < rhs; - } else if constexpr (std::is_same_v) { // bool special case - return lhs < static_cast(rhs); - } else if constexpr (std::is_signed_v) { - // If 'right' is negative, then result is 'false', otherwise cast & compare. - return rhs > 0 && lhs < static_cast>(rhs); - } else { - // If 'left' is negative, then result is 'true', otherwise cast & compare. - return lhs < 0 || static_cast>(lhs) < rhs; - } -} - -template -constexpr I log2(I value) noexcept { - static_assert(std::is_integral_v, "magic_enum::detail::log2 requires integral type."); - - if constexpr (std::is_same_v) { // bool special case - return MAGIC_ENUM_ASSERT(false), value; - } else { - auto ret = I{0}; - for (; value > I{1}; value >>= I{1}, ++ret) {} - - return ret; - } -} - -#if defined(__cpp_lib_array_constexpr) && __cpp_lib_array_constexpr >= 201603L -# define MAGIC_ENUM_ARRAY_CONSTEXPR 1 -#else -template -constexpr std::array, N> to_array(T (&a)[N], std::index_sequence) noexcept { - return {{a[I]...}}; -} -#endif - -template -inline constexpr bool is_enum_v = std::is_enum_v && std::is_same_v>; - -template -constexpr auto n() noexcept { - static_assert(is_enum_v, "magic_enum::detail::n requires enum type."); - - if constexpr (supported::value) { -#if defined(MAGIC_ENUM_GET_TYPE_NAME_BUILTIN) - constexpr auto name_ptr = MAGIC_ENUM_GET_TYPE_NAME_BUILTIN(E); - constexpr auto name = name_ptr ? str_view{name_ptr, std::char_traits::length(name_ptr)} : str_view{}; -#elif defined(__clang__) - auto name = str_view{__PRETTY_FUNCTION__ + 34, sizeof(__PRETTY_FUNCTION__) - 36}; -#elif defined(__GNUC__) - auto name = str_view{__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__) - 1}; - if (name.str_[name.size_ - 1] == ']') { - name.size_ -= 50; - name.str_ += 49; - } else { - name.size_ -= 40; - name.str_ += 37; - } -#elif defined(_MSC_VER) - auto name = str_view{__FUNCSIG__ + 40, sizeof(__FUNCSIG__) - 57}; -#else - auto name = str_view{}; -#endif - std::size_t p = 0; - for (std::size_t i = name.size_; i > 0; --i) { - if (name.str_[i] == ':') { - p = i + 1; - break; - } - } - if (p > 0) { - name.size_ -= p; - name.str_ += p; - } - return name; - } else { - return str_view{}; // Unsupported compiler or Invalid customize. - } -} - -template -constexpr auto type_name() noexcept { - [[maybe_unused]] constexpr auto custom = customize::enum_type_name(); - static_assert(std::is_same_v, customize::customize_t>, "magic_enum::customize requires customize_t type."); - if constexpr (custom.first == customize::detail::customize_tag::custom_tag) { - constexpr auto name = custom.second; - static_assert(!name.empty(), "magic_enum::customize requires not empty string."); - return static_str{name}; - } else if constexpr (custom.first == customize::detail::customize_tag::invalid_tag) { - return static_str<0>{}; - } else if constexpr (custom.first == customize::detail::customize_tag::default_tag) { - constexpr auto name = n(); - return static_str{name}; - } else { - static_assert(always_false_v, "magic_enum::customize invalid."); - } -} - -template -inline constexpr auto type_name_v = type_name(); - -template -constexpr auto n() noexcept { - static_assert(is_enum_v, "magic_enum::detail::n requires enum type."); - - if constexpr (supported::value) { -#if defined(MAGIC_ENUM_GET_ENUM_NAME_BUILTIN) - constexpr auto name_ptr = MAGIC_ENUM_GET_ENUM_NAME_BUILTIN(V); - auto name = name_ptr ? str_view{name_ptr, std::char_traits::length(name_ptr)} : str_view{}; -#elif defined(__clang__) - auto name = str_view{__PRETTY_FUNCTION__ + 34, sizeof(__PRETTY_FUNCTION__) - 36}; - if (name.size_ > 22 && name.str_[0] == '(' && name.str_[1] == 'a' && name.str_[10] == ' ' && name.str_[22] == ':') { - name.size_ -= 23; - name.str_ += 23; - } - if (name.str_[0] == '(' || name.str_[0] == '-' || (name.str_[0] >= '0' && name.str_[0] <= '9')) { - name = str_view{}; - } -#elif defined(__GNUC__) - auto name = str_view{__PRETTY_FUNCTION__, sizeof(__PRETTY_FUNCTION__) - 1}; - if (name.str_[name.size_ - 1] == ']') { - name.size_ -= 55; - name.str_ += 54; - } else { - name.size_ -= 40; - name.str_ += 37; - } - if (name.str_[0] == '(') { - name = str_view{}; - } -#elif defined(_MSC_VER) - str_view name; - if ((__FUNCSIG__[5] == '_' && __FUNCSIG__[35] != '(') || (__FUNCSIG__[5] == 'c' && __FUNCSIG__[41] != '(')) { - name = str_view{__FUNCSIG__ + 35, sizeof(__FUNCSIG__) - 52}; - } -#else - auto name = str_view{}; -#endif - std::size_t p = 0; - for (std::size_t i = name.size_; i > 0; --i) { - if (name.str_[i] == ':') { - p = i + 1; - break; - } - } - if (p > 0) { - name.size_ -= p; - name.str_ += p; - } - return name; - } else { - return str_view{}; // Unsupported compiler or Invalid customize. - } -} - -#if defined(_MSC_VER) && !defined(__clang__) && _MSC_VER < 1920 -# define MAGIC_ENUM_VS_2017_WORKAROUND 1 -#endif - -#if defined(MAGIC_ENUM_VS_2017_WORKAROUND) -template -constexpr auto n() noexcept { - static_assert(is_enum_v, "magic_enum::detail::n requires enum type."); - -# if defined(MAGIC_ENUM_GET_ENUM_NAME_BUILTIN) - constexpr auto name_ptr = MAGIC_ENUM_GET_ENUM_NAME_BUILTIN(V); - auto name = name_ptr ? str_view{name_ptr, std::char_traits::length(name_ptr)} : str_view{}; -# else - str_view name = str_view{__FUNCSIG__, sizeof(__FUNCSIG__) - 17}; - std::size_t p = 0; - for (std::size_t i = name.size_; i > 0; --i) { - if (name.str_[i] == ',' || name.str_[i] == ':') { - p = i + 1; - break; - } - } - if (p > 0) { - name.size_ -= p; - name.str_ += p; - } - if (name.str_[0] == '(' || name.str_[0] == '-' || (name.str_[0] >= '0' && name.str_[0] <= '9')) { - name = str_view{}; - } - return name; -# endif -} -#endif - -template -constexpr auto enum_name() noexcept { - [[maybe_unused]] constexpr auto custom = customize::enum_name(V); - static_assert(std::is_same_v, customize::customize_t>, "magic_enum::customize requires customize_t type."); - if constexpr (custom.first == customize::detail::customize_tag::custom_tag) { - constexpr auto name = custom.second; - static_assert(!name.empty(), "magic_enum::customize requires not empty string."); - return static_str{name}; - } else if constexpr (custom.first == customize::detail::customize_tag::invalid_tag) { - return static_str<0>{}; - } else if constexpr (custom.first == customize::detail::customize_tag::default_tag) { -#if defined(MAGIC_ENUM_VS_2017_WORKAROUND) - constexpr auto name = n(); -#else - constexpr auto name = n(); -#endif - return static_str{name}; - } else { - static_assert(always_false_v, "magic_enum::customize invalid."); - } -} - -template -inline constexpr auto enum_name_v = enum_name(); - -template -constexpr bool is_valid() noexcept { -#if defined(__clang__) && __clang_major__ >= 16 - // https://reviews.llvm.org/D130058, https://reviews.llvm.org/D131307 - constexpr E v = __builtin_bit_cast(E, V); -#else - constexpr E v = static_cast(V); -#endif - [[maybe_unused]] constexpr auto custom = customize::enum_name(v); - static_assert(std::is_same_v, customize::customize_t>, "magic_enum::customize requires customize_t type."); - if constexpr (custom.first == customize::detail::customize_tag::custom_tag) { - constexpr auto name = custom.second; - static_assert(!name.empty(), "magic_enum::customize requires not empty string."); - return name.size() != 0; - } else if constexpr (custom.first == customize::detail::customize_tag::default_tag) { -#if defined(MAGIC_ENUM_VS_2017_WORKAROUND) - return n().size_ != 0; -#else - return n().size_ != 0; -#endif - } else { - return false; - } -} - -enum class enum_subtype { - common, - flags -}; - -template > -constexpr U ualue(std::size_t i) noexcept { - if constexpr (std::is_same_v) { // bool special case - static_assert(O == 0, "magic_enum::detail::ualue requires valid offset."); - - return static_cast(i); - } else if constexpr (S == enum_subtype::flags) { - return static_cast(U{1} << static_cast(static_cast(i) + O)); - } else { - return static_cast(static_cast(i) + O); - } -} - -template > -constexpr E value(std::size_t i) noexcept { - return static_cast(ualue(i)); -} - -template > -constexpr int reflected_min() noexcept { - if constexpr (S == enum_subtype::flags) { - return 0; - } else { - constexpr auto lhs = range_min::value; - constexpr auto rhs = (std::numeric_limits::min)(); - - if constexpr (cmp_less(rhs, lhs)) { - return lhs; - } else { - return rhs; - } - } -} - -template > -constexpr int reflected_max() noexcept { - if constexpr (S == enum_subtype::flags) { - return std::numeric_limits::digits - 1; - } else { - constexpr auto lhs = range_max::value; - constexpr auto rhs = (std::numeric_limits::max)(); - - if constexpr (cmp_less(lhs, rhs)) { - return lhs; - } else { - return rhs; - } - } -} - -#define MAGIC_ENUM_FOR_EACH_256(T) \ - T( 0)T( 1)T( 2)T( 3)T( 4)T( 5)T( 6)T( 7)T( 8)T( 9)T( 10)T( 11)T( 12)T( 13)T( 14)T( 15)T( 16)T( 17)T( 18)T( 19)T( 20)T( 21)T( 22)T( 23)T( 24)T( 25)T( 26)T( 27)T( 28)T( 29)T( 30)T( 31) \ - T( 32)T( 33)T( 34)T( 35)T( 36)T( 37)T( 38)T( 39)T( 40)T( 41)T( 42)T( 43)T( 44)T( 45)T( 46)T( 47)T( 48)T( 49)T( 50)T( 51)T( 52)T( 53)T( 54)T( 55)T( 56)T( 57)T( 58)T( 59)T( 60)T( 61)T( 62)T( 63) \ - T( 64)T( 65)T( 66)T( 67)T( 68)T( 69)T( 70)T( 71)T( 72)T( 73)T( 74)T( 75)T( 76)T( 77)T( 78)T( 79)T( 80)T( 81)T( 82)T( 83)T( 84)T( 85)T( 86)T( 87)T( 88)T( 89)T( 90)T( 91)T( 92)T( 93)T( 94)T( 95) \ - T( 96)T( 97)T( 98)T( 99)T(100)T(101)T(102)T(103)T(104)T(105)T(106)T(107)T(108)T(109)T(110)T(111)T(112)T(113)T(114)T(115)T(116)T(117)T(118)T(119)T(120)T(121)T(122)T(123)T(124)T(125)T(126)T(127) \ - T(128)T(129)T(130)T(131)T(132)T(133)T(134)T(135)T(136)T(137)T(138)T(139)T(140)T(141)T(142)T(143)T(144)T(145)T(146)T(147)T(148)T(149)T(150)T(151)T(152)T(153)T(154)T(155)T(156)T(157)T(158)T(159) \ - T(160)T(161)T(162)T(163)T(164)T(165)T(166)T(167)T(168)T(169)T(170)T(171)T(172)T(173)T(174)T(175)T(176)T(177)T(178)T(179)T(180)T(181)T(182)T(183)T(184)T(185)T(186)T(187)T(188)T(189)T(190)T(191) \ - T(192)T(193)T(194)T(195)T(196)T(197)T(198)T(199)T(200)T(201)T(202)T(203)T(204)T(205)T(206)T(207)T(208)T(209)T(210)T(211)T(212)T(213)T(214)T(215)T(216)T(217)T(218)T(219)T(220)T(221)T(222)T(223) \ - T(224)T(225)T(226)T(227)T(228)T(229)T(230)T(231)T(232)T(233)T(234)T(235)T(236)T(237)T(238)T(239)T(240)T(241)T(242)T(243)T(244)T(245)T(246)T(247)T(248)T(249)T(250)T(251)T(252)T(253)T(254)T(255) - -template -constexpr void valid_count(bool* valid, std::size_t& count) noexcept { -#define MAGIC_ENUM_V(O) \ - if constexpr ((I + O) < Size) { \ - if constexpr (is_valid(I + O)>()) { \ - valid[I + O] = true; \ - ++count; \ - } \ - } - - MAGIC_ENUM_FOR_EACH_256(MAGIC_ENUM_V); - - if constexpr ((I + 256) < Size) { - valid_count(valid, count); - } -#undef MAGIC_ENUM_V -} - -template -struct valid_count_t { - std::size_t count = 0; - bool valid[N] = {}; -}; - -template -constexpr auto valid_count() noexcept { - valid_count_t vc; - valid_count(vc.valid, vc.count); - return vc; -} - -template -constexpr auto values() noexcept { - constexpr auto vc = valid_count(); - - if constexpr (vc.count > 0) { -#if defined(MAGIC_ENUM_ARRAY_CONSTEXPR) - std::array values = {}; -#else - E values[vc.count] = {}; -#endif - for (std::size_t i = 0, v = 0; v < vc.count; ++i) { - if (vc.valid[i]) { - values[v++] = value(i); - } - } -#if defined(MAGIC_ENUM_ARRAY_CONSTEXPR) - return values; -#else - return to_array(values, std::make_index_sequence{}); -#endif - } else { - return std::array{}; - } -} - -template > -constexpr auto values() noexcept { - constexpr auto min = reflected_min(); - constexpr auto max = reflected_max(); - constexpr auto range_size = max - min + 1; - static_assert(range_size > 0, "magic_enum::enum_range requires valid size."); - static_assert(range_size < (std::numeric_limits::max)(), "magic_enum::enum_range requires valid size."); - - return values(); -} - -template > -constexpr enum_subtype subtype(std::true_type) noexcept { - if constexpr (std::is_same_v) { // bool special case - return enum_subtype::common; - } else if constexpr (has_is_flags::value) { - return customize::enum_range::is_flags ? enum_subtype::flags : enum_subtype::common; - } else { -#if defined(MAGIC_ENUM_AUTO_IS_FLAGS) - constexpr auto flags_values = values(); - constexpr auto default_values = values(); - if (flags_values.size() == 0 || default_values.size() > flags_values.size()) { - return enum_subtype::common; - } - for (std::size_t i = 0; i < default_values.size(); ++i) { - const auto v = static_cast(default_values[i]); - if (v != 0 && (v & (v - 1)) != 0) { - return enum_subtype::common; - } - } - return enum_subtype::flags; -#else - return enum_subtype::common; -#endif - } -} - -template -constexpr enum_subtype subtype(std::false_type) noexcept { - // For non-enum type return default common subtype. - return enum_subtype::common; -} - -template > -inline constexpr auto subtype_v = subtype(std::is_enum{}); - -template -inline constexpr auto values_v = values(); - -template > -using values_t = decltype((values_v)); - -template -inline constexpr auto count_v = values_v.size(); - -template > -inline constexpr auto min_v = (count_v > 0) ? static_cast(values_v.front()) : U{0}; - -template > -inline constexpr auto max_v = (count_v > 0) ? static_cast(values_v.back()) : U{0}; - -template -constexpr auto names(std::index_sequence) noexcept { - return std::array{{enum_name_v[I]>...}}; -} - -template -inline constexpr auto names_v = names(std::make_index_sequence>{}); - -template > -using names_t = decltype((names_v)); - -template -constexpr auto entries(std::index_sequence) noexcept { - return std::array, sizeof...(I)>{{{values_v[I], enum_name_v[I]>}...}}; -} - -template -inline constexpr auto entries_v = entries(std::make_index_sequence>{}); - -template > -using entries_t = decltype((entries_v)); - -template > -constexpr bool is_sparse() noexcept { - if constexpr (count_v == 0) { - return false; - } else if constexpr (std::is_same_v) { // bool special case - return false; - } else { - constexpr auto max = (S == enum_subtype::flags) ? log2(max_v) : max_v; - constexpr auto min = (S == enum_subtype::flags) ? log2(min_v) : min_v; - constexpr auto range_size = max - min + 1; - - return range_size != count_v; - } -} - -template > -inline constexpr bool is_sparse_v = is_sparse(); - -template > -constexpr U values_ors() noexcept { - static_assert(S == enum_subtype::flags, "magic_enum::detail::values_ors requires valid subtype."); - - auto ors = U{0}; - for (std::size_t i = 0; i < count_v; ++i) { - ors |= static_cast(values_v[i]); - } - - return ors; -} - -template -struct enable_if_enum {}; - -template -struct enable_if_enum { - using type = R; - static_assert(supported::value, "magic_enum unsupported compiler (https://github.com/Neargye/magic_enum#compiler-compatibility)."); -}; - -template , typename D = std::decay_t> -using enable_if_t = typename enable_if_enum && std::is_invocable_r_v, R>::type; - -template >, int> = 0> -using enum_concept = T; - -template > -struct is_scoped_enum : std::false_type {}; - -template -struct is_scoped_enum : std::bool_constant>> {}; - -template > -struct is_unscoped_enum : std::false_type {}; - -template -struct is_unscoped_enum : std::bool_constant>> {}; - -template >> -struct underlying_type {}; - -template -struct underlying_type : std::underlying_type> {}; - -#if defined(MAGIC_ENUM_ENABLE_HASH) || defined(MAGIC_ENUM_ENABLE_HASH_SWITCH) - -template -struct constexpr_hash_t; - -template -struct constexpr_hash_t>> { - constexpr auto operator()(Value value) const noexcept { - using U = typename underlying_type::type; - if constexpr (std::is_same_v) { // bool special case - return static_cast(value); - } else { - return static_cast(value); - } - } - using secondary_hash = constexpr_hash_t; -}; - -template -struct constexpr_hash_t>> { - static constexpr std::uint32_t crc_table[256] { - 0x00000000L, 0x77073096L, 0xee0e612cL, 0x990951baL, 0x076dc419L, 0x706af48fL, 0xe963a535L, 0x9e6495a3L, - 0x0edb8832L, 0x79dcb8a4L, 0xe0d5e91eL, 0x97d2d988L, 0x09b64c2bL, 0x7eb17cbdL, 0xe7b82d07L, 0x90bf1d91L, - 0x1db71064L, 0x6ab020f2L, 0xf3b97148L, 0x84be41deL, 0x1adad47dL, 0x6ddde4ebL, 0xf4d4b551L, 0x83d385c7L, - 0x136c9856L, 0x646ba8c0L, 0xfd62f97aL, 0x8a65c9ecL, 0x14015c4fL, 0x63066cd9L, 0xfa0f3d63L, 0x8d080df5L, - 0x3b6e20c8L, 0x4c69105eL, 0xd56041e4L, 0xa2677172L, 0x3c03e4d1L, 0x4b04d447L, 0xd20d85fdL, 0xa50ab56bL, - 0x35b5a8faL, 0x42b2986cL, 0xdbbbc9d6L, 0xacbcf940L, 0x32d86ce3L, 0x45df5c75L, 0xdcd60dcfL, 0xabd13d59L, - 0x26d930acL, 0x51de003aL, 0xc8d75180L, 0xbfd06116L, 0x21b4f4b5L, 0x56b3c423L, 0xcfba9599L, 0xb8bda50fL, - 0x2802b89eL, 0x5f058808L, 0xc60cd9b2L, 0xb10be924L, 0x2f6f7c87L, 0x58684c11L, 0xc1611dabL, 0xb6662d3dL, - 0x76dc4190L, 0x01db7106L, 0x98d220bcL, 0xefd5102aL, 0x71b18589L, 0x06b6b51fL, 0x9fbfe4a5L, 0xe8b8d433L, - 0x7807c9a2L, 0x0f00f934L, 0x9609a88eL, 0xe10e9818L, 0x7f6a0dbbL, 0x086d3d2dL, 0x91646c97L, 0xe6635c01L, - 0x6b6b51f4L, 0x1c6c6162L, 0x856530d8L, 0xf262004eL, 0x6c0695edL, 0x1b01a57bL, 0x8208f4c1L, 0xf50fc457L, - 0x65b0d9c6L, 0x12b7e950L, 0x8bbeb8eaL, 0xfcb9887cL, 0x62dd1ddfL, 0x15da2d49L, 0x8cd37cf3L, 0xfbd44c65L, - 0x4db26158L, 0x3ab551ceL, 0xa3bc0074L, 0xd4bb30e2L, 0x4adfa541L, 0x3dd895d7L, 0xa4d1c46dL, 0xd3d6f4fbL, - 0x4369e96aL, 0x346ed9fcL, 0xad678846L, 0xda60b8d0L, 0x44042d73L, 0x33031de5L, 0xaa0a4c5fL, 0xdd0d7cc9L, - 0x5005713cL, 0x270241aaL, 0xbe0b1010L, 0xc90c2086L, 0x5768b525L, 0x206f85b3L, 0xb966d409L, 0xce61e49fL, - 0x5edef90eL, 0x29d9c998L, 0xb0d09822L, 0xc7d7a8b4L, 0x59b33d17L, 0x2eb40d81L, 0xb7bd5c3bL, 0xc0ba6cadL, - 0xedb88320L, 0x9abfb3b6L, 0x03b6e20cL, 0x74b1d29aL, 0xead54739L, 0x9dd277afL, 0x04db2615L, 0x73dc1683L, - 0xe3630b12L, 0x94643b84L, 0x0d6d6a3eL, 0x7a6a5aa8L, 0xe40ecf0bL, 0x9309ff9dL, 0x0a00ae27L, 0x7d079eb1L, - 0xf00f9344L, 0x8708a3d2L, 0x1e01f268L, 0x6906c2feL, 0xf762575dL, 0x806567cbL, 0x196c3671L, 0x6e6b06e7L, - 0xfed41b76L, 0x89d32be0L, 0x10da7a5aL, 0x67dd4accL, 0xf9b9df6fL, 0x8ebeeff9L, 0x17b7be43L, 0x60b08ed5L, - 0xd6d6a3e8L, 0xa1d1937eL, 0x38d8c2c4L, 0x4fdff252L, 0xd1bb67f1L, 0xa6bc5767L, 0x3fb506ddL, 0x48b2364bL, - 0xd80d2bdaL, 0xaf0a1b4cL, 0x36034af6L, 0x41047a60L, 0xdf60efc3L, 0xa867df55L, 0x316e8eefL, 0x4669be79L, - 0xcb61b38cL, 0xbc66831aL, 0x256fd2a0L, 0x5268e236L, 0xcc0c7795L, 0xbb0b4703L, 0x220216b9L, 0x5505262fL, - 0xc5ba3bbeL, 0xb2bd0b28L, 0x2bb45a92L, 0x5cb36a04L, 0xc2d7ffa7L, 0xb5d0cf31L, 0x2cd99e8bL, 0x5bdeae1dL, - 0x9b64c2b0L, 0xec63f226L, 0x756aa39cL, 0x026d930aL, 0x9c0906a9L, 0xeb0e363fL, 0x72076785L, 0x05005713L, - 0x95bf4a82L, 0xe2b87a14L, 0x7bb12baeL, 0x0cb61b38L, 0x92d28e9bL, 0xe5d5be0dL, 0x7cdcefb7L, 0x0bdbdf21L, - 0x86d3d2d4L, 0xf1d4e242L, 0x68ddb3f8L, 0x1fda836eL, 0x81be16cdL, 0xf6b9265bL, 0x6fb077e1L, 0x18b74777L, - 0x88085ae6L, 0xff0f6a70L, 0x66063bcaL, 0x11010b5cL, 0x8f659effL, 0xf862ae69L, 0x616bffd3L, 0x166ccf45L, - 0xa00ae278L, 0xd70dd2eeL, 0x4e048354L, 0x3903b3c2L, 0xa7672661L, 0xd06016f7L, 0x4969474dL, 0x3e6e77dbL, - 0xaed16a4aL, 0xd9d65adcL, 0x40df0b66L, 0x37d83bf0L, 0xa9bcae53L, 0xdebb9ec5L, 0x47b2cf7fL, 0x30b5ffe9L, - 0xbdbdf21cL, 0xcabac28aL, 0x53b39330L, 0x24b4a3a6L, 0xbad03605L, 0xcdd70693L, 0x54de5729L, 0x23d967bfL, - 0xb3667a2eL, 0xc4614ab8L, 0x5d681b02L, 0x2a6f2b94L, 0xb40bbe37L, 0xc30c8ea1L, 0x5a05df1bL, 0x2d02ef8dL - }; - constexpr std::uint32_t operator()(string_view value) const noexcept { - auto crc = static_cast(0xffffffffL); - for (const auto c : value) { - crc = (crc >> 8) ^ crc_table[(crc ^ static_cast(c)) & 0xff]; - } - return crc ^ 0xffffffffL; - } - - struct secondary_hash { - constexpr std::uint32_t operator()(string_view value) const noexcept { - auto acc = static_cast(2166136261ULL); - for (const auto c : value) { - acc = ((acc ^ static_cast(c)) * static_cast(16777619ULL)) & (std::numeric_limits::max)(); - } - return static_cast(acc); - } - }; -}; - -template -inline constexpr Hash hash_v{}; - -template -constexpr auto calculate_cases(std::size_t Page) noexcept { - constexpr std::array values = *GlobValues; - constexpr std::size_t size = values.size(); - - using switch_t = std::invoke_result_t; - static_assert(std::is_integral_v && !std::is_same_v); - const std::size_t values_to = (std::min)(static_cast(256), size - Page); - - std::array result{}; - auto fill = result.begin(); - { - auto first = values.begin() + static_cast(Page); - auto last = values.begin() + static_cast(Page + values_to); - while (first != last) { - *fill++ = hash_v(*first++); - } - } - - // dead cases, try to avoid case collisions - for (switch_t last_value = result[values_to - 1]; fill != result.end() && last_value != (std::numeric_limits::max)(); *fill++ = ++last_value) { - } - - { - auto it = result.begin(); - auto last_value = (std::numeric_limits::min)(); - for (; fill != result.end(); *fill++ = last_value++) { - while (last_value == *it) { - ++last_value, ++it; - } - } - } - - return result; -} - -template -constexpr R invoke_r(F&& f, Args&&... args) noexcept(std::is_nothrow_invocable_r_v) { - if constexpr (std::is_void_v) { - std::forward(f)(std::forward(args)...); - } else { - return static_cast(std::forward(f)(std::forward(args)...)); - } -} - -enum class case_call_t { - index, - value -}; - -template -inline constexpr auto default_result_type_lambda = []() noexcept(std::is_nothrow_default_constructible_v) { return T{}; }; - -template <> -inline constexpr auto default_result_type_lambda = []() noexcept {}; - -template -constexpr bool has_duplicate() noexcept { - using value_t = std::decay_t; - using hash_value_t = std::invoke_result_t; - std::arraysize()> hashes{}; - std::size_t size = 0; - for (auto elem : *Arr) { - hashes[size] = hash_v(elem); - for (auto i = size++; i > 0; --i) { - if (hashes[i] < hashes[i - 1]) { - auto tmp = hashes[i]; - hashes[i] = hashes[i - 1]; - hashes[i - 1] = tmp; - } else if (hashes[i] == hashes[i - 1]) { - return false; - } else { - break; - } - } - } - return true; -} - -#define MAGIC_ENUM_CASE(val) \ - case cases[val]: \ - if constexpr ((val) + Page < size) { \ - if (!pred(values[val + Page], searched)) { \ - break; \ - } \ - if constexpr (CallValue == case_call_t::index) { \ - if constexpr (std::is_invocable_r_v>) { \ - return detail::invoke_r(std::forward(lambda), std::integral_constant{}); \ - } else if constexpr (std::is_invocable_v>) { \ - MAGIC_ENUM_ASSERT(false && "magic_enum::detail::constexpr_switch wrong result type."); \ - } \ - } else if constexpr (CallValue == case_call_t::value) { \ - if constexpr (std::is_invocable_r_v>) { \ - return detail::invoke_r(std::forward(lambda), enum_constant{}); \ - } else if constexpr (std::is_invocable_r_v>) { \ - MAGIC_ENUM_ASSERT(false && "magic_enum::detail::constexpr_switch wrong result type."); \ - } \ - } \ - break; \ - } else [[fallthrough]]; - -template ::value_type>, - typename BinaryPredicate = std::equal_to<>, - typename Lambda, - typename ResultGetterType> -constexpr decltype(auto) constexpr_switch( - Lambda&& lambda, - typename std::decay_t::value_type searched, - ResultGetterType&& def, - BinaryPredicate&& pred = {}) { - using result_t = std::invoke_result_t; - using hash_t = std::conditional_t(), Hash, typename Hash::secondary_hash>; - static_assert(has_duplicate(), "magic_enum::detail::constexpr_switch duplicated hash found, please report it: https://github.com/Neargye/magic_enum/issues."); - constexpr std::array values = *GlobValues; - constexpr std::size_t size = values.size(); - constexpr std::array cases = calculate_cases(Page); - - switch (hash_v(searched)) { - MAGIC_ENUM_FOR_EACH_256(MAGIC_ENUM_CASE) - default: - if constexpr (size > 256 + Page) { - return constexpr_switch(std::forward(lambda), searched, std::forward(def)); - } - break; - } - return def(); -} - -#undef MAGIC_ENUM_CASE - -#endif - -} // namespace magic_enum::detail - -// Checks is magic_enum supported compiler. -inline constexpr bool is_magic_enum_supported = detail::supported::value; - -template -using Enum = detail::enum_concept; - -// Checks whether T is an Unscoped enumeration type. -// Provides the member constant value which is equal to true, if T is an [Unscoped enumeration](https://en.cppreference.com/w/cpp/language/enum#Unscoped_enumeration) type. Otherwise, value is equal to false. -template -struct is_unscoped_enum : detail::is_unscoped_enum {}; - -template -inline constexpr bool is_unscoped_enum_v = is_unscoped_enum::value; - -// Checks whether T is an Scoped enumeration type. -// Provides the member constant value which is equal to true, if T is an [Scoped enumeration](https://en.cppreference.com/w/cpp/language/enum#Scoped_enumerations) type. Otherwise, value is equal to false. -template -struct is_scoped_enum : detail::is_scoped_enum {}; - -template -inline constexpr bool is_scoped_enum_v = is_scoped_enum::value; - -// If T is a complete enumeration type, provides a member typedef type that names the underlying type of T. -// Otherwise, if T is not an enumeration type, there is no member type. Otherwise (T is an incomplete enumeration type), the program is ill-formed. -template -struct underlying_type : detail::underlying_type {}; - -template -using underlying_type_t = typename underlying_type::type; - -template -using enum_constant = detail::enum_constant; - -// Returns type name of enum. -template -[[nodiscard]] constexpr auto enum_type_name() noexcept -> detail::enable_if_t { - constexpr string_view name = detail::type_name_v>; - static_assert(!name.empty(), "magic_enum::enum_type_name enum type does not have a name."); - - return name; -} - -// Returns number of enum values. -template > -[[nodiscard]] constexpr auto enum_count() noexcept -> detail::enable_if_t { - return detail::count_v, S>; -} - -// Returns enum value at specified index. -// No bounds checking is performed: the behavior is undefined if index >= number of enum values. -template > -[[nodiscard]] constexpr auto enum_value(std::size_t index) noexcept -> detail::enable_if_t> { - using D = std::decay_t; - - if constexpr (detail::is_sparse_v) { - return MAGIC_ENUM_ASSERT(index < detail::count_v), detail::values_v[index]; - } else { - constexpr auto min = (S == detail::enum_subtype::flags) ? detail::log2(detail::min_v) : detail::min_v; - - return MAGIC_ENUM_ASSERT(index < detail::count_v), detail::value(index); - } -} - -// Returns enum value at specified index. -template > -[[nodiscard]] constexpr auto enum_value() noexcept -> detail::enable_if_t> { - using D = std::decay_t; - static_assert(I < detail::count_v, "magic_enum::enum_value out of range."); - - return enum_value(I); -} - -// Returns std::array with enum values, sorted by enum value. -template > -[[nodiscard]] constexpr auto enum_values() noexcept -> detail::enable_if_t> { - return detail::values_v, S>; -} - -// Returns integer value from enum value. -template -[[nodiscard]] constexpr auto enum_integer(E value) noexcept -> detail::enable_if_t> { - return static_cast>(value); -} - -// Returns underlying value from enum value. -template -[[nodiscard]] constexpr auto enum_underlying(E value) noexcept -> detail::enable_if_t> { - return static_cast>(value); -} - -// Obtains index in enum values from enum value. -// Returns optional with index. -template > -[[nodiscard]] constexpr auto enum_index(E value) noexcept -> detail::enable_if_t> { - using D = std::decay_t; - using U = underlying_type_t; - - if constexpr (detail::count_v == 0) { - static_cast(value); - return {}; // Empty enum. - } else if constexpr (detail::is_sparse_v || (S == detail::enum_subtype::flags)) { -#if defined(MAGIC_ENUM_ENABLE_HASH) - return detail::constexpr_switch<&detail::values_v, detail::case_call_t::index>( - [](std::size_t i) { return optional{i}; }, - value, - detail::default_result_type_lambda>); -#else - for (std::size_t i = 0; i < detail::count_v; ++i) { - if (enum_value(i) == value) { - return i; - } - } - return {}; // Invalid value or out of range. -#endif - } else { - const auto v = static_cast(value); - if (v >= detail::min_v && v <= detail::max_v) { - return static_cast(v - detail::min_v); - } - return {}; // Invalid value or out of range. - } -} - -// Obtains index in enum values from enum value. -// Returns optional with index. -template -[[nodiscard]] constexpr auto enum_index(E value) noexcept -> detail::enable_if_t> { - using D = std::decay_t; - - return enum_index(value); -} - -// Obtains index in enum values from static storage enum variable. -template >> -[[nodiscard]] constexpr auto enum_index() noexcept -> detail::enable_if_t { - constexpr auto index = enum_index, S>(V); - static_assert(index, "magic_enum::enum_index enum value does not have a index."); - - return *index; -} - -// Returns name from static storage enum variable. -// This version is much lighter on the compile times and is not restricted to the enum_range limitation. -template -[[nodiscard]] constexpr auto enum_name() noexcept -> detail::enable_if_t { - constexpr string_view name = detail::enum_name_v, V>; - static_assert(!name.empty(), "magic_enum::enum_name enum value does not have a name."); - - return name; -} - -// Returns name from enum value. -// If enum value does not have name or value out of range, returns empty string. -template > -[[nodiscard]] constexpr auto enum_name(E value) noexcept -> detail::enable_if_t { - using D = std::decay_t; - - if (const auto i = enum_index(value)) { - return detail::names_v[*i]; - } - return {}; -} - -// Returns name from enum value. -// If enum value does not have name or value out of range, returns empty string. -template -[[nodiscard]] constexpr auto enum_name(E value) -> detail::enable_if_t { - using D = std::decay_t; - - return enum_name(value); -} - -// Returns std::array with names, sorted by enum value. -template > -[[nodiscard]] constexpr auto enum_names() noexcept -> detail::enable_if_t> { - return detail::names_v, S>; -} - -// Returns std::array with pairs (value, name), sorted by enum value. -template > -[[nodiscard]] constexpr auto enum_entries() noexcept -> detail::enable_if_t> { - return detail::entries_v, S>; -} - -// Allows you to write magic_enum::enum_cast("bar", magic_enum::case_insensitive); -inline constexpr auto case_insensitive = detail::case_insensitive<>{}; - -// Obtains enum value from integer value. -// Returns optional with enum value. -template > -[[nodiscard]] constexpr auto enum_cast(underlying_type_t value) noexcept -> detail::enable_if_t>> { - using D = std::decay_t; - - if constexpr (detail::count_v == 0) { - static_cast(value); - return {}; // Empty enum. - } else { - if constexpr (detail::is_sparse_v || (S == detail::enum_subtype::flags)) { -#if defined(MAGIC_ENUM_ENABLE_HASH) - return detail::constexpr_switch<&detail::values_v, detail::case_call_t::value>( - [](D v) { return optional{v}; }, - static_cast(value), - detail::default_result_type_lambda>); -#else - for (std::size_t i = 0; i < detail::count_v; ++i) { - if (value == static_cast>(enum_value(i))) { - return static_cast(value); - } - } - return {}; // Invalid value or out of range. -#endif - } else { - if (value >= detail::min_v && value <= detail::max_v) { - return static_cast(value); - } - return {}; // Invalid value or out of range. - } - } -} - -// Obtains enum value from name. -// Returns optional with enum value. -template , typename BinaryPredicate = std::equal_to<>> -[[nodiscard]] constexpr auto enum_cast(string_view value, [[maybe_unused]] BinaryPredicate p = {}) noexcept(detail::is_nothrow_invocable()) -> detail::enable_if_t>, BinaryPredicate> { - using D = std::decay_t; - - if constexpr (detail::count_v == 0) { - static_cast(value); - return {}; // Empty enum. -#if defined(MAGIC_ENUM_ENABLE_HASH) - } else if constexpr (detail::is_default_predicate()) { - return detail::constexpr_switch<&detail::names_v, detail::case_call_t::index>( - [](std::size_t i) { return optional{detail::values_v[i]}; }, - value, - detail::default_result_type_lambda>, - [&p](string_view lhs, string_view rhs) { return detail::cmp_equal(lhs, rhs, p); }); -#endif - } else { - for (std::size_t i = 0; i < detail::count_v; ++i) { - if (detail::cmp_equal(value, detail::names_v[i], p)) { - return enum_value(i); - } - } - return {}; // Invalid value or out of range. - } -} - -// Checks whether enum contains value with such value. -template > -[[nodiscard]] constexpr auto enum_contains(E value) noexcept -> detail::enable_if_t { - using D = std::decay_t; - using U = underlying_type_t; - - return static_cast(enum_cast(static_cast(value))); -} - -// Checks whether enum contains value with such value. -template -[[nodiscard]] constexpr auto enum_contains(E value) noexcept -> detail::enable_if_t { - using D = std::decay_t; - using U = underlying_type_t; - - return static_cast(enum_cast(static_cast(value))); -} - -// Checks whether enum contains value with such integer value. -template > -[[nodiscard]] constexpr auto enum_contains(underlying_type_t value) noexcept -> detail::enable_if_t { - using D = std::decay_t; - - return static_cast(enum_cast(value)); -} - -// Checks whether enum contains enumerator with such name. -template , typename BinaryPredicate = std::equal_to<>> -[[nodiscard]] constexpr auto enum_contains(string_view value, BinaryPredicate p = {}) noexcept(detail::is_nothrow_invocable()) -> detail::enable_if_t { - using D = std::decay_t; - - return static_cast(enum_cast(value, std::move(p))); -} - -template -inline constexpr auto as_flags = AsFlags ? detail::enum_subtype::flags : detail::enum_subtype::common; - -template -inline constexpr auto as_common = AsFlags ? detail::enum_subtype::common : detail::enum_subtype::flags; - -namespace bitwise_operators { - -template = 0> -constexpr E operator~(E rhs) noexcept { - return static_cast(~static_cast>(rhs)); -} - -template = 0> -constexpr E operator|(E lhs, E rhs) noexcept { - return static_cast(static_cast>(lhs) | static_cast>(rhs)); -} - -template = 0> -constexpr E operator&(E lhs, E rhs) noexcept { - return static_cast(static_cast>(lhs) & static_cast>(rhs)); -} - -template = 0> -constexpr E operator^(E lhs, E rhs) noexcept { - return static_cast(static_cast>(lhs) ^ static_cast>(rhs)); -} - -template = 0> -constexpr E& operator|=(E& lhs, E rhs) noexcept { - return lhs = (lhs | rhs); -} - -template = 0> -constexpr E& operator&=(E& lhs, E rhs) noexcept { - return lhs = (lhs & rhs); -} - -template = 0> -constexpr E& operator^=(E& lhs, E rhs) noexcept { - return lhs = (lhs ^ rhs); -} - -} // namespace magic_enum::bitwise_operators - -} // namespace magic_enum - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#elif defined(_MSC_VER) -# pragma warning(pop) -#endif - -#undef MAGIC_ENUM_GET_ENUM_NAME_BUILTIN -#undef MAGIC_ENUM_GET_TYPE_NAME_BUILTIN -#undef MAGIC_ENUM_VS_2017_WORKAROUND -#undef MAGIC_ENUM_ARRAY_CONSTEXPR -#undef MAGIC_ENUM_FOR_EACH_256 - -#endif // NEARGYE_MAGIC_ENUM_HPP diff --git a/tt_metal/tools/profiler/op_profiler.hpp b/tt_metal/tools/profiler/op_profiler.hpp index 77a50644273..04aeb667af1 100644 --- a/tt_metal/tools/profiler/op_profiler.hpp +++ b/tt_metal/tools/profiler/op_profiler.hpp @@ -11,7 +11,7 @@ #include "ttnn/tensor/tensor.hpp" #include "third_party/json/json.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "tools/profiler/profiler.hpp" #include "tt_metal/impl/kernels/kernel.hpp" #include "ttnn/operation.hpp" diff --git a/tt_metal/tt_stl/reflection.hpp b/tt_metal/tt_stl/reflection.hpp index 7cec7334df7..79036fa0da1 100644 --- a/tt_metal/tt_stl/reflection.hpp +++ b/tt_metal/tt_stl/reflection.hpp @@ -20,7 +20,7 @@ #include "concepts.hpp" #include "third_party/json/json.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "type_name.hpp" #include "tt_metal/common/logger.hpp" diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 2848e52cf2a..2cf781ccfe2 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -549,7 +549,6 @@ set(TTNN_PUBLIC_LINK_DIRS "") set(TTNN_PRECOMPILED_HEADERS ${PROJECT_SOURCE_DIR}/tt_metal/tt_stl/reflection.hpp ${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/operation.hpp - ${PROJECT_SOURCE_DIR}/tt_metal/third_party/magic_enum/magic_enum.hpp ${PROJECT_SOURCE_DIR}/tt_metal/third_party/tracy/public/tracy/Tracy.hpp ${PROJECT_SOURCE_DIR}/tt_metal/third_party/fmt/fmt/core.h ${PROJECT_SOURCE_DIR}/tt_metal/third_party/fmt/fmt/format.h diff --git a/ttnn/cpp/pybind11/export_enum.hpp b/ttnn/cpp/pybind11/export_enum.hpp index d930e8aa934..496f23a2992 100644 --- a/ttnn/cpp/pybind11/export_enum.hpp +++ b/ttnn/cpp/pybind11/export_enum.hpp @@ -6,7 +6,7 @@ #include #include -#include "third_party/magic_enum/magic_enum.hpp" +#include namespace py = pybind11; diff --git a/ttnn/cpp/ttnn/core.hpp b/ttnn/cpp/ttnn/core.hpp index 44e62a6696c..b19079d6c28 100644 --- a/ttnn/cpp/ttnn/core.hpp +++ b/ttnn/cpp/ttnn/core.hpp @@ -6,7 +6,7 @@ #include #include -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/tensor/tensor.hpp" #include "ttnn/tensor/tensor_impl.hpp" // TTNN_TENSOR_PRINT_PROFILE #include "ttnn/tensor/types.hpp" diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.cpp index c39c7eb980b..3cbda33472f 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.cpp @@ -5,7 +5,7 @@ #include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" #include "common/constants.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "tt_metal/detail/util.hpp" #include "tt_metal/common/work_split.hpp" diff --git a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.hpp b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.hpp index ee384501a27..2ab6380247f 100644 --- a/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.hpp +++ b/ttnn/cpp/ttnn/operations/core/to_layout/to_layout_op.hpp @@ -8,7 +8,7 @@ #include #include "ttnn/tensor/tensor.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/tensor/tensor_utils.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp index 845840cd2eb..6d28f031836 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/reshard/device/reshard_op.cpp @@ -4,7 +4,7 @@ #include "reshard_op.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "reshard_program_factory.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp index d794fc0ed11..b71c58f041b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "binary_composite_op.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/operations/eltwise/binary/binary.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/types.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp index a88e910f8e1..0c0b7f07314 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.hpp @@ -8,7 +8,7 @@ #include #include #include "ttnn/tensor/tensor.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/operations/core/core.hpp" namespace ttnn::operations::binary{ diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp index 1a886f957df..bef06ac8379 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp @@ -10,7 +10,7 @@ #include "ttnn/common/constants.hpp" #include "ttnn/tensor/tensor.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/tensor/host_buffer/functions.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp index 2238f055fa2..ae9a5655ac6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary_backward/binary_backward.cpp @@ -26,7 +26,7 @@ #include "ttnn/operations/creation.hpp" #include "ttnn/common/constants.hpp" #include "ttnn/operations/eltwise/binary_backward/binary_backward.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include namespace ttnn::operations::binary_backward { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp index f2f930bd291..a4cfe450306 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_binary/device/complex_binary_op.hpp @@ -7,7 +7,7 @@ #include #include #include "ttnn/tensor/tensor.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn::operations::complex_binary { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp index b2d89086a92..1e9301c9bd1 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary/device/complex_unary_op.hpp @@ -7,7 +7,7 @@ #include #include #include "ttnn/tensor/tensor.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn::operations::complex_unary { diff --git a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp index 3443bdf93f7..5436f5695a6 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/complex_unary_backward/device/complex_unary_backward_op.hpp @@ -7,7 +7,7 @@ #include #include #include "ttnn/tensor/tensor.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/operations/eltwise/complex/complex.hpp" namespace ttnn::operations::complex_unary_backward { diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.hpp index 5c41dca31f7..9265a4ffeac 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary/ternary_composite_op.hpp @@ -7,7 +7,7 @@ #include #include -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/tensor/tensor.hpp" #include "ttnn/operations/core/core.hpp" #include "ttnn/run_operation.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward.cpp index fb35b36ef6a..e6c7a9b4a80 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/ternary_backward/ternary_backward.cpp @@ -12,7 +12,7 @@ #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" #include "ttnn/operations/eltwise/ternary_backward/ternary_backward.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include namespace ttnn::operations::ternary_backward { diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp index a58e45d2a5f..5c52b147acd 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.cpp @@ -8,7 +8,7 @@ #include #include -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "tt_metal/common/bfloat16.hpp" #include "ttnn/operations/data_movement/reshape_on_device/reshape.hpp" #include "ttnn/operations/data_movement/bcast/bcast.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp index c2fdb3bcb05..3608be72e02 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_composite_op.hpp @@ -6,7 +6,7 @@ #include #include #include "ttnn/tensor/tensor.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/cpp/ttnn/operations/eltwise/ternary/where.hpp" #include "ttnn/operations/eltwise/unary/unary.hpp" #include "ttnn/operations/eltwise/binary/binary.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp index 3bf9cde0656..ef236ac3c43 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp @@ -4,7 +4,7 @@ #include "unary_device_operation.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "tt_metal/common/constants.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/tools/profiler/op_profiler.hpp" diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp index 079f756203b..ff018c6d23a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary_backward/unary_backward.cpp @@ -2,7 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 -#include "third_party/magic_enum/magic_enum.hpp" +#include #include "ttnn/operations/data_movement/bcast/bcast.hpp" #include "tt_metal/common/constants.hpp" #include "ttnn/common/constants.hpp" diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp index 5322de38995..b791fdb3617 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_post_all_gather_op.cpp @@ -11,7 +11,7 @@ #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include diff --git a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp index 835a699785b..ea0729ed973 100644 --- a/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/layernorm_distributed/device/layernorm_pre_all_gather_op.cpp @@ -11,7 +11,7 @@ #include "tt_metal/common/constants.hpp" #include "tt_metal/detail/util.hpp" -#include "third_party/magic_enum/magic_enum.hpp" +#include #include From fad2a69dd066cbf57af1067dac77f959c1c9c8e1 Mon Sep 17 00:00:00 2001 From: Raymond Kim <109366641+tt-rkim@users.noreply.github.com> Date: Wed, 9 Oct 2024 10:08:01 -0400 Subject: [PATCH 37/58] #0: Increase mamba device perf threshold because looks like there was a slight performance bump [skip ci] (#13635) Force-merging to fix device perf single-card. --- models/demos/wormhole/mamba/tests/test_mamba_perf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/demos/wormhole/mamba/tests/test_mamba_perf.py b/models/demos/wormhole/mamba/tests/test_mamba_perf.py index 1d66d1cbf8e..b515c52e345 100644 --- a/models/demos/wormhole/mamba/tests/test_mamba_perf.py +++ b/models/demos/wormhole/mamba/tests/test_mamba_perf.py @@ -143,7 +143,7 @@ def test_mamba_perf_e2e( @pytest.mark.models_device_performance_bare_metal @pytest.mark.parametrize( "batch, expected_layer_duration_ms", - ((32, 1.71),), + ((32, 1.689),), ) def test_mamba_perf_device(batch, expected_layer_duration_ms): subdir = "ttnn_mamba" From b2f8ce6095d75b9ee3fc3d1311e45355acf754dd Mon Sep 17 00:00:00 2001 From: Pavle Josipovic Date: Wed, 2 Oct 2024 15:52:04 +0000 Subject: [PATCH 38/58] #13541: Conv2d auto shard interleaved tensor Conv2d requires (internally) sharded tensor for in0(activations). Currently if shard_layout is not provided as part of the Conv2dConfig, op will default to height sharding. And op will run an op to convert input tensor to height sharded tensor. This isn't always the best heuristic. This change introduces a logic to select the best sharding layout based on the input tensor shape, in order to increase core count and consequently increase L1 space op can use, and reduce number of out-of-memory errors. There is still no guarantee that op will fit in L1 space, but we can observe high success rate in tests coming out of pytorch2 ttnn integration. This heuristic is not the most optimal one and we will keep iterating on it as well, but given that it resolve a lot of torch tests cases we want to prioritize getting it in. This change is also a BREAKING CHANGE, as it changes the default behavior of Conv2d when user passes in interlaved tensor for in0(activations). This doesn't seem to affect our models given that they already provide sharded tensors as inputs, but it does affect some of our tests. In case users wants to pass in interleaved tensor, and manually pick sharding layout, user has to set shard_layout in Conv2dConfig, which now defaults to std::nullopt instead of HEIGHT_SHARDED. To test this change a variant with auto shard selection has been added to most of tests/ttnn/unit_tests/operations/test_new_conv2d.py. Some tests are skipped (like stable diffusion tests) as this change has exposed few issues with conv WIDTH_SHARDED codepath that we need to address. --- .../ttnn_functional_resnet50_new_conv_api.py | 2 + .../ttnn/unit_tests/operations/test_conv1d.py | 12 +- .../unit_tests/operations/test_new_conv2d.py | 55 +++++++- .../ttnn/operations/conv/conv2d/conv2d.cpp | 122 +++++++++++++----- .../ttnn/operations/conv/conv2d/conv2d.hpp | 10 +- .../operations/conv/conv2d/conv2d_pybind.cpp | 52 ++++---- 6 files changed, 190 insertions(+), 63 deletions(-) diff --git a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py index 9d7f3efed24..d8d09562b8a 100644 --- a/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py +++ b/models/demos/ttnn_resnet/tt/ttnn_functional_resnet50_new_conv_api.py @@ -639,6 +639,8 @@ def __init__( width=self.conv1_output_width, in_channels=self.conv1_input_channels, out_channels=self.conv1_output_channels, + kernel_size=[self.conv1_kernel_size[0], self.conv1_kernel_size[1]], + stride=[self.conv1_stride[0], self.conv1_stride[1]], ) def __del__(self): diff --git a/tests/ttnn/unit_tests/operations/test_conv1d.py b/tests/ttnn/unit_tests/operations/test_conv1d.py index ed9d7cb2ac6..3e7a1496c63 100644 --- a/tests/ttnn/unit_tests/operations/test_conv1d.py +++ b/tests/ttnn/unit_tests/operations/test_conv1d.py @@ -45,6 +45,7 @@ def run_conv( deallocate_activation=True, debug=False, groups=1, + auto_shard=False, ): # has_bias = False has_bias = False @@ -78,13 +79,17 @@ def run_conv( tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + shard_layout = ( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ) + if auto_shard: + shard_layout = None + conv_config = ttnn.Conv1dConfig( dtype=output_dtype, weights_dtype=weights_dtype, math_fidelity=math_fidelity, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if use_1d_systolic_array - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=shard_layout, input_channels_alignment=(16 if use_shallow_conv_variant else 32), deallocate_activation=deallocate_activation, fp32_dest_acc_enabled=fp32_accum, @@ -214,6 +219,7 @@ def test_conv1d_mamba( padded_input_channels=None, output_layout=output_layout, groups=groups, + auto_shard=True, ) diff --git a/tests/ttnn/unit_tests/operations/test_new_conv2d.py b/tests/ttnn/unit_tests/operations/test_new_conv2d.py index d04e1ee385d..0217433c0e2 100644 --- a/tests/ttnn/unit_tests/operations/test_new_conv2d.py +++ b/tests/ttnn/unit_tests/operations/test_new_conv2d.py @@ -71,6 +71,7 @@ def run_conv( groups=1, has_bias=True, shard_layout=None, + auto_shard=False, ): torch.manual_seed(0) conv_input_shape = [batch_size, input_channels, input_height, input_width] @@ -115,7 +116,7 @@ def run_conv( tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) - if shard_layout is None: + if shard_layout is None and not auto_shard: shard_layout = ( ttnn.TensorMemoryLayout.HEIGHT_SHARDED if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED ) @@ -249,13 +250,14 @@ def run_conv_with_split( torch_input2_tensor = torch.permute(split_input_tensors[1], (0, 2, 3, 1)) reader_patterns_cache = {} + shard_layout = ( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ) conv_config = ttnn.Conv2dConfig( dtype=activations_dtype, weights_dtype=weights_dtype, math_fidelity=math_fidelity, - shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED - if use_1d_systolic_array - else ttnn.TensorMemoryLayout.BLOCK_SHARDED, + shard_layout=shard_layout if use_1d_systolic_array else ttnn.TensorMemoryLayout.BLOCK_SHARDED, fp32_dest_acc_enabled=fp32_accum, packer_l1_accum_enabled=packer_l1_acc, # input_channels_alignment=(16 if use_shallow_conv_variant else 32), @@ -346,6 +348,7 @@ def run_conv_with_split( "activations_dtype", [ttnn.bfloat16], ) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_conv_ws( device, use_program_cache, @@ -362,6 +365,7 @@ def test_conv_ws( has_bias, weights_dtype, activations_dtype, + auto_shard, ): stride_h = stride stride_w = stride @@ -419,7 +423,7 @@ def test_conv_ws( dtype=activations_dtype, weights_dtype=weights_dtype, math_fidelity=ttnn.MathFidelity.HiFi4, - shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED, + shard_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED if not auto_shard else None, input_channels_alignment=32, deallocate_activation=deallocate_activation, fp32_dest_acc_enabled=fp32_accum, @@ -498,6 +502,7 @@ def test_conv_ws( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) @skip_for_grayskull() def test_conv_for_segformer_512x512( device, @@ -521,6 +526,7 @@ def test_conv_for_segformer_512x512( use_shallow_conv_variant, groups, output_layout, + auto_shard, ): run_conv( device, @@ -544,6 +550,7 @@ def test_conv_for_segformer_512x512( groups=groups, output_layout=output_layout, has_bias=False, + auto_shard=auto_shard, ) @@ -585,6 +592,7 @@ def test_conv_for_segformer_512x512( [ttnn.bfloat16, ttnn.bfloat8_b], ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_resnet50_conv_gs( device, use_program_cache, @@ -604,6 +612,7 @@ def test_resnet50_conv_gs( pad_w, use_1d_systolic_array, config_override, + auto_shard, ): if is_blackhole(): pytest.skip("This test is for Grayskull only") @@ -646,6 +655,7 @@ def test_resnet50_conv_gs( use_shallow_conv_variant=input_channels == 16, padded_input_channels=16 if input_channels == 16 else None, debug=not (batch_size == 20 and input_height == 115), + auto_shard=auto_shard, ) @@ -713,6 +723,7 @@ def test_resnet50_conv_gs( @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) @pytest.mark.parametrize("has_bias", [True, False], ids=["with_bias", "no_bias"]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_resnet50_conv_wh( device, use_program_cache, @@ -734,6 +745,7 @@ def test_resnet50_conv_wh( config_override, packer_l1_acc, has_bias, + auto_shard, ): if device.core_grid.y == 7: pytest.skip("Issue #6992: Statically allocated circular buffers in program clash with L1 buffers on core range") @@ -781,6 +793,7 @@ def test_resnet50_conv_wh( packer_l1_acc=packer_l1_acc, fp32_accum=False, has_bias=has_bias, + auto_shard=auto_shard, ) @@ -838,6 +851,7 @@ def test_resnet50_conv_wh( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.HiFi4]) @pytest.mark.parametrize("packer_l1_acc", [True, False], ids=["pack_l1", "no_pack_l1"]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_resnet50_conv_wh_fp32( device, use_program_cache, @@ -859,6 +873,7 @@ def test_resnet50_conv_wh_fp32( use_1d_systolic_array, config_override, packer_l1_acc, + auto_shard, ): if batch_size > 8 and (activations_dtype != ttnn.bfloat8_b or weights_dtype != ttnn.bfloat8_b): pytest.skip("Batch > 8 must be run fully bfp8") @@ -899,6 +914,7 @@ def test_resnet50_conv_wh_fp32( fp32_accum=fp32_accum, packer_l1_acc=packer_l1_acc, transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH + auto_shard=auto_shard, ) @@ -1249,6 +1265,7 @@ def test_sd_conv_wh( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_unet_conv( device, use_program_cache, @@ -1270,6 +1287,7 @@ def test_unet_conv( config_override, use_shallow_conv_variant, output_layout, + auto_shard, ): if is_blackhole(): pytest.skip("This test is for Grayskull only") @@ -1299,6 +1317,7 @@ def test_unet_conv( use_shallow_conv_variant=use_shallow_conv_variant, padded_input_channels=16 if input_channels == 3 else None, output_layout=output_layout, + auto_shard=auto_shard, ) @@ -1339,6 +1358,7 @@ def test_unet_conv( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_unet_conv_wh( device, use_program_cache, @@ -1360,6 +1380,7 @@ def test_unet_conv_wh( config_override, use_shallow_conv_variant, output_layout, + auto_shard, ): if (device.compute_with_storage_grid_size().x, device.compute_with_storage_grid_size().y) == (8, 7): pytest.skip("Test is not supported on n300 (8,7) grid") @@ -1389,6 +1410,7 @@ def test_unet_conv_wh( transpose_mcast=use_1d_systolic_array, ## use RM (transpose_mcast=False) with 2D on WH padded_input_channels=None, output_layout=output_layout, + auto_shard=auto_shard, ) @@ -1406,6 +1428,7 @@ def test_unet_conv_wh( ), ) @pytest.mark.parametrize("use_1d_systolic_array", [False, True]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_halo_reshard_conv( device, use_program_cache, @@ -1422,6 +1445,7 @@ def test_halo_reshard_conv( pad_h, pad_w, config_override, + auto_shard, ): math_fidelity = ttnn.MathFidelity.HiFi4 activations_dtype = ttnn.bfloat16 @@ -1445,6 +1469,7 @@ def test_halo_reshard_conv( pad_w, use_1d_systolic_array, config_override, + auto_shard=auto_shard, ) @@ -1461,6 +1486,7 @@ def test_halo_reshard_conv( ), ) @pytest.mark.parametrize("use_1d_systolic_array", [False, True]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_conv_core_nondivis( device, use_program_cache, @@ -1478,6 +1504,7 @@ def test_conv_core_nondivis( pad_w, config_override, xfail, + auto_shard, ): if xfail: pytest.xfail() @@ -1504,6 +1531,7 @@ def test_conv_core_nondivis( pad_w, use_1d_systolic_array, config_override, + auto_shard=auto_shard, ) @@ -1538,6 +1566,7 @@ def test_conv_core_nondivis( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) @pytest.mark.parametrize( "filter, dilation, pad", [ @@ -1563,6 +1592,7 @@ def test_conv_dilation( pad, output_layout, dilation, + auto_shard, ): config_override = {"act_block_w_div": act_block_w_div} run_conv( @@ -1587,6 +1617,7 @@ def test_conv_dilation( output_layout=output_layout, dilation=dilation, has_bias=False, + auto_shard=auto_shard, ) @@ -1632,6 +1663,8 @@ def test_conv_dilation( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) +# ToDo: Renable this when auto shard heuristic is imporved, currently we run out of L1 in for some test cases +# @pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_conv_groups( device, use_program_cache, @@ -1745,6 +1778,7 @@ def test_conv_groups( @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) # @pytest.mark.parametrize("output_layout", [ttnn.ROW_MAJOR_LAYOUT, ttnn.TILE_LAYOUT]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_yolov4_conv_groups_larger_than_one( device, use_program_cache, @@ -1767,6 +1801,7 @@ def test_yolov4_conv_groups_larger_than_one( use_shallow_conv_variant, groups, output_layout, + auto_shard, ): if output_layout == ttnn.ROW_MAJOR_LAYOUT and activations_dtype == ttnn.bfloat8_b: pytest.skip("Row major layout not compatible with bfloat8_b") @@ -1794,6 +1829,7 @@ def test_yolov4_conv_groups_larger_than_one( groups=groups, padded_input_channels=16 if input_channels == 3 else None, output_layout=output_layout, + auto_shard=auto_shard, ) @@ -1816,6 +1852,7 @@ def test_yolov4_conv_groups_larger_than_one( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_swin_s_conv( device, use_program_cache, @@ -1838,6 +1875,7 @@ def test_swin_s_conv( use_shallow_conv_variant, groups, output_layout, + auto_shard, ): if device.core_grid.y == 7: pytest.skip("This test is not supported for N300") @@ -1864,6 +1902,7 @@ def test_swin_s_conv( use_shallow_conv_variant=use_shallow_conv_variant, groups=groups, output_layout=output_layout, + auto_shard=auto_shard, ) @@ -1893,6 +1932,7 @@ def test_swin_s_conv( ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) @pytest.mark.parametrize("output_layout", [ttnn.TILE_LAYOUT]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) @skip_for_grayskull() def test_conv_for_segformer_512x512( device, @@ -1916,6 +1956,7 @@ def test_conv_for_segformer_512x512( use_shallow_conv_variant, groups, output_layout, + auto_shard, ): run_conv( device, @@ -1939,6 +1980,7 @@ def test_conv_for_segformer_512x512( groups=groups, output_layout=output_layout, shard_layout=shard_layout, + auto_shard=auto_shard, ) @@ -1963,6 +2005,7 @@ def test_conv_for_segformer_512x512( [ttnn.bfloat8_b], ) @pytest.mark.parametrize("math_fidelity", [ttnn.MathFidelity.LoFi]) +@pytest.mark.parametrize("auto_shard", [True, False], ids=["auto_shard", "no_auto_shard"]) def test_model_k_256x256( device, use_program_cache, @@ -1984,6 +2027,7 @@ def test_model_k_256x256( dilation_w, groups, use_1d_systolic_array, + auto_shard, ): run_conv( device, @@ -2004,6 +2048,7 @@ def test_model_k_256x256( use_1d_systolic_array, None, dilation=dilation_h, + auto_shard=auto_shard, ) diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp index 8b73c50692e..9fd54ab64b9 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.cpp @@ -6,12 +6,13 @@ #include #include +#include "impl/buffers/buffer_constants.hpp" #include "ttnn/operations/pool/downsample/device/downsample_op.hpp" #include "tt_metal/detail/reports/memory_reporter.hpp" -#include "ttnn/operations/core/to_dtype/to_dtype_op.hpp" #include "tt_metal/common/work_split.hpp" #include "ttnn/operations/eltwise/unary/common/unary_op_utils.hpp" #include "ttnn/cpp/ttnn/operations/data_movement/reshape_view/reshape.hpp" +#include "ttnn/tensor/tensor.hpp" using namespace tt; namespace ttnn { @@ -313,16 +314,63 @@ OptimizedConvBlockConfig determine_per_core_conv_block_config( .out_subblock_w_ntiles = out_subblock_w_ntiles}; } -template +// Implements a heuristic for selecting shard layout based on the input tensors shapes +// and stride. +static TensorMemoryLayout select_shard_layout( + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + std::array kernel_size, + std::array stride) { + // ToDo: enhance shard layout selection logic to check + // which sharding scheme maximizes core count (and consequently available L1 space) + + TensorMemoryLayout shard_layout; + + // 1d convs support only height sharding + bool is_conv1d = width == 1 && kernel_size[1] == 1; + // block and width sharding support very few configurations of kernel size and stride + // which are encoded below. + bool is_width_or_block_sharding_valid = + (kernel_size[0] == 3 && kernel_size[1] == 3 && (stride[0] == 1 || stride[0] == 2)) || + (kernel_size[0] == 1 && kernel_size[1] == 1 && stride[0] == 2); + + if (is_conv1d || !is_width_or_block_sharding_valid) { + log_debug(LogOp, "Conv which can only be supported by TensorMemoryLayout::HEIGHT_SHARDED"); + shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; + } else { + float nhw = height * width * batch_size; + float ratio = nhw / in_channels; + log_debug(LogOp, "NHW: {}, C: {}, ratio: {}", nhw, in_channels, ratio); + + if (ratio > 8.0f) { + shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; + log_debug(LogOp, "Shard layout: TensorMemoryLayout::HEIGHT_SHARDED"); + } else if (ratio < 0.4f) { + shard_layout = TensorMemoryLayout::WIDTH_SHARDED; + log_debug(LogOp, "Shard layout: TensorMemoryLayout::WIDTH_SHARDED"); + } else { + log_debug(LogOp, "Shard layout: TensorMemoryLayout::BLOCK_SHARDED"); + shard_layout = TensorMemoryLayout::BLOCK_SHARDED; + } + } + + return shard_layout; +} + +template std::tuple get_conv_padded_input_shape_and_mem_config( - T * device, + T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, uint32_t batch_size, uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels) { + uint32_t out_channels, + std::array kernel_size, + std::array stride) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); bool needs_shard_or_reshard = false; @@ -331,6 +379,13 @@ std::tuple get_conv_padded_input_shape_an false, "Incorrect config provided: reshard_if_not_optimal and override_sharding_config cannot both be set."); } + + TensorMemoryLayout shard_layout; + if (conv_config.shard_layout.has_value()) { + shard_layout = conv_config.shard_layout.value(); + } else { + shard_layout = select_shard_layout(batch_size, height, width, in_channels, kernel_size, stride); + } ParallelConfig input_tensor_parallel_config; if (!input_tensor_on_device) { needs_shard_or_reshard = true; @@ -357,15 +412,16 @@ std::tuple get_conv_padded_input_shape_an needs_shard_or_reshard = true; } if (conv_config.override_sharding_config) { - TT_FATAL(conv_config.core_grid.has_value(), "Error"); + TT_FATAL(conv_config.core_grid.has_value(), "If override_sharding_config is set, core_grid must be set as well."); + TT_FATAL(conv_config.shard_layout.has_value(), "If override_sharding_config is set, shard_layout must be set as well."); if (conv_config.core_grid.value() != input_shard_grid) { needs_shard_or_reshard = true; } - if(conv_config.shard_layout!=input_shard_scheme) { + if(shard_layout!=input_shard_scheme) { needs_shard_or_reshard = true; } bool input_transpose_shards = input_shard_orientation == ShardOrientation::COL_MAJOR; - if (conv_config.shard_layout == TensorMemoryLayout::BLOCK_SHARDED && conv_config.transpose_shards != input_transpose_shards) { + if (shard_layout == TensorMemoryLayout::BLOCK_SHARDED && conv_config.transpose_shards != input_transpose_shards) { needs_shard_or_reshard = true; } } @@ -376,23 +432,17 @@ std::tuple get_conv_padded_input_shape_an auto block_shard_orientation = conv_config.transpose_shards ? ShardOrientation::COL_MAJOR : ShardOrientation::ROW_MAJOR; const ParallelConfig& optimal_parallel_config = determine_parallel_config( - conv_config.shard_layout, - batch_size, - in_channels, - height, - width, - out_channels, - device, - block_shard_orientation); + shard_layout, batch_size, in_channels, height, width, out_channels, device, block_shard_orientation); if (conv_config.override_sharding_config) { TT_FATAL(conv_config.core_grid.has_value(), "Error"); // override parallel config - auto shard_orientation = - conv_config.shard_layout==TensorMemoryLayout::BLOCK_SHARDED ? block_shard_orientation: ShardOrientation::ROW_MAJOR; + auto shard_orientation = shard_layout == TensorMemoryLayout::BLOCK_SHARDED + ? block_shard_orientation + : ShardOrientation::ROW_MAJOR; parallel_config = { .grid = conv_config.core_grid.value(), - .shard_scheme = conv_config.shard_layout, + .shard_scheme = shard_layout, .shard_orientation = shard_orientation}; } else { parallel_config = optimal_parallel_config; @@ -405,7 +455,7 @@ std::tuple get_conv_padded_input_shape_an uint32_t input_num_cores_nhw = get_num_cores_nhw_from_parallel_config(parallel_config); // TT_ASSERT(input_tensor.get_legacy_shape() == input_tensor.get_shape()); uint32_t tensor_height = input_tensor.get_shape()[0] * input_tensor.get_shape()[1] * input_tensor.get_shape()[2]; - uint32_t input_tensor_height_snapped_to_tile = (conv_config.shard_layout == TensorMemoryLayout::WIDTH_SHARDED)? tensor_height : tt::round_up(tensor_height, input_num_cores_nhw * 32); + uint32_t input_tensor_height_snapped_to_tile = (shard_layout == TensorMemoryLayout::WIDTH_SHARDED)? tensor_height : tt::round_up(tensor_height, input_num_cores_nhw * 32); TT_ASSERT(input_tensor_height_snapped_to_tile >= tensor_height); uint32_t tensor_width = input_tensor.get_shape()[3]; uint32_t input_tensor_width_snapped_to_channels_alignment = @@ -429,20 +479,24 @@ std::tuple get_conv_padded_input_shape_an } } -template +template std::tuple shard_or_reshard_tensor_if_required( - T * device, + T* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, uint32_t batch_size, uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels) { + uint32_t out_channels, + std::array kernel_size, + std::array stride) { ttnn::Tensor input_tensor = input_tensor_; // tensor to return bool input_tensor_on_device = ttnn::is_tensor_on_device_or_multidevice(input_tensor_); - auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard] = get_conv_padded_input_shape_and_mem_config(device, input_tensor_, conv_config, batch_size, height, width, in_channels, out_channels); + auto [input_padded_shape, input_tensor_sharded_memory_config, needs_shard_or_reshard] = + get_conv_padded_input_shape_and_mem_config( + device, input_tensor_, conv_config, batch_size, height, width, in_channels, out_channels, kernel_size, stride); ParallelConfig parallel_config = { .grid = input_tensor_sharded_memory_config.shard_spec.value().grid, .shard_scheme = input_tensor_sharded_memory_config.memory_layout, @@ -676,7 +730,7 @@ std::tuple( bool is_out_tiled); template std::tuple get_conv_padded_input_shape_and_mem_config( - Device * device, + Device* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, uint32_t batch_size, uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels); + uint32_t out_channels, + std::array kernel_size, + std::array stride); template std::tuple get_conv_padded_input_shape_and_mem_config( MeshDevice * device, @@ -927,17 +983,21 @@ template std::tuple get_conv_padded_input uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels); + uint32_t out_channels, + std::array kernel_size, + std::array stride); template std::tuple shard_or_reshard_tensor_if_required( - Device * device, + Device* device, const ttnn::Tensor& input_tensor_, const Conv2dConfig& conv_config, uint32_t batch_size, uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels); + uint32_t out_channels, + std::array kernel_size, + std::array stride); template std::tuple shard_or_reshard_tensor_if_required( MeshDevice * device, @@ -947,7 +1007,9 @@ template std::tuple shard_or_reshard_tensor_ uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels); + uint32_t out_channels, + std::array kernel_size, + std::array stride); template std::pair> prepare_conv_weights_biases_and_move_to_device( const ttnn::Tensor& weight_tensor, diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp index 97302a83727..428b6b11387 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d.hpp @@ -40,7 +40,7 @@ struct Conv2dConfig { uint32_t act_block_w_div = 1; //Amount by which the maximum possible act_block_width is divided. Max act_block_w = (in_channels * window_w * window_h)/total_num_cores; bool reshard_if_not_optimal = false; // if true, override_sharding_config should not be set to true bool override_sharding_config = false; // if true, reshard_if_not_optimal should not be set to true - TensorMemoryLayout shard_layout = TensorMemoryLayout::HEIGHT_SHARDED; // used only if override_sharding_config is true + std::optional shard_layout; std::optional core_grid = std::nullopt; // used only if override_sharding_config is true bool transpose_shards = true; // used only if override_sharding_config is true and if height sharding is false Layout output_layout = Layout::TILE; @@ -134,7 +134,9 @@ std::tuple get_conv_padded_input_shape_an uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels); + uint32_t out_channels, + std::array kernel_size, + std::array stride); template std::tuple shard_or_reshard_tensor_if_required( @@ -145,7 +147,9 @@ std::tuple shard_or_reshard_ uint32_t height, uint32_t width, uint32_t in_channels, - uint32_t out_channels); + uint32_t out_channels, + std::array kernel_size, + std::array stride); void validate_weight_and_bias_tensors(const ttnn::Tensor& weight_tensor, std::optional& bias_tensor); diff --git a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp index 0e5a7711f64..e41d9605e8c 100644 --- a/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/conv/conv2d/conv2d_pybind.cpp @@ -117,16 +117,18 @@ void py_bind_conv2d(py::module& module) { module.def( "get_conv_padded_input_shape_and_mem_config", - [](ttnn::Device * device, - const ttnn::Tensor& input_tensor, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels) -> std::tuple { + [](ttnn::Device* device, + const ttnn::Tensor& input_tensor, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels, + std::array kernel_size, + std::array stride) -> std::tuple { return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( - device, input_tensor, conv_config, batch_size, height, width, in_channels, out_channels); + device, input_tensor, conv_config, batch_size, height, width, in_channels, out_channels, kernel_size, stride); }, py::kw_only(), py::arg("device"), @@ -136,20 +138,24 @@ void py_bind_conv2d(py::module& module) { py::arg("height"), py::arg("width"), py::arg("in_channels"), - py::arg("out_channels")); + py::arg("out_channels"), + py::arg("kernel_size"), + py::arg("stride")); module.def( "get_conv_padded_input_shape_and_mem_config", - [](MeshDevice * device, - const ttnn::Tensor& input_tensor, - const Conv2dConfig& conv_config, - uint32_t batch_size, - uint32_t height, - uint32_t width, - uint32_t in_channels, - uint32_t out_channels) -> std::tuple { + [](MeshDevice* device, + const ttnn::Tensor& input_tensor, + const Conv2dConfig& conv_config, + uint32_t batch_size, + uint32_t height, + uint32_t width, + uint32_t in_channels, + uint32_t out_channels, + std::array kernel_size, + std::array stride) -> std::tuple { return ttnn::operations::conv::conv2d::get_conv_padded_input_shape_and_mem_config( - device, input_tensor, conv_config, batch_size, height, width, in_channels, out_channels); + device, input_tensor, conv_config, batch_size, height, width, in_channels, out_channels, kernel_size, stride); }, py::kw_only(), py::arg("device"), @@ -159,7 +165,9 @@ void py_bind_conv2d(py::module& module) { py::arg("height"), py::arg("width"), py::arg("in_channels"), - py::arg("out_channels")); + py::arg("out_channels"), + py::arg("kernel_size"), + py::arg("stride")); module.def( "convert_conv_weight_tensor_to_tiled_layout", @@ -186,7 +194,7 @@ void py_bind_conv2d(py::module& module) { auto py_conv_config = py::class_(module, "Conv2dConfig"); py_conv_config.def( - py::init, bool, Layout, bool, bool, bool>(), + py::init, std::optional, bool, Layout, bool, bool, bool>(), py::kw_only(), py::arg("math_fidelity") = MathFidelity::HiFi4, py::arg("dtype") = DataType::BFLOAT16, @@ -202,7 +210,7 @@ void py_bind_conv2d(py::module& module) { py::arg("act_block_w_div") = 1, py::arg("reshard_if_not_optimal") = false, py::arg("override_sharding_config") = false, - py::arg("shard_layout") = TensorMemoryLayout::HEIGHT_SHARDED, + py::arg("shard_layout") = std::nullopt, py::arg("core_grid") = std::nullopt, py::arg("transpose_shards") = true, py::arg("output_layout") = Layout::TILE, From 2aaf77829d5571f5f1f2286ac812c42475dccf38 Mon Sep 17 00:00:00 2001 From: Bryan Wilder Field Lozano Date: Wed, 9 Oct 2024 11:59:17 -0400 Subject: [PATCH 39/58] #13619: Bring fmt dependency from CPM (#13620) --- cmake/dependencies.cmake | 14 + tests/CMakeLists.txt | 2 +- .../tt_metal/unit_tests_common/CMakeLists.txt | 2 +- tt_metal/CMakeLists.txt | 4 +- tt_metal/common/CMakeLists.txt | 2 +- tt_metal/third_party/fmt/fmt/args.h | 228 - tt_metal/third_party/fmt/fmt/base.h | 3061 ------------ tt_metal/third_party/fmt/fmt/chrono.h | 2432 --------- tt_metal/third_party/fmt/fmt/color.h | 612 --- tt_metal/third_party/fmt/fmt/compile.h | 529 -- tt_metal/third_party/fmt/fmt/core.h | 5 - tt_metal/third_party/fmt/fmt/format-inl.h | 1904 ------- tt_metal/third_party/fmt/fmt/format.h | 4419 ----------------- tt_metal/third_party/fmt/fmt/os.h | 439 -- tt_metal/third_party/fmt/fmt/ostream.h | 211 - tt_metal/third_party/fmt/fmt/printf.h | 656 --- tt_metal/third_party/fmt/fmt/ranges.h | 882 ---- tt_metal/third_party/fmt/fmt/std.h | 699 --- tt_metal/third_party/fmt/fmt/xchar.h | 322 -- ttnn/CMakeLists.txt | 2 - 20 files changed, 18 insertions(+), 16407 deletions(-) delete mode 100644 tt_metal/third_party/fmt/fmt/args.h delete mode 100644 tt_metal/third_party/fmt/fmt/base.h delete mode 100644 tt_metal/third_party/fmt/fmt/chrono.h delete mode 100644 tt_metal/third_party/fmt/fmt/color.h delete mode 100644 tt_metal/third_party/fmt/fmt/compile.h delete mode 100644 tt_metal/third_party/fmt/fmt/core.h delete mode 100644 tt_metal/third_party/fmt/fmt/format-inl.h delete mode 100644 tt_metal/third_party/fmt/fmt/format.h delete mode 100644 tt_metal/third_party/fmt/fmt/os.h delete mode 100644 tt_metal/third_party/fmt/fmt/ostream.h delete mode 100644 tt_metal/third_party/fmt/fmt/printf.h delete mode 100644 tt_metal/third_party/fmt/fmt/ranges.h delete mode 100644 tt_metal/third_party/fmt/fmt/std.h delete mode 100644 tt_metal/third_party/fmt/fmt/xchar.h diff --git a/cmake/dependencies.cmake b/cmake/dependencies.cmake index c8b7384d75e..a598b632fc7 100644 --- a/cmake/dependencies.cmake +++ b/cmake/dependencies.cmake @@ -56,8 +56,22 @@ CPMAddPackage( GIT_TAG v1.1.1 ) +############################################################################################################################ +# magic_enum : https://github.com/Neargye/magic_enum +############################################################################################################################ + CPMAddPackage( NAME magic_enum GITHUB_REPOSITORY Neargye/magic_enum GIT_TAG v0.9.6 ) + +############################################################################################################################ +# fmt : https://github.com/fmtlib/fmt +############################################################################################################################ + +CPMAddPackage( + NAME fmt + GITHUB_REPOSITORY fmtlib/fmt + GIT_TAG 11.0.1 +) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1738b45d300..1474dc932c1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,7 +2,7 @@ enable_testing() include(GoogleTest) add_library(test_common_libs INTERFACE) -target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main magic_enum) +target_link_libraries(test_common_libs INTERFACE pthread gtest gtest_main magic_enum fmt) if(TT_METAL_BUILD_TESTS) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/tt_metal/tt_metal) diff --git a/tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt b/tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt index d488cd254b1..b9d67fb184d 100644 --- a/tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt +++ b/tests/tt_metal/tt_metal/unit_tests_common/CMakeLists.txt @@ -29,7 +29,7 @@ set(UNIT_TESTS_COMMON_SRC ${CMAKE_CURRENT_SOURCE_DIR}/watcher/test_link_training.cpp ) add_library(unit_tests_common_o OBJECT ${UNIT_TESTS_COMMON_SRC}) -target_link_libraries(unit_tests_common_o PUBLIC compiler_flags metal_header_directories gtest gtest_main magic_enum) +target_link_libraries(unit_tests_common_o PUBLIC compiler_flags metal_header_directories gtest gtest_main magic_enum fmt) target_include_directories(unit_tests_common_o PUBLIC ${UMD_HOME} ${PROJECT_SOURCE_DIR} diff --git a/tt_metal/CMakeLists.txt b/tt_metal/CMakeLists.txt index 32d09bcc0e9..cfecfa2cdb7 100644 --- a/tt_metal/CMakeLists.txt +++ b/tt_metal/CMakeLists.txt @@ -23,12 +23,10 @@ set(TT_METAL_OBJECTS add_library(tt_metal ${TT_METAL_OBJECTS}) -target_link_libraries(tt_metal PUBLIC metal_header_directories umd_device metal_common_libs magic_enum) +target_link_libraries(tt_metal PUBLIC metal_header_directories umd_device metal_common_libs magic_enum fmt) target_precompile_headers(tt_metal PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/third_party/tracy/public/tracy/Tracy.hpp - ${CMAKE_CURRENT_SOURCE_DIR}/third_party/fmt/fmt/core.h - ${CMAKE_CURRENT_SOURCE_DIR}/third_party/fmt/fmt/format.h diff --git a/tt_metal/common/CMakeLists.txt b/tt_metal/common/CMakeLists.txt index d5621964729..2de753f9fed 100644 --- a/tt_metal/common/CMakeLists.txt +++ b/tt_metal/common/CMakeLists.txt @@ -8,7 +8,7 @@ set(COMMON_SRCS add_library(common OBJECT ${COMMON_SRCS}) target_link_libraries(common PRIVATE yaml-cpp::yaml-cpp) -target_link_libraries(common PUBLIC compiler_flags metal_header_directories magic_enum) +target_link_libraries(common PUBLIC compiler_flags metal_header_directories magic_enum fmt) target_include_directories(common PUBLIC ${UMD_HOME} diff --git a/tt_metal/third_party/fmt/fmt/args.h b/tt_metal/third_party/fmt/fmt/args.h deleted file mode 100644 index 31a60e8faf1..00000000000 --- a/tt_metal/third_party/fmt/fmt/args.h +++ /dev/null @@ -1,228 +0,0 @@ -// Formatting library for C++ - dynamic argument lists -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_ARGS_H_ -#define FMT_ARGS_H_ - -#ifndef FMT_MODULE -# include // std::reference_wrapper -# include // std::unique_ptr -# include -#endif - -#include "format.h" // std_string_view - -FMT_BEGIN_NAMESPACE - -namespace detail { - -template struct is_reference_wrapper : std::false_type {}; -template -struct is_reference_wrapper> : std::true_type {}; - -template auto unwrap(const T& v) -> const T& { return v; } -template -auto unwrap(const std::reference_wrapper& v) -> const T& { - return static_cast(v); -} - -// node is defined outside dynamic_arg_list to workaround a C2504 bug in MSVC -// 2022 (v17.10.0). -// -// Workaround for clang's -Wweak-vtables. Unlike for regular classes, for -// templates it doesn't complain about inability to deduce single translation -// unit for placing vtable. So node is made a fake template. -template struct node { - virtual ~node() = default; - std::unique_ptr> next; -}; - -class dynamic_arg_list { - template struct typed_node : node<> { - T value; - - template - FMT_CONSTEXPR typed_node(const Arg& arg) : value(arg) {} - - template - FMT_CONSTEXPR typed_node(const basic_string_view& arg) - : value(arg.data(), arg.size()) {} - }; - - std::unique_ptr> head_; - - public: - template auto push(const Arg& arg) -> const T& { - auto new_node = std::unique_ptr>(new typed_node(arg)); - auto& value = new_node->value; - new_node->next = std::move(head_); - head_ = std::move(new_node); - return value; - } -}; -} // namespace detail - -/** - * A dynamic list of formatting arguments with storage. - * - * It can be implicitly converted into `fmt::basic_format_args` for passing - * into type-erased formatting functions such as `fmt::vformat`. - */ -template -class dynamic_format_arg_store -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 - // Workaround a GCC template argument substitution bug. - : public basic_format_args -#endif -{ - private: - using char_type = typename Context::char_type; - - template struct need_copy { - static constexpr detail::type mapped_type = - detail::mapped_type_constant::value; - - enum { - value = !(detail::is_reference_wrapper::value || - std::is_same>::value || - std::is_same>::value || - (mapped_type != detail::type::cstring_type && - mapped_type != detail::type::string_type && - mapped_type != detail::type::custom_type)) - }; - }; - - template - using stored_type = conditional_t< - std::is_convertible>::value && - !detail::is_reference_wrapper::value, - std::basic_string, T>; - - // Storage of basic_format_arg must be contiguous. - std::vector> data_; - std::vector> named_info_; - - // Storage of arguments not fitting into basic_format_arg must grow - // without relocation because items in data_ refer to it. - detail::dynamic_arg_list dynamic_args_; - - friend class basic_format_args; - - auto get_types() const -> unsigned long long { - return detail::is_unpacked_bit | data_.size() | - (named_info_.empty() - ? 0ULL - : static_cast(detail::has_named_args_bit)); - } - - auto data() const -> const basic_format_arg* { - return named_info_.empty() ? data_.data() : data_.data() + 1; - } - - template void emplace_arg(const T& arg) { - data_.emplace_back(detail::make_arg(arg)); - } - - template - void emplace_arg(const detail::named_arg& arg) { - if (named_info_.empty()) { - constexpr const detail::named_arg_info* zero_ptr{nullptr}; - data_.insert(data_.begin(), {zero_ptr, 0}); - } - data_.emplace_back(detail::make_arg(detail::unwrap(arg.value))); - auto pop_one = [](std::vector>* data) { - data->pop_back(); - }; - std::unique_ptr>, decltype(pop_one)> - guard{&data_, pop_one}; - named_info_.push_back({arg.name, static_cast(data_.size() - 2u)}); - data_[0].value_.named_args = {named_info_.data(), named_info_.size()}; - guard.release(); - } - - public: - constexpr dynamic_format_arg_store() = default; - - /** - * Adds an argument into the dynamic store for later passing to a formatting - * function. - * - * Note that custom types and string types (but not string views) are copied - * into the store dynamically allocating memory if necessary. - * - * **Example**: - * - * fmt::dynamic_format_arg_store store; - * store.push_back(42); - * store.push_back("abc"); - * store.push_back(1.5f); - * std::string result = fmt::vformat("{} and {} and {}", store); - */ - template void push_back(const T& arg) { - if (detail::const_check(need_copy::value)) - emplace_arg(dynamic_args_.push>(arg)); - else - emplace_arg(detail::unwrap(arg)); - } - - /** - * Adds a reference to the argument into the dynamic store for later passing - * to a formatting function. - * - * **Example**: - * - * fmt::dynamic_format_arg_store store; - * char band[] = "Rolling Stones"; - * store.push_back(std::cref(band)); - * band[9] = 'c'; // Changing str affects the output. - * std::string result = fmt::vformat("{}", store); - * // result == "Rolling Scones" - */ - template void push_back(std::reference_wrapper arg) { - static_assert( - need_copy::value, - "objects of built-in types and string views are always copied"); - emplace_arg(arg.get()); - } - - /** - * Adds named argument into the dynamic store for later passing to a - * formatting function. `std::reference_wrapper` is supported to avoid - * copying of the argument. The name is always copied into the store. - */ - template - void push_back(const detail::named_arg& arg) { - const char_type* arg_name = - dynamic_args_.push>(arg.name).c_str(); - if (detail::const_check(need_copy::value)) { - emplace_arg( - fmt::arg(arg_name, dynamic_args_.push>(arg.value))); - } else { - emplace_arg(fmt::arg(arg_name, arg.value)); - } - } - - /// Erase all elements from the store. - void clear() { - data_.clear(); - named_info_.clear(); - dynamic_args_ = detail::dynamic_arg_list(); - } - - /// Reserves space to store at least `new_cap` arguments including - /// `new_cap_named` named arguments. - void reserve(size_t new_cap, size_t new_cap_named) { - FMT_ASSERT(new_cap >= new_cap_named, - "Set of arguments includes set of named arguments"); - data_.reserve(new_cap); - named_info_.reserve(new_cap_named); - } -}; - -FMT_END_NAMESPACE - -#endif // FMT_ARGS_H_ diff --git a/tt_metal/third_party/fmt/fmt/base.h b/tt_metal/third_party/fmt/fmt/base.h deleted file mode 100644 index f440cffd20d..00000000000 --- a/tt_metal/third_party/fmt/fmt/base.h +++ /dev/null @@ -1,3061 +0,0 @@ -// Formatting library for C++ - the base API for char/UTF-8 -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_BASE_H_ -#define FMT_BASE_H_ - -#if defined(FMT_IMPORT_STD) && !defined(FMT_MODULE) -# define FMT_MODULE -#endif - -#ifndef FMT_MODULE -# include // CHAR_BIT -# include // FILE -# include // strlen - -// is also included transitively from . -# include // std::byte -# include // std::enable_if -#endif - -// The fmt library version in the form major * 10000 + minor * 100 + patch. -#define FMT_VERSION 110001 - -// Detect compiler versions. -#if defined(__clang__) && !defined(__ibmxl__) -# define FMT_CLANG_VERSION (__clang_major__ * 100 + __clang_minor__) -#else -# define FMT_CLANG_VERSION 0 -#endif -#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) -# define FMT_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) -#else -# define FMT_GCC_VERSION 0 -#endif -#if defined(__ICL) -# define FMT_ICC_VERSION __ICL -#elif defined(__INTEL_COMPILER) -# define FMT_ICC_VERSION __INTEL_COMPILER -#else -# define FMT_ICC_VERSION 0 -#endif -#if defined(_MSC_VER) -# define FMT_MSC_VERSION _MSC_VER -#else -# define FMT_MSC_VERSION 0 -#endif - -// Detect standard library versions. -#ifdef _GLIBCXX_RELEASE -# define FMT_GLIBCXX_RELEASE _GLIBCXX_RELEASE -#else -# define FMT_GLIBCXX_RELEASE 0 -#endif -#ifdef _LIBCPP_VERSION -# define FMT_LIBCPP_VERSION _LIBCPP_VERSION -#else -# define FMT_LIBCPP_VERSION 0 -#endif - -#ifdef _MSVC_LANG -# define FMT_CPLUSPLUS _MSVC_LANG -#else -# define FMT_CPLUSPLUS __cplusplus -#endif - -// Detect __has_*. -#ifdef __has_feature -# define FMT_HAS_FEATURE(x) __has_feature(x) -#else -# define FMT_HAS_FEATURE(x) 0 -#endif -#ifdef __has_include -# define FMT_HAS_INCLUDE(x) __has_include(x) -#else -# define FMT_HAS_INCLUDE(x) 0 -#endif -#ifdef __has_cpp_attribute -# define FMT_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) -#else -# define FMT_HAS_CPP_ATTRIBUTE(x) 0 -#endif - -#define FMT_HAS_CPP14_ATTRIBUTE(attribute) \ - (FMT_CPLUSPLUS >= 201402L && FMT_HAS_CPP_ATTRIBUTE(attribute)) - -#define FMT_HAS_CPP17_ATTRIBUTE(attribute) \ - (FMT_CPLUSPLUS >= 201703L && FMT_HAS_CPP_ATTRIBUTE(attribute)) - -// Detect C++14 relaxed constexpr. -#ifdef FMT_USE_CONSTEXPR -// Use the provided definition. -#elif FMT_GCC_VERSION >= 600 && FMT_CPLUSPLUS >= 201402L -// GCC only allows throw in constexpr since version 6: -// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=67371. -# define FMT_USE_CONSTEXPR 1 -#elif FMT_ICC_VERSION -# define FMT_USE_CONSTEXPR 0 // https://github.com/fmtlib/fmt/issues/1628 -#elif FMT_HAS_FEATURE(cxx_relaxed_constexpr) || FMT_MSC_VERSION >= 1912 -# define FMT_USE_CONSTEXPR 1 -#else -# define FMT_USE_CONSTEXPR 0 -#endif -#if FMT_USE_CONSTEXPR -# define FMT_CONSTEXPR constexpr -#else -# define FMT_CONSTEXPR -#endif - -// Detect consteval, C++20 constexpr extensions and std::is_constant_evaluated. -#if !defined(__cpp_lib_is_constant_evaluated) -# define FMT_USE_CONSTEVAL 0 -#elif FMT_CPLUSPLUS < 201709L -# define FMT_USE_CONSTEVAL 0 -#elif FMT_GLIBCXX_RELEASE && FMT_GLIBCXX_RELEASE < 10 -# define FMT_USE_CONSTEVAL 0 -#elif FMT_LIBCPP_VERSION && FMT_LIBCPP_VERSION < 10000 -# define FMT_USE_CONSTEVAL 0 -#elif defined(__apple_build_version__) && __apple_build_version__ < 14000029L -# define FMT_USE_CONSTEVAL 0 // consteval is broken in Apple clang < 14. -#elif FMT_MSC_VERSION && FMT_MSC_VERSION < 1929 -# define FMT_USE_CONSTEVAL 0 // consteval is broken in MSVC VS2019 < 16.10. -#elif defined(__cpp_consteval) -# define FMT_USE_CONSTEVAL 1 -#elif FMT_GCC_VERSION >= 1002 || FMT_CLANG_VERSION >= 1101 -# define FMT_USE_CONSTEVAL 1 -#else -# define FMT_USE_CONSTEVAL 0 -#endif -#if FMT_USE_CONSTEVAL -# define FMT_CONSTEVAL consteval -# define FMT_CONSTEXPR20 constexpr -#else -# define FMT_CONSTEVAL -# define FMT_CONSTEXPR20 -#endif - -#if defined(FMT_USE_NONTYPE_TEMPLATE_ARGS) -// Use the provided definition. -#elif defined(__NVCOMPILER) -# define FMT_USE_NONTYPE_TEMPLATE_ARGS 0 -#elif FMT_GCC_VERSION >= 903 && FMT_CPLUSPLUS >= 201709L -# define FMT_USE_NONTYPE_TEMPLATE_ARGS 1 -#elif defined(__cpp_nontype_template_args) && \ - __cpp_nontype_template_args >= 201911L -# define FMT_USE_NONTYPE_TEMPLATE_ARGS 1 -#elif FMT_CLANG_VERSION >= 1200 && FMT_CPLUSPLUS >= 202002L -# define FMT_USE_NONTYPE_TEMPLATE_ARGS 1 -#else -# define FMT_USE_NONTYPE_TEMPLATE_ARGS 0 -#endif - -#ifdef FMT_USE_CONCEPTS -// Use the provided definition. -#elif defined(__cpp_concepts) -# define FMT_USE_CONCEPTS 1 -#else -# define FMT_USE_CONCEPTS 0 -#endif - -// Check if exceptions are disabled. -#ifdef FMT_EXCEPTIONS -// Use the provided definition. -#elif defined(__GNUC__) && !defined(__EXCEPTIONS) -# define FMT_EXCEPTIONS 0 -#elif FMT_MSC_VERSION && !_HAS_EXCEPTIONS -# define FMT_EXCEPTIONS 0 -#else -# define FMT_EXCEPTIONS 1 -#endif -#if FMT_EXCEPTIONS -# define FMT_TRY try -# define FMT_CATCH(x) catch (x) -#else -# define FMT_TRY if (true) -# define FMT_CATCH(x) if (false) -#endif - -#if FMT_HAS_CPP17_ATTRIBUTE(fallthrough) -# define FMT_FALLTHROUGH [[fallthrough]] -#elif defined(__clang__) -# define FMT_FALLTHROUGH [[clang::fallthrough]] -#elif FMT_GCC_VERSION >= 700 && \ - (!defined(__EDG_VERSION__) || __EDG_VERSION__ >= 520) -# define FMT_FALLTHROUGH [[gnu::fallthrough]] -#else -# define FMT_FALLTHROUGH -#endif - -// Disable [[noreturn]] on MSVC/NVCC because of bogus unreachable code warnings. -#if FMT_HAS_CPP_ATTRIBUTE(noreturn) && !FMT_MSC_VERSION && !defined(__NVCC__) -# define FMT_NORETURN [[noreturn]] -#else -# define FMT_NORETURN -#endif - -#ifndef FMT_NODISCARD -# if FMT_HAS_CPP17_ATTRIBUTE(nodiscard) -# define FMT_NODISCARD [[nodiscard]] -# else -# define FMT_NODISCARD -# endif -#endif - -#ifdef FMT_DEPRECATED -// Use the provided definition. -#elif FMT_HAS_CPP14_ATTRIBUTE(deprecated) -# define FMT_DEPRECATED [[deprecated]] -#else -# define FMT_DEPRECATED /* deprecated */ -#endif - -#ifdef FMT_INLINE -// Use the provided definition. -#elif FMT_GCC_VERSION || FMT_CLANG_VERSION -# define FMT_ALWAYS_INLINE inline __attribute__((always_inline)) -#else -# define FMT_ALWAYS_INLINE inline -#endif -// A version of FMT_INLINE to prevent code bloat in debug mode. -#ifdef NDEBUG -# define FMT_INLINE FMT_ALWAYS_INLINE -#else -# define FMT_INLINE inline -#endif - -#if FMT_GCC_VERSION || FMT_CLANG_VERSION -# define FMT_VISIBILITY(value) __attribute__((visibility(value))) -#else -# define FMT_VISIBILITY(value) -#endif - -#ifndef FMT_GCC_PRAGMA -// Workaround a _Pragma bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59884 -// and an nvhpc warning: https://github.com/fmtlib/fmt/pull/2582. -# if FMT_GCC_VERSION >= 504 && !defined(__NVCOMPILER) -# define FMT_GCC_PRAGMA(arg) _Pragma(arg) -# else -# define FMT_GCC_PRAGMA(arg) -# endif -#endif - -// GCC < 5 requires this-> in decltype. -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 500 -# define FMT_DECLTYPE_THIS this-> -#else -# define FMT_DECLTYPE_THIS -#endif - -#if FMT_MSC_VERSION -# define FMT_MSC_WARNING(...) __pragma(warning(__VA_ARGS__)) -# define FMT_UNCHECKED_ITERATOR(It) \ - using _Unchecked_type = It // Mark iterator as checked. -#else -# define FMT_MSC_WARNING(...) -# define FMT_UNCHECKED_ITERATOR(It) using unchecked_type = It -#endif - -#ifndef FMT_BEGIN_NAMESPACE -# define FMT_BEGIN_NAMESPACE \ - namespace fmt { \ - inline namespace v11 { -# define FMT_END_NAMESPACE \ - } \ - } -#endif - -#ifndef FMT_EXPORT -# define FMT_EXPORT -# define FMT_BEGIN_EXPORT -# define FMT_END_EXPORT -#endif - -#if !defined(FMT_HEADER_ONLY) && defined(_WIN32) -# if defined(FMT_LIB_EXPORT) -# define FMT_API __declspec(dllexport) -# elif defined(FMT_SHARED) -# define FMT_API __declspec(dllimport) -# endif -#elif defined(FMT_LIB_EXPORT) || defined(FMT_SHARED) -# define FMT_API FMT_VISIBILITY("default") -#endif -#ifndef FMT_API -# define FMT_API -#endif - -#ifndef FMT_UNICODE -# define FMT_UNICODE 1 -#endif - -// Check if rtti is available. -#ifndef FMT_USE_RTTI -// __RTTI is for EDG compilers. _CPPRTTI is for MSVC. -# if defined(__GXX_RTTI) || FMT_HAS_FEATURE(cxx_rtti) || defined(_CPPRTTI) || \ - defined(__INTEL_RTTI__) || defined(__RTTI) -# define FMT_USE_RTTI 1 -# else -# define FMT_USE_RTTI 0 -# endif -#endif - -#define FMT_FWD(...) static_cast(__VA_ARGS__) - -// Enable minimal optimizations for more compact code in debug mode. -FMT_GCC_PRAGMA("GCC push_options") -#if !defined(__OPTIMIZE__) && !defined(__CUDACC__) -FMT_GCC_PRAGMA("GCC optimize(\"Og\")") -#endif - -FMT_BEGIN_NAMESPACE - -// Implementations of enable_if_t and other metafunctions for older systems. -template -using enable_if_t = typename std::enable_if::type; -template -using conditional_t = typename std::conditional::type; -template using bool_constant = std::integral_constant; -template -using remove_reference_t = typename std::remove_reference::type; -template -using remove_const_t = typename std::remove_const::type; -template -using remove_cvref_t = typename std::remove_cv>::type; -template struct type_identity { - using type = T; -}; -template using type_identity_t = typename type_identity::type; -template -using make_unsigned_t = typename std::make_unsigned::type; -template -using underlying_t = typename std::underlying_type::type; - -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 500 -// A workaround for gcc 4.8 to make void_t work in a SFINAE context. -template struct void_t_impl { - using type = void; -}; -template using void_t = typename void_t_impl::type; -#else -template using void_t = void; -#endif - -struct monostate { - constexpr monostate() {} -}; - -// An enable_if helper to be used in template parameters which results in much -// shorter symbols: https://godbolt.org/z/sWw4vP. Extra parentheses are needed -// to workaround a bug in MSVC 2019 (see #1140 and #1186). -#ifdef FMT_DOC -# define FMT_ENABLE_IF(...) -#else -# define FMT_ENABLE_IF(...) fmt::enable_if_t<(__VA_ARGS__), int> = 0 -#endif - -// This is defined in base.h instead of format.h to avoid injecting in std. -// It is a template to avoid undesirable implicit conversions to std::byte. -#ifdef __cpp_lib_byte -template ::value)> -inline auto format_as(T b) -> unsigned char { - return static_cast(b); -} -#endif - -namespace detail { -// Suppresses "unused variable" warnings with the method described in -// https://herbsutter.com/2009/10/18/mailbag-shutting-up-compiler-warnings/. -// (void)var does not work on many Intel compilers. -template FMT_CONSTEXPR void ignore_unused(const T&...) {} - -constexpr auto is_constant_evaluated(bool default_value = false) noexcept - -> bool { -// Workaround for incompatibility between libstdc++ consteval-based -// std::is_constant_evaluated() implementation and clang-14: -// https://github.com/fmtlib/fmt/issues/3247. -#if FMT_CPLUSPLUS >= 202002L && FMT_GLIBCXX_RELEASE >= 12 && \ - (FMT_CLANG_VERSION >= 1400 && FMT_CLANG_VERSION < 1500) - ignore_unused(default_value); - return __builtin_is_constant_evaluated(); -#elif defined(__cpp_lib_is_constant_evaluated) - ignore_unused(default_value); - return std::is_constant_evaluated(); -#else - return default_value; -#endif -} - -// Suppresses "conditional expression is constant" warnings. -template constexpr auto const_check(T value) -> T { return value; } - -FMT_NORETURN FMT_API void assert_fail(const char* file, int line, - const char* message); - -#if defined(FMT_ASSERT) -// Use the provided definition. -#elif defined(NDEBUG) -// FMT_ASSERT is not empty to avoid -Wempty-body. -# define FMT_ASSERT(condition, message) \ - fmt::detail::ignore_unused((condition), (message)) -#else -# define FMT_ASSERT(condition, message) \ - ((condition) /* void() fails with -Winvalid-constexpr on clang 4.0.1 */ \ - ? (void)0 \ - : fmt::detail::assert_fail(__FILE__, __LINE__, (message))) -#endif - -#ifdef FMT_USE_INT128 -// Do nothing. -#elif defined(__SIZEOF_INT128__) && !defined(__NVCC__) && \ - !(FMT_CLANG_VERSION && FMT_MSC_VERSION) -# define FMT_USE_INT128 1 -using int128_opt = __int128_t; // An optional native 128-bit integer. -using uint128_opt = __uint128_t; -template inline auto convert_for_visit(T value) -> T { - return value; -} -#else -# define FMT_USE_INT128 0 -#endif -#if !FMT_USE_INT128 -enum class int128_opt {}; -enum class uint128_opt {}; -// Reduce template instantiations. -template auto convert_for_visit(T) -> monostate { return {}; } -#endif - -// Casts a nonnegative integer to unsigned. -template -FMT_CONSTEXPR auto to_unsigned(Int value) -> make_unsigned_t { - FMT_ASSERT(std::is_unsigned::value || value >= 0, "negative value"); - return static_cast>(value); -} - -// A heuristic to detect std::string and std::[experimental::]string_view. -// It is mainly used to avoid dependency on <[experimental/]string_view>. -template -struct is_std_string_like : std::false_type {}; -template -struct is_std_string_like().find_first_of( - typename T::value_type(), 0))>> - : std::true_type {}; - -// Returns true iff the literal encoding is UTF-8. -constexpr auto is_utf8_enabled() -> bool { - // Avoid an MSVC sign extension bug: https://github.com/fmtlib/fmt/pull/2297. - using uchar = unsigned char; - return sizeof("\u00A7") == 3 && uchar("\u00A7"[0]) == 0xC2 && - uchar("\u00A7"[1]) == 0xA7; -} -constexpr auto use_utf8() -> bool { - return !FMT_MSC_VERSION || is_utf8_enabled(); -} - -static_assert(!FMT_UNICODE || use_utf8(), - "Unicode support requires compiling with /utf-8"); - -template FMT_CONSTEXPR auto length(const Char* s) -> size_t { - size_t len = 0; - while (*s++) ++len; - return len; -} - -template -FMT_CONSTEXPR auto compare(const Char* s1, const Char* s2, std::size_t n) - -> int { - for (; n != 0; ++s1, ++s2, --n) { - if (*s1 < *s2) return -1; - if (*s1 > *s2) return 1; - } - return 0; -} - -template -struct is_back_insert_iterator : std::false_type {}; -template -struct is_back_insert_iterator< - It, - bool_constant())), - It>::value>> : std::true_type {}; - -// Extracts a reference to the container from *insert_iterator. -template -inline auto get_container(OutputIt it) -> typename OutputIt::container_type& { - struct accessor : OutputIt { - accessor(OutputIt base) : OutputIt(base) {} - using OutputIt::container; - }; - return *accessor(it).container; -} -} // namespace detail - -// Checks whether T is a container with contiguous storage. -template struct is_contiguous : std::false_type {}; - -/** - * An implementation of `std::basic_string_view` for pre-C++17. It provides a - * subset of the API. `fmt::basic_string_view` is used for format strings even - * if `std::basic_string_view` is available to prevent issues when a library is - * compiled with a different `-std` option than the client code (which is not - * recommended). - */ -FMT_EXPORT -template class basic_string_view { - private: - const Char* data_; - size_t size_; - - public: - using value_type = Char; - using iterator = const Char*; - - constexpr basic_string_view() noexcept : data_(nullptr), size_(0) {} - - /// Constructs a string reference object from a C string and a size. - constexpr basic_string_view(const Char* s, size_t count) noexcept - : data_(s), size_(count) {} - - constexpr basic_string_view(std::nullptr_t) = delete; - - /// Constructs a string reference object from a C string. - FMT_CONSTEXPR20 - basic_string_view(const Char* s) - : data_(s), - size_(detail::const_check(std::is_same::value && - !detail::is_constant_evaluated(false)) - ? strlen(reinterpret_cast(s)) - : detail::length(s)) {} - - /// Constructs a string reference from a `std::basic_string` or a - /// `std::basic_string_view` object. - template ::value&& std::is_same< - typename S::value_type, Char>::value)> - FMT_CONSTEXPR basic_string_view(const S& s) noexcept - : data_(s.data()), size_(s.size()) {} - - /// Returns a pointer to the string data. - constexpr auto data() const noexcept -> const Char* { return data_; } - - /// Returns the string size. - constexpr auto size() const noexcept -> size_t { return size_; } - - constexpr auto begin() const noexcept -> iterator { return data_; } - constexpr auto end() const noexcept -> iterator { return data_ + size_; } - - constexpr auto operator[](size_t pos) const noexcept -> const Char& { - return data_[pos]; - } - - FMT_CONSTEXPR void remove_prefix(size_t n) noexcept { - data_ += n; - size_ -= n; - } - - FMT_CONSTEXPR auto starts_with(basic_string_view sv) const noexcept - -> bool { - return size_ >= sv.size_ && detail::compare(data_, sv.data_, sv.size_) == 0; - } - FMT_CONSTEXPR auto starts_with(Char c) const noexcept -> bool { - return size_ >= 1 && *data_ == c; - } - FMT_CONSTEXPR auto starts_with(const Char* s) const -> bool { - return starts_with(basic_string_view(s)); - } - - // Lexicographically compare this string reference to other. - FMT_CONSTEXPR auto compare(basic_string_view other) const -> int { - size_t str_size = size_ < other.size_ ? size_ : other.size_; - int result = detail::compare(data_, other.data_, str_size); - if (result == 0) - result = size_ == other.size_ ? 0 : (size_ < other.size_ ? -1 : 1); - return result; - } - - FMT_CONSTEXPR friend auto operator==(basic_string_view lhs, - basic_string_view rhs) -> bool { - return lhs.compare(rhs) == 0; - } - friend auto operator!=(basic_string_view lhs, basic_string_view rhs) -> bool { - return lhs.compare(rhs) != 0; - } - friend auto operator<(basic_string_view lhs, basic_string_view rhs) -> bool { - return lhs.compare(rhs) < 0; - } - friend auto operator<=(basic_string_view lhs, basic_string_view rhs) -> bool { - return lhs.compare(rhs) <= 0; - } - friend auto operator>(basic_string_view lhs, basic_string_view rhs) -> bool { - return lhs.compare(rhs) > 0; - } - friend auto operator>=(basic_string_view lhs, basic_string_view rhs) -> bool { - return lhs.compare(rhs) >= 0; - } -}; - -FMT_EXPORT -using string_view = basic_string_view; - -/// Specifies if `T` is a character type. Can be specialized by users. -FMT_EXPORT -template struct is_char : std::false_type {}; -template <> struct is_char : std::true_type {}; - -namespace detail { - -// Constructs fmt::basic_string_view from types implicitly convertible -// to it, deducing Char. Explicitly convertible types such as the ones returned -// from FMT_STRING are intentionally excluded. -template ::value)> -auto to_string_view(const Char* s) -> basic_string_view { - return s; -} -template ::value)> -auto to_string_view(const T& s) -> basic_string_view { - return s; -} -template -constexpr auto to_string_view(basic_string_view s) - -> basic_string_view { - return s; -} - -template -struct has_to_string_view : std::false_type {}; -// detail:: is intentional since to_string_view is not an extension point. -template -struct has_to_string_view< - T, void_t()))>> - : std::true_type {}; - -template struct string_literal { - static constexpr Char value[sizeof...(C)] = {C...}; - constexpr operator basic_string_view() const { - return {value, sizeof...(C)}; - } -}; -#if FMT_CPLUSPLUS < 201703L -template -constexpr Char string_literal::value[sizeof...(C)]; -#endif - -enum class type { - none_type, - // Integer types should go first, - int_type, - uint_type, - long_long_type, - ulong_long_type, - int128_type, - uint128_type, - bool_type, - char_type, - last_integer_type = char_type, - // followed by floating-point types. - float_type, - double_type, - long_double_type, - last_numeric_type = long_double_type, - cstring_type, - string_type, - pointer_type, - custom_type -}; - -// Maps core type T to the corresponding type enum constant. -template -struct type_constant : std::integral_constant {}; - -#define FMT_TYPE_CONSTANT(Type, constant) \ - template \ - struct type_constant \ - : std::integral_constant {} - -FMT_TYPE_CONSTANT(int, int_type); -FMT_TYPE_CONSTANT(unsigned, uint_type); -FMT_TYPE_CONSTANT(long long, long_long_type); -FMT_TYPE_CONSTANT(unsigned long long, ulong_long_type); -FMT_TYPE_CONSTANT(int128_opt, int128_type); -FMT_TYPE_CONSTANT(uint128_opt, uint128_type); -FMT_TYPE_CONSTANT(bool, bool_type); -FMT_TYPE_CONSTANT(Char, char_type); -FMT_TYPE_CONSTANT(float, float_type); -FMT_TYPE_CONSTANT(double, double_type); -FMT_TYPE_CONSTANT(long double, long_double_type); -FMT_TYPE_CONSTANT(const Char*, cstring_type); -FMT_TYPE_CONSTANT(basic_string_view, string_type); -FMT_TYPE_CONSTANT(const void*, pointer_type); - -constexpr auto is_integral_type(type t) -> bool { - return t > type::none_type && t <= type::last_integer_type; -} -constexpr auto is_arithmetic_type(type t) -> bool { - return t > type::none_type && t <= type::last_numeric_type; -} - -constexpr auto set(type rhs) -> int { return 1 << static_cast(rhs); } -constexpr auto in(type t, int set) -> bool { - return ((set >> static_cast(t)) & 1) != 0; -} - -// Bitsets of types. -enum { - sint_set = - set(type::int_type) | set(type::long_long_type) | set(type::int128_type), - uint_set = set(type::uint_type) | set(type::ulong_long_type) | - set(type::uint128_type), - bool_set = set(type::bool_type), - char_set = set(type::char_type), - float_set = set(type::float_type) | set(type::double_type) | - set(type::long_double_type), - string_set = set(type::string_type), - cstring_set = set(type::cstring_type), - pointer_set = set(type::pointer_type) -}; -} // namespace detail - -/// Reports a format error at compile time or, via a `format_error` exception, -/// at runtime. -// This function is intentionally not constexpr to give a compile-time error. -FMT_NORETURN FMT_API void report_error(const char* message); - -FMT_DEPRECATED FMT_NORETURN inline void throw_format_error( - const char* message) { - report_error(message); -} - -/// String's character (code unit) type. -template ()))> -using char_t = typename V::value_type; - -/** - * Parsing context consisting of a format string range being parsed and an - * argument counter for automatic indexing. - * You can use the `format_parse_context` type alias for `char` instead. - */ -FMT_EXPORT -template class basic_format_parse_context { - private: - basic_string_view format_str_; - int next_arg_id_; - - FMT_CONSTEXPR void do_check_arg_id(int id); - - public: - using char_type = Char; - using iterator = const Char*; - - explicit constexpr basic_format_parse_context( - basic_string_view format_str, int next_arg_id = 0) - : format_str_(format_str), next_arg_id_(next_arg_id) {} - - /// Returns an iterator to the beginning of the format string range being - /// parsed. - constexpr auto begin() const noexcept -> iterator { - return format_str_.begin(); - } - - /// Returns an iterator past the end of the format string range being parsed. - constexpr auto end() const noexcept -> iterator { return format_str_.end(); } - - /// Advances the begin iterator to `it`. - FMT_CONSTEXPR void advance_to(iterator it) { - format_str_.remove_prefix(detail::to_unsigned(it - begin())); - } - - /// Reports an error if using the manual argument indexing; otherwise returns - /// the next argument index and switches to the automatic indexing. - FMT_CONSTEXPR auto next_arg_id() -> int { - if (next_arg_id_ < 0) { - report_error("cannot switch from manual to automatic argument indexing"); - return 0; - } - int id = next_arg_id_++; - do_check_arg_id(id); - return id; - } - - /// Reports an error if using the automatic argument indexing; otherwise - /// switches to the manual indexing. - FMT_CONSTEXPR void check_arg_id(int id) { - if (next_arg_id_ > 0) { - report_error("cannot switch from automatic to manual argument indexing"); - return; - } - next_arg_id_ = -1; - do_check_arg_id(id); - } - FMT_CONSTEXPR void check_arg_id(basic_string_view) { - next_arg_id_ = -1; - } - FMT_CONSTEXPR void check_dynamic_spec(int arg_id); -}; - -FMT_EXPORT -using format_parse_context = basic_format_parse_context; - -namespace detail { -// A parse context with extra data used only in compile-time checks. -template -class compile_parse_context : public basic_format_parse_context { - private: - int num_args_; - const type* types_; - using base = basic_format_parse_context; - - public: - explicit FMT_CONSTEXPR compile_parse_context( - basic_string_view format_str, int num_args, const type* types, - int next_arg_id = 0) - : base(format_str, next_arg_id), num_args_(num_args), types_(types) {} - - constexpr auto num_args() const -> int { return num_args_; } - constexpr auto arg_type(int id) const -> type { return types_[id]; } - - FMT_CONSTEXPR auto next_arg_id() -> int { - int id = base::next_arg_id(); - if (id >= num_args_) report_error("argument not found"); - return id; - } - - FMT_CONSTEXPR void check_arg_id(int id) { - base::check_arg_id(id); - if (id >= num_args_) report_error("argument not found"); - } - using base::check_arg_id; - - FMT_CONSTEXPR void check_dynamic_spec(int arg_id) { - detail::ignore_unused(arg_id); - if (arg_id < num_args_ && types_ && !is_integral_type(types_[arg_id])) - report_error("width/precision is not integer"); - } -}; - -/// A contiguous memory buffer with an optional growing ability. It is an -/// internal class and shouldn't be used directly, only via `memory_buffer`. -template class buffer { - private: - T* ptr_; - size_t size_; - size_t capacity_; - - using grow_fun = void (*)(buffer& buf, size_t capacity); - grow_fun grow_; - - protected: - // Don't initialize ptr_ since it is not accessed to save a few cycles. - FMT_MSC_WARNING(suppress : 26495) - FMT_CONSTEXPR20 buffer(grow_fun grow, size_t sz) noexcept - : size_(sz), capacity_(sz), grow_(grow) {} - - constexpr buffer(grow_fun grow, T* p = nullptr, size_t sz = 0, - size_t cap = 0) noexcept - : ptr_(p), size_(sz), capacity_(cap), grow_(grow) {} - - FMT_CONSTEXPR20 ~buffer() = default; - buffer(buffer&&) = default; - - /// Sets the buffer data and capacity. - FMT_CONSTEXPR void set(T* buf_data, size_t buf_capacity) noexcept { - ptr_ = buf_data; - capacity_ = buf_capacity; - } - - public: - using value_type = T; - using const_reference = const T&; - - buffer(const buffer&) = delete; - void operator=(const buffer&) = delete; - - auto begin() noexcept -> T* { return ptr_; } - auto end() noexcept -> T* { return ptr_ + size_; } - - auto begin() const noexcept -> const T* { return ptr_; } - auto end() const noexcept -> const T* { return ptr_ + size_; } - - /// Returns the size of this buffer. - constexpr auto size() const noexcept -> size_t { return size_; } - - /// Returns the capacity of this buffer. - constexpr auto capacity() const noexcept -> size_t { return capacity_; } - - /// Returns a pointer to the buffer data (not null-terminated). - FMT_CONSTEXPR auto data() noexcept -> T* { return ptr_; } - FMT_CONSTEXPR auto data() const noexcept -> const T* { return ptr_; } - - /// Clears this buffer. - void clear() { size_ = 0; } - - // Tries resizing the buffer to contain `count` elements. If T is a POD type - // the new elements may not be initialized. - FMT_CONSTEXPR void try_resize(size_t count) { - try_reserve(count); - size_ = count <= capacity_ ? count : capacity_; - } - - // Tries increasing the buffer capacity to `new_capacity`. It can increase the - // capacity by a smaller amount than requested but guarantees there is space - // for at least one additional element either by increasing the capacity or by - // flushing the buffer if it is full. - FMT_CONSTEXPR void try_reserve(size_t new_capacity) { - if (new_capacity > capacity_) grow_(*this, new_capacity); - } - - FMT_CONSTEXPR void push_back(const T& value) { - try_reserve(size_ + 1); - ptr_[size_++] = value; - } - - /// Appends data to the end of the buffer. - template void append(const U* begin, const U* end) { - while (begin != end) { - auto count = to_unsigned(end - begin); - try_reserve(size_ + count); - auto free_cap = capacity_ - size_; - if (free_cap < count) count = free_cap; - if (std::is_same::value) { - memcpy(ptr_ + size_, begin, count * sizeof(T)); - } else { - T* out = ptr_ + size_; - for (size_t i = 0; i < count; ++i) out[i] = begin[i]; - } - size_ += count; - begin += count; - } - } - - template FMT_CONSTEXPR auto operator[](Idx index) -> T& { - return ptr_[index]; - } - template - FMT_CONSTEXPR auto operator[](Idx index) const -> const T& { - return ptr_[index]; - } -}; - -struct buffer_traits { - explicit buffer_traits(size_t) {} - auto count() const -> size_t { return 0; } - auto limit(size_t size) -> size_t { return size; } -}; - -class fixed_buffer_traits { - private: - size_t count_ = 0; - size_t limit_; - - public: - explicit fixed_buffer_traits(size_t limit) : limit_(limit) {} - auto count() const -> size_t { return count_; } - auto limit(size_t size) -> size_t { - size_t n = limit_ > count_ ? limit_ - count_ : 0; - count_ += size; - return size < n ? size : n; - } -}; - -// A buffer that writes to an output iterator when flushed. -template -class iterator_buffer : public Traits, public buffer { - private: - OutputIt out_; - enum { buffer_size = 256 }; - T data_[buffer_size]; - - static FMT_CONSTEXPR void grow(buffer& buf, size_t) { - if (buf.size() == buffer_size) static_cast(buf).flush(); - } - - void flush() { - auto size = this->size(); - this->clear(); - const T* begin = data_; - const T* end = begin + this->limit(size); - while (begin != end) *out_++ = *begin++; - } - - public: - explicit iterator_buffer(OutputIt out, size_t n = buffer_size) - : Traits(n), buffer(grow, data_, 0, buffer_size), out_(out) {} - iterator_buffer(iterator_buffer&& other) noexcept - : Traits(other), - buffer(grow, data_, 0, buffer_size), - out_(other.out_) {} - ~iterator_buffer() { - // Don't crash if flush fails during unwinding. - FMT_TRY { flush(); } - FMT_CATCH(...) {} - } - - auto out() -> OutputIt { - flush(); - return out_; - } - auto count() const -> size_t { return Traits::count() + this->size(); } -}; - -template -class iterator_buffer : public fixed_buffer_traits, - public buffer { - private: - T* out_; - enum { buffer_size = 256 }; - T data_[buffer_size]; - - static FMT_CONSTEXPR void grow(buffer& buf, size_t) { - if (buf.size() == buf.capacity()) - static_cast(buf).flush(); - } - - void flush() { - size_t n = this->limit(this->size()); - if (this->data() == out_) { - out_ += n; - this->set(data_, buffer_size); - } - this->clear(); - } - - public: - explicit iterator_buffer(T* out, size_t n = buffer_size) - : fixed_buffer_traits(n), buffer(grow, out, 0, n), out_(out) {} - iterator_buffer(iterator_buffer&& other) noexcept - : fixed_buffer_traits(other), - buffer(static_cast(other)), - out_(other.out_) { - if (this->data() != out_) { - this->set(data_, buffer_size); - this->clear(); - } - } - ~iterator_buffer() { flush(); } - - auto out() -> T* { - flush(); - return out_; - } - auto count() const -> size_t { - return fixed_buffer_traits::count() + this->size(); - } -}; - -template class iterator_buffer : public buffer { - public: - explicit iterator_buffer(T* out, size_t = 0) - : buffer([](buffer&, size_t) {}, out, 0, ~size_t()) {} - - auto out() -> T* { return &*this->end(); } -}; - -// A buffer that writes to a container with the contiguous storage. -template -class iterator_buffer< - OutputIt, - enable_if_t::value && - is_contiguous::value, - typename OutputIt::container_type::value_type>> - : public buffer { - private: - using container_type = typename OutputIt::container_type; - using value_type = typename container_type::value_type; - container_type& container_; - - static FMT_CONSTEXPR void grow(buffer& buf, size_t capacity) { - auto& self = static_cast(buf); - self.container_.resize(capacity); - self.set(&self.container_[0], capacity); - } - - public: - explicit iterator_buffer(container_type& c) - : buffer(grow, c.size()), container_(c) {} - explicit iterator_buffer(OutputIt out, size_t = 0) - : iterator_buffer(get_container(out)) {} - - auto out() -> OutputIt { return back_inserter(container_); } -}; - -// A buffer that counts the number of code units written discarding the output. -template class counting_buffer : public buffer { - private: - enum { buffer_size = 256 }; - T data_[buffer_size]; - size_t count_ = 0; - - static FMT_CONSTEXPR void grow(buffer& buf, size_t) { - if (buf.size() != buffer_size) return; - static_cast(buf).count_ += buf.size(); - buf.clear(); - } - - public: - counting_buffer() : buffer(grow, data_, 0, buffer_size) {} - - auto count() -> size_t { return count_ + this->size(); } -}; -} // namespace detail - -template -FMT_CONSTEXPR void basic_format_parse_context::do_check_arg_id(int id) { - // Argument id is only checked at compile-time during parsing because - // formatting has its own validation. - if (detail::is_constant_evaluated() && - (!FMT_GCC_VERSION || FMT_GCC_VERSION >= 1200)) { - using context = detail::compile_parse_context; - if (id >= static_cast(this)->num_args()) - report_error("argument not found"); - } -} - -template -FMT_CONSTEXPR void basic_format_parse_context::check_dynamic_spec( - int arg_id) { - if (detail::is_constant_evaluated() && - (!FMT_GCC_VERSION || FMT_GCC_VERSION >= 1200)) { - using context = detail::compile_parse_context; - static_cast(this)->check_dynamic_spec(arg_id); - } -} - -FMT_EXPORT template class basic_format_arg; -FMT_EXPORT template class basic_format_args; -FMT_EXPORT template class dynamic_format_arg_store; - -// A formatter for objects of type T. -FMT_EXPORT -template -struct formatter { - // A deleted default constructor indicates a disabled formatter. - formatter() = delete; -}; - -// Specifies if T has an enabled formatter specialization. A type can be -// formattable even if it doesn't have a formatter e.g. via a conversion. -template -using has_formatter = - std::is_constructible>; - -// An output iterator that appends to a buffer. It is used instead of -// back_insert_iterator to reduce symbol sizes and avoid dependency. -template class basic_appender { - private: - detail::buffer* buffer_; - - friend auto get_container(basic_appender app) -> detail::buffer& { - return *app.buffer_; - } - - public: - using iterator_category = int; - using value_type = T; - using difference_type = ptrdiff_t; - using pointer = T*; - using reference = T&; - FMT_UNCHECKED_ITERATOR(basic_appender); - - FMT_CONSTEXPR basic_appender(detail::buffer& buf) : buffer_(&buf) {} - - auto operator=(T c) -> basic_appender& { - buffer_->push_back(c); - return *this; - } - auto operator*() -> basic_appender& { return *this; } - auto operator++() -> basic_appender& { return *this; } - auto operator++(int) -> basic_appender { return *this; } -}; - -using appender = basic_appender; - -namespace detail { - -template -struct locking : std::true_type {}; -template -struct locking>::nonlocking>> - : std::false_type {}; - -template FMT_CONSTEXPR inline auto is_locking() -> bool { - return locking::value; -} -template -FMT_CONSTEXPR inline auto is_locking() -> bool { - return locking::value || is_locking(); -} - -// An optimized version of std::copy with the output value type (T). -template -auto copy(InputIt begin, InputIt end, appender out) -> appender { - get_container(out).append(begin, end); - return out; -} - -template ::value)> -auto copy(InputIt begin, InputIt end, OutputIt out) -> OutputIt { - get_container(out).append(begin, end); - return out; -} - -template ::value)> -FMT_CONSTEXPR auto copy(InputIt begin, InputIt end, OutputIt out) -> OutputIt { - while (begin != end) *out++ = static_cast(*begin++); - return out; -} - -template -FMT_CONSTEXPR auto copy(const T* begin, const T* end, T* out) -> T* { - if (is_constant_evaluated()) return copy(begin, end, out); - auto size = to_unsigned(end - begin); - if (size > 0) memcpy(out, begin, size * sizeof(T)); - return out + size; -} - -template -FMT_CONSTEXPR auto copy(basic_string_view s, OutputIt out) -> OutputIt { - return copy(s.begin(), s.end(), out); -} - -template -constexpr auto has_const_formatter_impl(T*) - -> decltype(typename Context::template formatter_type().format( - std::declval(), std::declval()), - true) { - return true; -} -template -constexpr auto has_const_formatter_impl(...) -> bool { - return false; -} -template -constexpr auto has_const_formatter() -> bool { - return has_const_formatter_impl(static_cast(nullptr)); -} - -// Maps an output iterator to a buffer. -template -auto get_buffer(OutputIt out) -> iterator_buffer { - return iterator_buffer(out); -} -template auto get_buffer(basic_appender out) -> buffer& { - return get_container(out); -} - -template -auto get_iterator(Buf& buf, OutputIt) -> decltype(buf.out()) { - return buf.out(); -} -template -auto get_iterator(buffer&, OutputIt out) -> OutputIt { - return out; -} - -struct view {}; - -template struct named_arg : view { - const Char* name; - const T& value; - named_arg(const Char* n, const T& v) : name(n), value(v) {} -}; - -template struct named_arg_info { - const Char* name; - int id; -}; - -template struct is_named_arg : std::false_type {}; -template struct is_statically_named_arg : std::false_type {}; - -template -struct is_named_arg> : std::true_type {}; - -template constexpr auto count() -> size_t { return B ? 1 : 0; } -template constexpr auto count() -> size_t { - return (B1 ? 1 : 0) + count(); -} - -template constexpr auto count_named_args() -> size_t { - return count::value...>(); -} - -template -constexpr auto count_statically_named_args() -> size_t { - return count::value...>(); -} - -struct unformattable {}; -struct unformattable_char : unformattable {}; -struct unformattable_pointer : unformattable {}; - -template struct string_value { - const Char* data; - size_t size; -}; - -template struct named_arg_value { - const named_arg_info* data; - size_t size; -}; - -template struct custom_value { - using parse_context = typename Context::parse_context_type; - void* value; - void (*format)(void* arg, parse_context& parse_ctx, Context& ctx); -}; - -// A formatting argument value. -template class value { - public: - using char_type = typename Context::char_type; - - union { - monostate no_value; - int int_value; - unsigned uint_value; - long long long_long_value; - unsigned long long ulong_long_value; - int128_opt int128_value; - uint128_opt uint128_value; - bool bool_value; - char_type char_value; - float float_value; - double double_value; - long double long_double_value; - const void* pointer; - string_value string; - custom_value custom; - named_arg_value named_args; - }; - - constexpr FMT_ALWAYS_INLINE value() : no_value() {} - constexpr FMT_ALWAYS_INLINE value(int val) : int_value(val) {} - constexpr FMT_ALWAYS_INLINE value(unsigned val) : uint_value(val) {} - constexpr FMT_ALWAYS_INLINE value(long long val) : long_long_value(val) {} - constexpr FMT_ALWAYS_INLINE value(unsigned long long val) - : ulong_long_value(val) {} - FMT_ALWAYS_INLINE value(int128_opt val) : int128_value(val) {} - FMT_ALWAYS_INLINE value(uint128_opt val) : uint128_value(val) {} - constexpr FMT_ALWAYS_INLINE value(float val) : float_value(val) {} - constexpr FMT_ALWAYS_INLINE value(double val) : double_value(val) {} - FMT_ALWAYS_INLINE value(long double val) : long_double_value(val) {} - constexpr FMT_ALWAYS_INLINE value(bool val) : bool_value(val) {} - constexpr FMT_ALWAYS_INLINE value(char_type val) : char_value(val) {} - FMT_CONSTEXPR FMT_ALWAYS_INLINE value(const char_type* val) { - string.data = val; - if (is_constant_evaluated()) string.size = {}; - } - FMT_CONSTEXPR FMT_ALWAYS_INLINE value(basic_string_view val) { - string.data = val.data(); - string.size = val.size(); - } - FMT_ALWAYS_INLINE value(const void* val) : pointer(val) {} - FMT_ALWAYS_INLINE value(const named_arg_info* args, size_t size) - : named_args{args, size} {} - - template FMT_CONSTEXPR20 FMT_ALWAYS_INLINE value(T& val) { - using value_type = remove_const_t; - // T may overload operator& e.g. std::vector::reference in libc++. -#if defined(__cpp_if_constexpr) - if constexpr (std::is_same::value) - custom.value = const_cast(&val); -#endif - if (!is_constant_evaluated()) - custom.value = const_cast(&reinterpret_cast(val)); - // Get the formatter type through the context to allow different contexts - // have different extension points, e.g. `formatter` for `format` and - // `printf_formatter` for `printf`. - custom.format = format_custom_arg< - value_type, typename Context::template formatter_type>; - } - value(unformattable); - value(unformattable_char); - value(unformattable_pointer); - - private: - // Formats an argument of a custom type, such as a user-defined class. - template - static void format_custom_arg(void* arg, - typename Context::parse_context_type& parse_ctx, - Context& ctx) { - auto f = Formatter(); - parse_ctx.advance_to(f.parse(parse_ctx)); - using qualified_type = - conditional_t(), const T, T>; - // format must be const for compatibility with std::format and compilation. - const auto& cf = f; - ctx.advance_to(cf.format(*static_cast(arg), ctx)); - } -}; - -// To minimize the number of types we need to deal with, long is translated -// either to int or to long long depending on its size. -enum { long_short = sizeof(long) == sizeof(int) }; -using long_type = conditional_t; -using ulong_type = conditional_t; - -template struct format_as_result { - template ::value || std::is_class::value)> - static auto map(U*) -> remove_cvref_t()))>; - static auto map(...) -> void; - - using type = decltype(map(static_cast(nullptr))); -}; -template using format_as_t = typename format_as_result::type; - -template -struct has_format_as - : bool_constant, void>::value> {}; - -#define FMT_MAP_API FMT_CONSTEXPR FMT_ALWAYS_INLINE - -// Maps formatting arguments to core types. -// arg_mapper reports errors by returning unformattable instead of using -// static_assert because it's used in the is_formattable trait. -template struct arg_mapper { - using char_type = typename Context::char_type; - - FMT_MAP_API auto map(signed char val) -> int { return val; } - FMT_MAP_API auto map(unsigned char val) -> unsigned { return val; } - FMT_MAP_API auto map(short val) -> int { return val; } - FMT_MAP_API auto map(unsigned short val) -> unsigned { return val; } - FMT_MAP_API auto map(int val) -> int { return val; } - FMT_MAP_API auto map(unsigned val) -> unsigned { return val; } - FMT_MAP_API auto map(long val) -> long_type { return val; } - FMT_MAP_API auto map(unsigned long val) -> ulong_type { return val; } - FMT_MAP_API auto map(long long val) -> long long { return val; } - FMT_MAP_API auto map(unsigned long long val) -> unsigned long long { - return val; - } - FMT_MAP_API auto map(int128_opt val) -> int128_opt { return val; } - FMT_MAP_API auto map(uint128_opt val) -> uint128_opt { return val; } - FMT_MAP_API auto map(bool val) -> bool { return val; } - - template ::value || - std::is_same::value)> - FMT_MAP_API auto map(T val) -> char_type { - return val; - } - template ::value || -#ifdef __cpp_char8_t - std::is_same::value || -#endif - std::is_same::value || - std::is_same::value) && - !std::is_same::value, - int> = 0> - FMT_MAP_API auto map(T) -> unformattable_char { - return {}; - } - - FMT_MAP_API auto map(float val) -> float { return val; } - FMT_MAP_API auto map(double val) -> double { return val; } - FMT_MAP_API auto map(long double val) -> long double { return val; } - - FMT_MAP_API auto map(char_type* val) -> const char_type* { return val; } - FMT_MAP_API auto map(const char_type* val) -> const char_type* { return val; } - template , - FMT_ENABLE_IF(std::is_same::value && - !std::is_pointer::value)> - FMT_MAP_API auto map(const T& val) -> basic_string_view { - return to_string_view(val); - } - template , - FMT_ENABLE_IF(!std::is_same::value && - !std::is_pointer::value)> - FMT_MAP_API auto map(const T&) -> unformattable_char { - return {}; - } - - FMT_MAP_API auto map(void* val) -> const void* { return val; } - FMT_MAP_API auto map(const void* val) -> const void* { return val; } - FMT_MAP_API auto map(std::nullptr_t val) -> const void* { return val; } - - // Use SFINAE instead of a const T* parameter to avoid a conflict with the - // array overload. - template < - typename T, - FMT_ENABLE_IF( - std::is_pointer::value || std::is_member_pointer::value || - std::is_function::type>::value || - (std::is_array::value && - !std::is_convertible::value))> - FMT_CONSTEXPR auto map(const T&) -> unformattable_pointer { - return {}; - } - - template ::value)> - FMT_MAP_API auto map(const T (&values)[N]) -> const T (&)[N] { - return values; - } - - // Only map owning types because mapping views can be unsafe. - template , - FMT_ENABLE_IF(std::is_arithmetic::value)> - FMT_MAP_API auto map(const T& val) -> decltype(FMT_DECLTYPE_THIS map(U())) { - return map(format_as(val)); - } - - template > - struct formattable : bool_constant() || - (has_formatter::value && - !std::is_const::value)> {}; - - template ::value)> - FMT_MAP_API auto do_map(T& val) -> T& { - return val; - } - template ::value)> - FMT_MAP_API auto do_map(T&) -> unformattable { - return {}; - } - - // is_fundamental is used to allow formatters for extended FP types. - template , - FMT_ENABLE_IF( - (std::is_class::value || std::is_enum::value || - std::is_union::value || std::is_fundamental::value) && - !has_to_string_view::value && !is_char::value && - !is_named_arg::value && !std::is_integral::value && - !std::is_arithmetic>::value)> - FMT_MAP_API auto map(T& val) -> decltype(FMT_DECLTYPE_THIS do_map(val)) { - return do_map(val); - } - - template ::value)> - FMT_MAP_API auto map(const T& named_arg) - -> decltype(FMT_DECLTYPE_THIS map(named_arg.value)) { - return map(named_arg.value); - } - - auto map(...) -> unformattable { return {}; } -}; - -// A type constant after applying arg_mapper. -template -using mapped_type_constant = - type_constant().map(std::declval())), - typename Context::char_type>; - -enum { packed_arg_bits = 4 }; -// Maximum number of arguments with packed types. -enum { max_packed_args = 62 / packed_arg_bits }; -enum : unsigned long long { is_unpacked_bit = 1ULL << 63 }; -enum : unsigned long long { has_named_args_bit = 1ULL << 62 }; - -template -struct is_output_iterator : std::false_type {}; - -template <> struct is_output_iterator : std::true_type {}; - -template -struct is_output_iterator< - It, T, void_t()++ = std::declval())>> - : std::true_type {}; - -// A type-erased reference to an std::locale to avoid a heavy include. -class locale_ref { - private: - const void* locale_; // A type-erased pointer to std::locale. - - public: - constexpr locale_ref() : locale_(nullptr) {} - template explicit locale_ref(const Locale& loc); - - explicit operator bool() const noexcept { return locale_ != nullptr; } - - template auto get() const -> Locale; -}; - -template constexpr auto encode_types() -> unsigned long long { - return 0; -} - -template -constexpr auto encode_types() -> unsigned long long { - return static_cast(mapped_type_constant::value) | - (encode_types() << packed_arg_bits); -} - -template -constexpr unsigned long long make_descriptor() { - return NUM_ARGS <= max_packed_args ? encode_types() - : is_unpacked_bit | NUM_ARGS; -} - -// This type is intentionally undefined, only used for errors. -template -#if FMT_CLANG_VERSION && FMT_CLANG_VERSION <= 1500 -// https://github.com/fmtlib/fmt/issues/3796 -struct type_is_unformattable_for { -}; -#else -struct type_is_unformattable_for; -#endif - -template -FMT_CONSTEXPR auto make_arg(T& val) -> value { - using arg_type = remove_cvref_t().map(val))>; - - // Use enum instead of constexpr because the latter may generate code. - enum { - formattable_char = !std::is_same::value - }; - static_assert(formattable_char, "Mixing character types is disallowed."); - - // Formatting of arbitrary pointers is disallowed. If you want to format a - // pointer cast it to `void*` or `const void*`. In particular, this forbids - // formatting of `[const] volatile char*` printed as bool by iostreams. - enum { - formattable_pointer = !std::is_same::value - }; - static_assert(formattable_pointer, - "Formatting of non-void pointers is disallowed."); - - enum { formattable = !std::is_same::value }; -#if defined(__cpp_if_constexpr) - if constexpr (!formattable) - type_is_unformattable_for _; -#endif - static_assert( - formattable, - "Cannot format an argument. To make type T formattable provide a " - "formatter specialization: https://fmt.dev/latest/api.html#udt"); - return {arg_mapper().map(val)}; -} - -template -FMT_CONSTEXPR auto make_arg(T& val) -> basic_format_arg { - auto arg = basic_format_arg(); - arg.type_ = mapped_type_constant::value; - arg.value_ = make_arg(val); - return arg; -} - -template -FMT_CONSTEXPR inline auto make_arg(T& val) -> basic_format_arg { - return make_arg(val); -} - -template -using arg_t = conditional_t, - basic_format_arg>; - -template ::value)> -void init_named_arg(named_arg_info*, int& arg_index, int&, const T&) { - ++arg_index; -} -template ::value)> -void init_named_arg(named_arg_info* named_args, int& arg_index, - int& named_arg_index, const T& arg) { - named_args[named_arg_index++] = {arg.name, arg_index++}; -} - -// An array of references to arguments. It can be implicitly converted to -// `fmt::basic_format_args` for passing into type-erased formatting functions -// such as `fmt::vformat`. -template -struct format_arg_store { - // args_[0].named_args points to named_args to avoid bloating format_args. - // +1 to workaround a bug in gcc 7.5 that causes duplicated-branches warning. - static constexpr size_t ARGS_ARR_SIZE = 1 + (NUM_ARGS != 0 ? NUM_ARGS : +1); - - arg_t args[ARGS_ARR_SIZE]; - named_arg_info named_args[NUM_NAMED_ARGS]; - - template - FMT_MAP_API format_arg_store(T&... values) - : args{{named_args, NUM_NAMED_ARGS}, - make_arg(values)...} { - using dummy = int[]; - int arg_index = 0, named_arg_index = 0; - (void)dummy{ - 0, - (init_named_arg(named_args, arg_index, named_arg_index, values), 0)...}; - } - - format_arg_store(format_arg_store&& rhs) { - args[0] = {named_args, NUM_NAMED_ARGS}; - for (size_t i = 1; i < ARGS_ARR_SIZE; ++i) args[i] = rhs.args[i]; - for (size_t i = 0; i < NUM_NAMED_ARGS; ++i) - named_args[i] = rhs.named_args[i]; - } - - format_arg_store(const format_arg_store& rhs) = delete; - format_arg_store& operator=(const format_arg_store& rhs) = delete; - format_arg_store& operator=(format_arg_store&& rhs) = delete; -}; - -// A specialization of format_arg_store without named arguments. -// It is a plain struct to reduce binary size in debug mode. -template -struct format_arg_store { - // +1 to workaround a bug in gcc 7.5 that causes duplicated-branches warning. - arg_t args[NUM_ARGS != 0 ? NUM_ARGS : +1]; -}; - -} // namespace detail -FMT_BEGIN_EXPORT - -// A formatting argument. Context is a template parameter for the compiled API -// where output can be unbuffered. -template class basic_format_arg { - private: - detail::value value_; - detail::type type_; - - template - friend FMT_CONSTEXPR auto detail::make_arg(T& value) - -> basic_format_arg; - - friend class basic_format_args; - friend class dynamic_format_arg_store; - - using char_type = typename Context::char_type; - - template - friend struct detail::format_arg_store; - - basic_format_arg(const detail::named_arg_info* args, size_t size) - : value_(args, size) {} - - public: - class handle { - public: - explicit handle(detail::custom_value custom) : custom_(custom) {} - - void format(typename Context::parse_context_type& parse_ctx, - Context& ctx) const { - custom_.format(custom_.value, parse_ctx, ctx); - } - - private: - detail::custom_value custom_; - }; - - constexpr basic_format_arg() : type_(detail::type::none_type) {} - - constexpr explicit operator bool() const noexcept { - return type_ != detail::type::none_type; - } - - auto type() const -> detail::type { return type_; } - - auto is_integral() const -> bool { return detail::is_integral_type(type_); } - auto is_arithmetic() const -> bool { - return detail::is_arithmetic_type(type_); - } - - /** - * Visits an argument dispatching to the appropriate visit method based on - * the argument type. For example, if the argument type is `double` then - * `vis(value)` will be called with the value of type `double`. - */ - template - FMT_CONSTEXPR auto visit(Visitor&& vis) const -> decltype(vis(0)) { - switch (type_) { - case detail::type::none_type: - break; - case detail::type::int_type: - return vis(value_.int_value); - case detail::type::uint_type: - return vis(value_.uint_value); - case detail::type::long_long_type: - return vis(value_.long_long_value); - case detail::type::ulong_long_type: - return vis(value_.ulong_long_value); - case detail::type::int128_type: - return vis(detail::convert_for_visit(value_.int128_value)); - case detail::type::uint128_type: - return vis(detail::convert_for_visit(value_.uint128_value)); - case detail::type::bool_type: - return vis(value_.bool_value); - case detail::type::char_type: - return vis(value_.char_value); - case detail::type::float_type: - return vis(value_.float_value); - case detail::type::double_type: - return vis(value_.double_value); - case detail::type::long_double_type: - return vis(value_.long_double_value); - case detail::type::cstring_type: - return vis(value_.string.data); - case detail::type::string_type: - using sv = basic_string_view; - return vis(sv(value_.string.data, value_.string.size)); - case detail::type::pointer_type: - return vis(value_.pointer); - case detail::type::custom_type: - return vis(typename basic_format_arg::handle(value_.custom)); - } - return vis(monostate()); - } - - auto format_custom(const char_type* parse_begin, - typename Context::parse_context_type& parse_ctx, - Context& ctx) -> bool { - if (type_ != detail::type::custom_type) return false; - parse_ctx.advance_to(parse_begin); - value_.custom.format(value_.custom.value, parse_ctx, ctx); - return true; - } -}; - -template -FMT_DEPRECATED FMT_CONSTEXPR auto visit_format_arg( - Visitor&& vis, const basic_format_arg& arg) -> decltype(vis(0)) { - return arg.visit(static_cast(vis)); -} - -/** - * A view of a collection of formatting arguments. To avoid lifetime issues it - * should only be used as a parameter type in type-erased functions such as - * `vformat`: - * - * void vlog(fmt::string_view fmt, fmt::format_args args); // OK - * fmt::format_args args = fmt::make_format_args(); // Dangling reference - */ -template class basic_format_args { - public: - using size_type = int; - using format_arg = basic_format_arg; - - private: - // A descriptor that contains information about formatting arguments. - // If the number of arguments is less or equal to max_packed_args then - // argument types are passed in the descriptor. This reduces binary code size - // per formatting function call. - unsigned long long desc_; - union { - // If is_packed() returns true then argument values are stored in values_; - // otherwise they are stored in args_. This is done to improve cache - // locality and reduce compiled code size since storing larger objects - // may require more code (at least on x86-64) even if the same amount of - // data is actually copied to stack. It saves ~10% on the bloat test. - const detail::value* values_; - const format_arg* args_; - }; - - constexpr auto is_packed() const -> bool { - return (desc_ & detail::is_unpacked_bit) == 0; - } - constexpr auto has_named_args() const -> bool { - return (desc_ & detail::has_named_args_bit) != 0; - } - - FMT_CONSTEXPR auto type(int index) const -> detail::type { - int shift = index * detail::packed_arg_bits; - unsigned int mask = (1 << detail::packed_arg_bits) - 1; - return static_cast((desc_ >> shift) & mask); - } - - public: - constexpr basic_format_args() : desc_(0), args_(nullptr) {} - - /// Constructs a `basic_format_args` object from `format_arg_store`. - template - constexpr FMT_ALWAYS_INLINE basic_format_args( - const detail::format_arg_store& - store) - : desc_(DESC), values_(store.args + (NUM_NAMED_ARGS != 0 ? 1 : 0)) {} - - template detail::max_packed_args)> - constexpr basic_format_args( - const detail::format_arg_store& - store) - : desc_(DESC), args_(store.args + (NUM_NAMED_ARGS != 0 ? 1 : 0)) {} - - /// Constructs a `basic_format_args` object from `dynamic_format_arg_store`. - constexpr basic_format_args(const dynamic_format_arg_store& store) - : desc_(store.get_types()), args_(store.data()) {} - - /// Constructs a `basic_format_args` object from a dynamic list of arguments. - constexpr basic_format_args(const format_arg* args, int count) - : desc_(detail::is_unpacked_bit | detail::to_unsigned(count)), - args_(args) {} - - /// Returns the argument with the specified id. - FMT_CONSTEXPR auto get(int id) const -> format_arg { - format_arg arg; - if (!is_packed()) { - if (id < max_size()) arg = args_[id]; - return arg; - } - if (static_cast(id) >= detail::max_packed_args) return arg; - arg.type_ = type(id); - if (arg.type_ == detail::type::none_type) return arg; - arg.value_ = values_[id]; - return arg; - } - - template - auto get(basic_string_view name) const -> format_arg { - int id = get_id(name); - return id >= 0 ? get(id) : format_arg(); - } - - template - FMT_CONSTEXPR auto get_id(basic_string_view name) const -> int { - if (!has_named_args()) return -1; - const auto& named_args = - (is_packed() ? values_[-1] : args_[-1].value_).named_args; - for (size_t i = 0; i < named_args.size; ++i) { - if (named_args.data[i].name == name) return named_args.data[i].id; - } - return -1; - } - - auto max_size() const -> int { - unsigned long long max_packed = detail::max_packed_args; - return static_cast(is_packed() ? max_packed - : desc_ & ~detail::is_unpacked_bit); - } -}; - -// A formatting context. -class context { - private: - appender out_; - basic_format_args args_; - detail::locale_ref loc_; - - public: - /// The character type for the output. - using char_type = char; - - using iterator = appender; - using format_arg = basic_format_arg; - using parse_context_type = basic_format_parse_context; - template using formatter_type = formatter; - - /// Constructs a `basic_format_context` object. References to the arguments - /// are stored in the object so make sure they have appropriate lifetimes. - FMT_CONSTEXPR context(iterator out, basic_format_args ctx_args, - detail::locale_ref loc = {}) - : out_(out), args_(ctx_args), loc_(loc) {} - context(context&&) = default; - context(const context&) = delete; - void operator=(const context&) = delete; - - FMT_CONSTEXPR auto arg(int id) const -> format_arg { return args_.get(id); } - auto arg(string_view name) -> format_arg { return args_.get(name); } - FMT_CONSTEXPR auto arg_id(string_view name) -> int { - return args_.get_id(name); - } - auto args() const -> const basic_format_args& { return args_; } - - // Returns an iterator to the beginning of the output range. - FMT_CONSTEXPR auto out() -> iterator { return out_; } - - // Advances the begin iterator to `it`. - void advance_to(iterator) {} - - FMT_CONSTEXPR auto locale() -> detail::locale_ref { return loc_; } -}; - -template class generic_context; - -// Longer aliases for C++20 compatibility. -template -using basic_format_context = - conditional_t::value, context, - generic_context>; -using format_context = context; - -template -using buffered_context = basic_format_context, Char>; - -template -using is_formattable = bool_constant>() - .map(std::declval()))>::value>; - -#if FMT_USE_CONCEPTS -template -concept formattable = is_formattable, Char>::value; -#endif - -/** - * Constructs an object that stores references to arguments and can be - * implicitly converted to `format_args`. `Context` can be omitted in which case - * it defaults to `format_context`. See `arg` for lifetime considerations. - */ -// Take arguments by lvalue references to avoid some lifetime issues, e.g. -// auto args = make_format_args(std::string()); -template (), - unsigned long long DESC = detail::make_descriptor(), - FMT_ENABLE_IF(NUM_NAMED_ARGS == 0)> -constexpr FMT_ALWAYS_INLINE auto make_format_args(T&... args) - -> detail::format_arg_store { - return {{detail::make_arg( - args)...}}; -} - -#ifndef FMT_DOC -template (), - unsigned long long DESC = - detail::make_descriptor() | - static_cast(detail::has_named_args_bit), - FMT_ENABLE_IF(NUM_NAMED_ARGS != 0)> -constexpr auto make_format_args(T&... args) - -> detail::format_arg_store { - return {args...}; -} -#endif - -/** - * Returns a named argument to be used in a formatting function. - * It should only be used in a call to a formatting function or - * `dynamic_format_arg_store::push_back`. - * - * **Example**: - * - * fmt::print("The answer is {answer}.", fmt::arg("answer", 42)); - */ -template -inline auto arg(const Char* name, const T& arg) -> detail::named_arg { - static_assert(!detail::is_named_arg(), "nested named arguments"); - return {name, arg}; -} -FMT_END_EXPORT - -/// An alias for `basic_format_args`. -// A separate type would result in shorter symbols but break ABI compatibility -// between clang and gcc on ARM (#1919). -FMT_EXPORT using format_args = basic_format_args; - -// We cannot use enum classes as bit fields because of a gcc bug, so we put them -// in namespaces instead (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61414). -// Additionally, if an underlying type is specified, older gcc incorrectly warns -// that the type is too small. Both bugs are fixed in gcc 9.3. -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 903 -# define FMT_ENUM_UNDERLYING_TYPE(type) -#else -# define FMT_ENUM_UNDERLYING_TYPE(type) : type -#endif -namespace align { -enum type FMT_ENUM_UNDERLYING_TYPE(unsigned char){none, left, right, center, - numeric}; -} -using align_t = align::type; -namespace sign { -enum type FMT_ENUM_UNDERLYING_TYPE(unsigned char){none, minus, plus, space}; -} -using sign_t = sign::type; - -namespace detail { - -template -using unsigned_char = typename conditional_t::value, - std::make_unsigned, - type_identity>::type; - -// Character (code unit) type is erased to prevent template bloat. -struct fill_t { - private: - enum { max_size = 4 }; - char data_[max_size] = {' '}; - unsigned char size_ = 1; - - public: - template - FMT_CONSTEXPR void operator=(basic_string_view s) { - auto size = s.size(); - size_ = static_cast(size); - if (size == 1) { - unsigned uchar = static_cast>(s[0]); - data_[0] = static_cast(uchar); - data_[1] = static_cast(uchar >> 8); - return; - } - FMT_ASSERT(size <= max_size, "invalid fill"); - for (size_t i = 0; i < size; ++i) data_[i] = static_cast(s[i]); - } - - FMT_CONSTEXPR void operator=(char c) { - data_[0] = c; - size_ = 1; - } - - constexpr auto size() const -> size_t { return size_; } - - template constexpr auto get() const -> Char { - using uchar = unsigned char; - return static_cast(static_cast(data_[0]) | - (static_cast(data_[1]) << 8)); - } - - template ::value)> - constexpr auto data() const -> const Char* { - return data_; - } - template ::value)> - constexpr auto data() const -> const Char* { - return nullptr; - } -}; -} // namespace detail - -enum class presentation_type : unsigned char { - // Common specifiers: - none = 0, - debug = 1, // '?' - string = 2, // 's' (string, bool) - - // Integral, bool and character specifiers: - dec = 3, // 'd' - hex, // 'x' or 'X' - oct, // 'o' - bin, // 'b' or 'B' - chr, // 'c' - - // String and pointer specifiers: - pointer = 3, // 'p' - - // Floating-point specifiers: - exp = 1, // 'e' or 'E' (1 since there is no FP debug presentation) - fixed, // 'f' or 'F' - general, // 'g' or 'G' - hexfloat // 'a' or 'A' -}; - -// Format specifiers for built-in and string types. -struct format_specs { - int width; - int precision; - presentation_type type; - align_t align : 4; - sign_t sign : 3; - bool upper : 1; // An uppercase version e.g. 'X' for 'x'. - bool alt : 1; // Alternate form ('#'). - bool localized : 1; - detail::fill_t fill; - - constexpr format_specs() - : width(0), - precision(-1), - type(presentation_type::none), - align(align::none), - sign(sign::none), - upper(false), - alt(false), - localized(false) {} -}; - -namespace detail { - -enum class arg_id_kind { none, index, name }; - -// An argument reference. -template struct arg_ref { - FMT_CONSTEXPR arg_ref() : kind(arg_id_kind::none), val() {} - - FMT_CONSTEXPR explicit arg_ref(int index) - : kind(arg_id_kind::index), val(index) {} - FMT_CONSTEXPR explicit arg_ref(basic_string_view name) - : kind(arg_id_kind::name), val(name) {} - - FMT_CONSTEXPR auto operator=(int idx) -> arg_ref& { - kind = arg_id_kind::index; - val.index = idx; - return *this; - } - - arg_id_kind kind; - union value { - FMT_CONSTEXPR value(int idx = 0) : index(idx) {} - FMT_CONSTEXPR value(basic_string_view n) : name(n) {} - - int index; - basic_string_view name; - } val; -}; - -// Format specifiers with width and precision resolved at formatting rather -// than parsing time to allow reusing the same parsed specifiers with -// different sets of arguments (precompilation of format strings). -template struct dynamic_format_specs : format_specs { - arg_ref width_ref; - arg_ref precision_ref; -}; - -// Converts a character to ASCII. Returns '\0' on conversion failure. -template ::value)> -constexpr auto to_ascii(Char c) -> char { - return c <= 0xff ? static_cast(c) : '\0'; -} - -// Returns the number of code units in a code point or 1 on error. -template -FMT_CONSTEXPR auto code_point_length(const Char* begin) -> int { - if (const_check(sizeof(Char) != 1)) return 1; - auto c = static_cast(*begin); - return static_cast((0x3a55000000000000ull >> (2 * (c >> 3))) & 0x3) + 1; -} - -// Return the result via the out param to workaround gcc bug 77539. -template -FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr& out) -> bool { - for (out = first; out != last; ++out) { - if (*out == value) return true; - } - return false; -} - -template <> -inline auto find(const char* first, const char* last, char value, - const char*& out) -> bool { - out = - static_cast(memchr(first, value, to_unsigned(last - first))); - return out != nullptr; -} - -// Parses the range [begin, end) as an unsigned integer. This function assumes -// that the range is non-empty and the first character is a digit. -template -FMT_CONSTEXPR auto parse_nonnegative_int(const Char*& begin, const Char* end, - int error_value) noexcept -> int { - FMT_ASSERT(begin != end && '0' <= *begin && *begin <= '9', ""); - unsigned value = 0, prev = 0; - auto p = begin; - do { - prev = value; - value = value * 10 + unsigned(*p - '0'); - ++p; - } while (p != end && '0' <= *p && *p <= '9'); - auto num_digits = p - begin; - begin = p; - int digits10 = static_cast(sizeof(int) * CHAR_BIT * 3 / 10); - if (num_digits <= digits10) return static_cast(value); - // Check for overflow. - unsigned max = INT_MAX; - return num_digits == digits10 + 1 && - prev * 10ull + unsigned(p[-1] - '0') <= max - ? static_cast(value) - : error_value; -} - -FMT_CONSTEXPR inline auto parse_align(char c) -> align_t { - switch (c) { - case '<': - return align::left; - case '>': - return align::right; - case '^': - return align::center; - } - return align::none; -} - -template constexpr auto is_name_start(Char c) -> bool { - return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '_'; -} - -template -FMT_CONSTEXPR auto do_parse_arg_id(const Char* begin, const Char* end, - Handler&& handler) -> const Char* { - Char c = *begin; - if (c >= '0' && c <= '9') { - int index = 0; - if (c != '0') - index = parse_nonnegative_int(begin, end, INT_MAX); - else - ++begin; - if (begin == end || (*begin != '}' && *begin != ':')) - report_error("invalid format string"); - else - handler.on_index(index); - return begin; - } - if (!is_name_start(c)) { - report_error("invalid format string"); - return begin; - } - auto it = begin; - do { - ++it; - } while (it != end && (is_name_start(*it) || ('0' <= *it && *it <= '9'))); - handler.on_name({begin, to_unsigned(it - begin)}); - return it; -} - -template -FMT_CONSTEXPR auto parse_arg_id(const Char* begin, const Char* end, - Handler&& handler) -> const Char* { - FMT_ASSERT(begin != end, ""); - Char c = *begin; - if (c != '}' && c != ':') return do_parse_arg_id(begin, end, handler); - handler.on_auto(); - return begin; -} - -template struct dynamic_spec_id_handler { - basic_format_parse_context& ctx; - arg_ref& ref; - - FMT_CONSTEXPR void on_auto() { - int id = ctx.next_arg_id(); - ref = arg_ref(id); - ctx.check_dynamic_spec(id); - } - FMT_CONSTEXPR void on_index(int id) { - ref = arg_ref(id); - ctx.check_arg_id(id); - ctx.check_dynamic_spec(id); - } - FMT_CONSTEXPR void on_name(basic_string_view id) { - ref = arg_ref(id); - ctx.check_arg_id(id); - } -}; - -// Parses [integer | "{" [arg_id] "}"]. -template -FMT_CONSTEXPR auto parse_dynamic_spec(const Char* begin, const Char* end, - int& value, arg_ref& ref, - basic_format_parse_context& ctx) - -> const Char* { - FMT_ASSERT(begin != end, ""); - if ('0' <= *begin && *begin <= '9') { - int val = parse_nonnegative_int(begin, end, -1); - if (val != -1) - value = val; - else - report_error("number is too big"); - } else if (*begin == '{') { - ++begin; - auto handler = dynamic_spec_id_handler{ctx, ref}; - if (begin != end) begin = parse_arg_id(begin, end, handler); - if (begin != end && *begin == '}') return ++begin; - report_error("invalid format string"); - } - return begin; -} - -template -FMT_CONSTEXPR auto parse_precision(const Char* begin, const Char* end, - int& value, arg_ref& ref, - basic_format_parse_context& ctx) - -> const Char* { - ++begin; - if (begin == end || *begin == '}') { - report_error("invalid precision"); - return begin; - } - return parse_dynamic_spec(begin, end, value, ref, ctx); -} - -enum class state { start, align, sign, hash, zero, width, precision, locale }; - -// Parses standard format specifiers. -template -FMT_CONSTEXPR auto parse_format_specs(const Char* begin, const Char* end, - dynamic_format_specs& specs, - basic_format_parse_context& ctx, - type arg_type) -> const Char* { - auto c = '\0'; - if (end - begin > 1) { - auto next = to_ascii(begin[1]); - c = parse_align(next) == align::none ? to_ascii(*begin) : '\0'; - } else { - if (begin == end) return begin; - c = to_ascii(*begin); - } - - struct { - state current_state = state::start; - FMT_CONSTEXPR void operator()(state s, bool valid = true) { - if (current_state >= s || !valid) - report_error("invalid format specifier"); - current_state = s; - } - } enter_state; - - using pres = presentation_type; - constexpr auto integral_set = sint_set | uint_set | bool_set | char_set; - struct { - const Char*& begin; - dynamic_format_specs& specs; - type arg_type; - - FMT_CONSTEXPR auto operator()(pres pres_type, int set) -> const Char* { - if (!in(arg_type, set)) { - if (arg_type == type::none_type) return begin; - report_error("invalid format specifier"); - } - specs.type = pres_type; - return begin + 1; - } - } parse_presentation_type{begin, specs, arg_type}; - - for (;;) { - switch (c) { - case '<': - case '>': - case '^': - enter_state(state::align); - specs.align = parse_align(c); - ++begin; - break; - case '+': - case '-': - case ' ': - if (arg_type == type::none_type) return begin; - enter_state(state::sign, in(arg_type, sint_set | float_set)); - switch (c) { - case '+': - specs.sign = sign::plus; - break; - case '-': - specs.sign = sign::minus; - break; - case ' ': - specs.sign = sign::space; - break; - } - ++begin; - break; - case '#': - if (arg_type == type::none_type) return begin; - enter_state(state::hash, is_arithmetic_type(arg_type)); - specs.alt = true; - ++begin; - break; - case '0': - enter_state(state::zero); - if (!is_arithmetic_type(arg_type)) { - if (arg_type == type::none_type) return begin; - report_error("format specifier requires numeric argument"); - } - if (specs.align == align::none) { - // Ignore 0 if align is specified for compatibility with std::format. - specs.align = align::numeric; - specs.fill = '0'; - } - ++begin; - break; - case '1': - case '2': - case '3': - case '4': - case '5': - case '6': - case '7': - case '8': - case '9': - case '{': - enter_state(state::width); - begin = parse_dynamic_spec(begin, end, specs.width, specs.width_ref, ctx); - break; - case '.': - if (arg_type == type::none_type) return begin; - enter_state(state::precision, - in(arg_type, float_set | string_set | cstring_set)); - begin = parse_precision(begin, end, specs.precision, specs.precision_ref, - ctx); - break; - case 'L': - if (arg_type == type::none_type) return begin; - enter_state(state::locale, is_arithmetic_type(arg_type)); - specs.localized = true; - ++begin; - break; - case 'd': - return parse_presentation_type(pres::dec, integral_set); - case 'X': - specs.upper = true; - FMT_FALLTHROUGH; - case 'x': - return parse_presentation_type(pres::hex, integral_set); - case 'o': - return parse_presentation_type(pres::oct, integral_set); - case 'B': - specs.upper = true; - FMT_FALLTHROUGH; - case 'b': - return parse_presentation_type(pres::bin, integral_set); - case 'E': - specs.upper = true; - FMT_FALLTHROUGH; - case 'e': - return parse_presentation_type(pres::exp, float_set); - case 'F': - specs.upper = true; - FMT_FALLTHROUGH; - case 'f': - return parse_presentation_type(pres::fixed, float_set); - case 'G': - specs.upper = true; - FMT_FALLTHROUGH; - case 'g': - return parse_presentation_type(pres::general, float_set); - case 'A': - specs.upper = true; - FMT_FALLTHROUGH; - case 'a': - return parse_presentation_type(pres::hexfloat, float_set); - case 'c': - if (arg_type == type::bool_type) report_error("invalid format specifier"); - return parse_presentation_type(pres::chr, integral_set); - case 's': - return parse_presentation_type(pres::string, - bool_set | string_set | cstring_set); - case 'p': - return parse_presentation_type(pres::pointer, pointer_set | cstring_set); - case '?': - return parse_presentation_type(pres::debug, - char_set | string_set | cstring_set); - case '}': - return begin; - default: { - if (*begin == '}') return begin; - // Parse fill and alignment. - auto fill_end = begin + code_point_length(begin); - if (end - fill_end <= 0) { - report_error("invalid format specifier"); - return begin; - } - if (*begin == '{') { - report_error("invalid fill character '{'"); - return begin; - } - auto align = parse_align(to_ascii(*fill_end)); - enter_state(state::align, align != align::none); - specs.fill = - basic_string_view(begin, to_unsigned(fill_end - begin)); - specs.align = align; - begin = fill_end + 1; - } - } - if (begin == end) return begin; - c = to_ascii(*begin); - } -} - -template -FMT_CONSTEXPR auto parse_replacement_field(const Char* begin, const Char* end, - Handler&& handler) -> const Char* { - struct id_adapter { - Handler& handler; - int arg_id; - - FMT_CONSTEXPR void on_auto() { arg_id = handler.on_arg_id(); } - FMT_CONSTEXPR void on_index(int id) { arg_id = handler.on_arg_id(id); } - FMT_CONSTEXPR void on_name(basic_string_view id) { - arg_id = handler.on_arg_id(id); - } - }; - - ++begin; - if (begin == end) return handler.on_error("invalid format string"), end; - if (*begin == '}') { - handler.on_replacement_field(handler.on_arg_id(), begin); - } else if (*begin == '{') { - handler.on_text(begin, begin + 1); - } else { - auto adapter = id_adapter{handler, 0}; - begin = parse_arg_id(begin, end, adapter); - Char c = begin != end ? *begin : Char(); - if (c == '}') { - handler.on_replacement_field(adapter.arg_id, begin); - } else if (c == ':') { - begin = handler.on_format_specs(adapter.arg_id, begin + 1, end); - if (begin == end || *begin != '}') - return handler.on_error("unknown format specifier"), end; - } else { - return handler.on_error("missing '}' in format string"), end; - } - } - return begin + 1; -} - -template -FMT_CONSTEXPR void parse_format_string(basic_string_view format_str, - Handler&& handler) { - auto begin = format_str.data(); - auto end = begin + format_str.size(); - if (end - begin < 32) { - // Use a simple loop instead of memchr for small strings. - const Char* p = begin; - while (p != end) { - auto c = *p++; - if (c == '{') { - handler.on_text(begin, p - 1); - begin = p = parse_replacement_field(p - 1, end, handler); - } else if (c == '}') { - if (p == end || *p != '}') - return handler.on_error("unmatched '}' in format string"); - handler.on_text(begin, p); - begin = ++p; - } - } - handler.on_text(begin, end); - return; - } - struct writer { - FMT_CONSTEXPR void operator()(const Char* from, const Char* to) { - if (from == to) return; - for (;;) { - const Char* p = nullptr; - if (!find(from, to, Char('}'), p)) - return handler_.on_text(from, to); - ++p; - if (p == to || *p != '}') - return handler_.on_error("unmatched '}' in format string"); - handler_.on_text(from, p); - from = p + 1; - } - } - Handler& handler_; - } write = {handler}; - while (begin != end) { - // Doing two passes with memchr (one for '{' and another for '}') is up to - // 2.5x faster than the naive one-pass implementation on big format strings. - const Char* p = begin; - if (*begin != '{' && !find(begin + 1, end, Char('{'), p)) - return write(begin, end); - write(begin, p); - begin = parse_replacement_field(p, end, handler); - } -} - -template ::value> struct strip_named_arg { - using type = T; -}; -template struct strip_named_arg { - using type = remove_cvref_t; -}; - -template -FMT_VISIBILITY("hidden") // Suppress an ld warning on macOS (#3769). -FMT_CONSTEXPR auto parse_format_specs(ParseContext& ctx) - -> decltype(ctx.begin()) { - using char_type = typename ParseContext::char_type; - using context = buffered_context; - using mapped_type = conditional_t< - mapped_type_constant::value != type::custom_type, - decltype(arg_mapper().map(std::declval())), - typename strip_named_arg::type>; -#if defined(__cpp_if_constexpr) - if constexpr (std::is_default_constructible< - formatter>::value) { - return formatter().parse(ctx); - } else { - type_is_unformattable_for _; - return ctx.begin(); - } -#else - return formatter().parse(ctx); -#endif -} - -// Checks char specs and returns true iff the presentation type is char-like. -FMT_CONSTEXPR inline auto check_char_specs(const format_specs& specs) -> bool { - if (specs.type != presentation_type::none && - specs.type != presentation_type::chr && - specs.type != presentation_type::debug) { - return false; - } - if (specs.align == align::numeric || specs.sign != sign::none || specs.alt) - report_error("invalid format specifier for char"); - return true; -} - -#if FMT_USE_NONTYPE_TEMPLATE_ARGS -template -constexpr auto get_arg_index_by_name(basic_string_view name) -> int { - if constexpr (is_statically_named_arg()) { - if (name == T::name) return N; - } - if constexpr (sizeof...(Args) > 0) - return get_arg_index_by_name(name); - (void)name; // Workaround an MSVC bug about "unused" parameter. - return -1; -} -#endif - -template -FMT_CONSTEXPR auto get_arg_index_by_name(basic_string_view name) -> int { -#if FMT_USE_NONTYPE_TEMPLATE_ARGS - if constexpr (sizeof...(Args) > 0) - return get_arg_index_by_name<0, Args...>(name); -#endif - (void)name; - return -1; -} - -template class format_string_checker { - private: - using parse_context_type = compile_parse_context; - static constexpr int num_args = sizeof...(Args); - - // Format specifier parsing function. - // In the future basic_format_parse_context will replace compile_parse_context - // here and will use is_constant_evaluated and downcasting to access the data - // needed for compile-time checks: https://godbolt.org/z/GvWzcTjh1. - using parse_func = const Char* (*)(parse_context_type&); - - type types_[num_args > 0 ? static_cast(num_args) : 1]; - parse_context_type context_; - parse_func parse_funcs_[num_args > 0 ? static_cast(num_args) : 1]; - - public: - explicit FMT_CONSTEXPR format_string_checker(basic_string_view fmt) - : types_{mapped_type_constant>::value...}, - context_(fmt, num_args, types_), - parse_funcs_{&parse_format_specs...} {} - - FMT_CONSTEXPR void on_text(const Char*, const Char*) {} - - FMT_CONSTEXPR auto on_arg_id() -> int { return context_.next_arg_id(); } - FMT_CONSTEXPR auto on_arg_id(int id) -> int { - return context_.check_arg_id(id), id; - } - FMT_CONSTEXPR auto on_arg_id(basic_string_view id) -> int { -#if FMT_USE_NONTYPE_TEMPLATE_ARGS - auto index = get_arg_index_by_name(id); - if (index < 0) on_error("named argument is not found"); - return index; -#else - (void)id; - on_error("compile-time checks for named arguments require C++20 support"); - return 0; -#endif - } - - FMT_CONSTEXPR void on_replacement_field(int id, const Char* begin) { - on_format_specs(id, begin, begin); // Call parse() on empty specs. - } - - FMT_CONSTEXPR auto on_format_specs(int id, const Char* begin, const Char*) - -> const Char* { - context_.advance_to(begin); - // id >= 0 check is a workaround for gcc 10 bug (#2065). - return id >= 0 && id < num_args ? parse_funcs_[id](context_) : begin; - } - - FMT_NORETURN FMT_CONSTEXPR void on_error(const char* message) { - report_error(message); - } -}; - -// A base class for compile-time strings. -struct compile_string {}; - -template -using is_compile_string = std::is_base_of; - -// Reports a compile-time error if S is not a valid format string. -template ::value)> -FMT_ALWAYS_INLINE void check_format_string(const S&) { -#ifdef FMT_ENFORCE_COMPILE_STRING - static_assert(is_compile_string::value, - "FMT_ENFORCE_COMPILE_STRING requires all format strings to use " - "FMT_STRING."); -#endif -} -template ::value)> -void check_format_string(S format_str) { - using char_t = typename S::char_type; - FMT_CONSTEXPR auto s = basic_string_view(format_str); - using checker = format_string_checker...>; - FMT_CONSTEXPR bool error = (parse_format_string(s, checker(s)), true); - ignore_unused(error); -} - -// Report truncation to prevent silent data loss. -inline void report_truncation(bool truncated) { - if (truncated) report_error("output is truncated"); -} - -// Use vformat_args and avoid type_identity to keep symbols short and workaround -// a GCC <= 4.8 bug. -template struct vformat_args { - using type = basic_format_args>; -}; -template <> struct vformat_args { - using type = format_args; -}; - -template -void vformat_to(buffer& buf, basic_string_view fmt, - typename vformat_args::type args, locale_ref loc = {}); - -FMT_API void vprint_mojibake(FILE*, string_view, format_args, bool = false); -#ifndef _WIN32 -inline void vprint_mojibake(FILE*, string_view, format_args, bool) {} -#endif - -template struct native_formatter { - private: - dynamic_format_specs specs_; - - public: - using nonlocking = void; - - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> const Char* { - if (ctx.begin() == ctx.end() || *ctx.begin() == '}') return ctx.begin(); - auto end = parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, TYPE); - if (const_check(TYPE == type::char_type)) check_char_specs(specs_); - return end; - } - - template - FMT_CONSTEXPR void set_debug_format(bool set = true) { - specs_.type = set ? presentation_type::debug : presentation_type::none; - } - - template - FMT_CONSTEXPR auto format(const T& val, FormatContext& ctx) const - -> decltype(ctx.out()); -}; -} // namespace detail - -FMT_BEGIN_EXPORT - -// A formatter specialization for natively supported types. -template -struct formatter::value != - detail::type::custom_type>> - : detail::native_formatter::value> { -}; - -template struct runtime_format_string { - basic_string_view str; -}; - -/// A compile-time format string. -template class basic_format_string { - private: - basic_string_view str_; - - public: - template < - typename S, - FMT_ENABLE_IF( - std::is_convertible>::value || - (detail::is_compile_string::value && - std::is_constructible, const S&>::value))> - FMT_CONSTEVAL FMT_ALWAYS_INLINE basic_format_string(const S& s) : str_(s) { - static_assert( - detail::count< - (std::is_base_of>::value && - std::is_reference::value)...>() == 0, - "passing views as lvalues is disallowed"); -#if FMT_USE_CONSTEVAL - if constexpr (detail::count_named_args() == - detail::count_statically_named_args()) { - using checker = - detail::format_string_checker...>; - detail::parse_format_string(str_, checker(s)); - } -#else - detail::check_format_string(s); -#endif - } - basic_format_string(runtime_format_string fmt) : str_(fmt.str) {} - - FMT_ALWAYS_INLINE operator basic_string_view() const { return str_; } - auto get() const -> basic_string_view { return str_; } -}; - -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 -// Workaround broken conversion on older gcc. -template using format_string = string_view; -inline auto runtime(string_view s) -> string_view { return s; } -#else -template -using format_string = basic_format_string...>; -/** - * Creates a runtime format string. - * - * **Example**: - * - * // Check format string at runtime instead of compile-time. - * fmt::print(fmt::runtime("{:d}"), "I am not a number"); - */ -inline auto runtime(string_view s) -> runtime_format_string<> { return {{s}}; } -#endif - -/// Formats a string and writes the output to `out`. -template , - char>::value)> -auto vformat_to(OutputIt&& out, string_view fmt, format_args args) - -> remove_cvref_t { - auto&& buf = detail::get_buffer(out); - detail::vformat_to(buf, fmt, args, {}); - return detail::get_iterator(buf, out); -} - -/** - * Formats `args` according to specifications in `fmt`, writes the result to - * the output iterator `out` and returns the iterator past the end of the output - * range. `format_to` does not append a terminating null character. - * - * **Example**: - * - * auto out = std::vector(); - * fmt::format_to(std::back_inserter(out), "{}", 42); - */ -template , - char>::value)> -FMT_INLINE auto format_to(OutputIt&& out, format_string fmt, T&&... args) - -> remove_cvref_t { - return vformat_to(FMT_FWD(out), fmt, fmt::make_format_args(args...)); -} - -template struct format_to_n_result { - /// Iterator past the end of the output range. - OutputIt out; - /// Total (not truncated) output size. - size_t size; -}; - -template ::value)> -auto vformat_to_n(OutputIt out, size_t n, string_view fmt, format_args args) - -> format_to_n_result { - using traits = detail::fixed_buffer_traits; - auto buf = detail::iterator_buffer(out, n); - detail::vformat_to(buf, fmt, args, {}); - return {buf.out(), buf.count()}; -} - -/** - * Formats `args` according to specifications in `fmt`, writes up to `n` - * characters of the result to the output iterator `out` and returns the total - * (not truncated) output size and the iterator past the end of the output - * range. `format_to_n` does not append a terminating null character. - */ -template ::value)> -FMT_INLINE auto format_to_n(OutputIt out, size_t n, format_string fmt, - T&&... args) -> format_to_n_result { - return vformat_to_n(out, n, fmt, fmt::make_format_args(args...)); -} - -template -struct format_to_result { - /// Iterator pointing to just after the last successful write in the range. - OutputIt out; - /// Specifies if the output was truncated. - bool truncated; - - FMT_CONSTEXPR operator OutputIt&() & { - detail::report_truncation(truncated); - return out; - } - FMT_CONSTEXPR operator const OutputIt&() const& { - detail::report_truncation(truncated); - return out; - } - FMT_CONSTEXPR operator OutputIt&&() && { - detail::report_truncation(truncated); - return static_cast(out); - } -}; - -template -auto vformat_to(char (&out)[N], string_view fmt, format_args args) - -> format_to_result { - auto result = vformat_to_n(out, N, fmt, args); - return {result.out, result.size > N}; -} - -template -FMT_INLINE auto format_to(char (&out)[N], format_string fmt, T&&... args) - -> format_to_result { - auto result = fmt::format_to_n(out, N, fmt, static_cast(args)...); - return {result.out, result.size > N}; -} - -/// Returns the number of chars in the output of `format(fmt, args...)`. -template -FMT_NODISCARD FMT_INLINE auto formatted_size(format_string fmt, - T&&... args) -> size_t { - auto buf = detail::counting_buffer<>(); - detail::vformat_to(buf, fmt, fmt::make_format_args(args...), {}); - return buf.count(); -} - -FMT_API void vprint(string_view fmt, format_args args); -FMT_API void vprint(FILE* f, string_view fmt, format_args args); -FMT_API void vprint_buffered(FILE* f, string_view fmt, format_args args); -FMT_API void vprintln(FILE* f, string_view fmt, format_args args); - -/** - * Formats `args` according to specifications in `fmt` and writes the output - * to `stdout`. - * - * **Example**: - * - * fmt::print("The answer is {}.", 42); - */ -template -FMT_INLINE void print(format_string fmt, T&&... args) { - const auto& vargs = fmt::make_format_args(args...); - if (!detail::use_utf8()) return detail::vprint_mojibake(stdout, fmt, vargs); - return detail::is_locking() ? vprint_buffered(stdout, fmt, vargs) - : vprint(fmt, vargs); -} - -/** - * Formats `args` according to specifications in `fmt` and writes the - * output to the file `f`. - * - * **Example**: - * - * fmt::print(stderr, "Don't {}!", "panic"); - */ -template -FMT_INLINE void print(FILE* f, format_string fmt, T&&... args) { - const auto& vargs = fmt::make_format_args(args...); - if (!detail::use_utf8()) return detail::vprint_mojibake(f, fmt, vargs); - return detail::is_locking() ? vprint_buffered(f, fmt, vargs) - : vprint(f, fmt, vargs); -} - -/// Formats `args` according to specifications in `fmt` and writes the output -/// to the file `f` followed by a newline. -template -FMT_INLINE void println(FILE* f, format_string fmt, T&&... args) { - const auto& vargs = fmt::make_format_args(args...); - return detail::use_utf8() ? vprintln(f, fmt, vargs) - : detail::vprint_mojibake(f, fmt, vargs, true); -} - -/// Formats `args` according to specifications in `fmt` and writes the output -/// to `stdout` followed by a newline. -template -FMT_INLINE void println(format_string fmt, T&&... args) { - return fmt::println(stdout, fmt, static_cast(args)...); -} - -FMT_END_EXPORT -FMT_GCC_PRAGMA("GCC pop_options") -FMT_END_NAMESPACE - -#ifdef FMT_HEADER_ONLY -# include "format.h" -#endif -#endif // FMT_BASE_H_ diff --git a/tt_metal/third_party/fmt/fmt/chrono.h b/tt_metal/third_party/fmt/fmt/chrono.h deleted file mode 100644 index c93123fd335..00000000000 --- a/tt_metal/third_party/fmt/fmt/chrono.h +++ /dev/null @@ -1,2432 +0,0 @@ -// Formatting library for C++ - chrono support -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_CHRONO_H_ -#define FMT_CHRONO_H_ - -#ifndef FMT_MODULE -# include -# include -# include // std::isfinite -# include // std::memcpy -# include -# include -# include -# include -# include -#endif - -#include "format.h" - -FMT_BEGIN_NAMESPACE - -// Check if std::chrono::local_t is available. -#ifndef FMT_USE_LOCAL_TIME -# ifdef __cpp_lib_chrono -# define FMT_USE_LOCAL_TIME (__cpp_lib_chrono >= 201907L) -# else -# define FMT_USE_LOCAL_TIME 0 -# endif -#endif - -// Check if std::chrono::utc_timestamp is available. -#ifndef FMT_USE_UTC_TIME -# ifdef __cpp_lib_chrono -# define FMT_USE_UTC_TIME (__cpp_lib_chrono >= 201907L) -# else -# define FMT_USE_UTC_TIME 0 -# endif -#endif - -// Enable tzset. -#ifndef FMT_USE_TZSET -// UWP doesn't provide _tzset. -# if FMT_HAS_INCLUDE("winapifamily.h") -# include -# endif -# if defined(_WIN32) && (!defined(WINAPI_FAMILY) || \ - (WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP)) -# define FMT_USE_TZSET 1 -# else -# define FMT_USE_TZSET 0 -# endif -#endif - -// Enable safe chrono durations, unless explicitly disabled. -#ifndef FMT_SAFE_DURATION_CAST -# define FMT_SAFE_DURATION_CAST 1 -#endif -#if FMT_SAFE_DURATION_CAST - -// For conversion between std::chrono::durations without undefined -// behaviour or erroneous results. -// This is a stripped down version of duration_cast, for inclusion in fmt. -// See https://github.com/pauldreik/safe_duration_cast -// -// Copyright Paul Dreik 2019 -namespace safe_duration_cast { - -template ::value && - std::numeric_limits::is_signed == - std::numeric_limits::is_signed)> -FMT_CONSTEXPR auto lossless_integral_conversion(const From from, int& ec) - -> To { - ec = 0; - using F = std::numeric_limits; - using T = std::numeric_limits; - static_assert(F::is_integer, "From must be integral"); - static_assert(T::is_integer, "To must be integral"); - - // A and B are both signed, or both unsigned. - if (detail::const_check(F::digits <= T::digits)) { - // From fits in To without any problem. - } else { - // From does not always fit in To, resort to a dynamic check. - if (from < (T::min)() || from > (T::max)()) { - // outside range. - ec = 1; - return {}; - } - } - return static_cast(from); -} - -/// Converts From to To, without loss. If the dynamic value of from -/// can't be converted to To without loss, ec is set. -template ::value && - std::numeric_limits::is_signed != - std::numeric_limits::is_signed)> -FMT_CONSTEXPR auto lossless_integral_conversion(const From from, int& ec) - -> To { - ec = 0; - using F = std::numeric_limits; - using T = std::numeric_limits; - static_assert(F::is_integer, "From must be integral"); - static_assert(T::is_integer, "To must be integral"); - - if (detail::const_check(F::is_signed && !T::is_signed)) { - // From may be negative, not allowed! - if (fmt::detail::is_negative(from)) { - ec = 1; - return {}; - } - // From is positive. Can it always fit in To? - if (detail::const_check(F::digits > T::digits) && - from > static_cast(detail::max_value())) { - ec = 1; - return {}; - } - } - - if (detail::const_check(!F::is_signed && T::is_signed && - F::digits >= T::digits) && - from > static_cast(detail::max_value())) { - ec = 1; - return {}; - } - return static_cast(from); // Lossless conversion. -} - -template ::value)> -FMT_CONSTEXPR auto lossless_integral_conversion(const From from, int& ec) - -> To { - ec = 0; - return from; -} // function - -// clang-format off -/** - * converts From to To if possible, otherwise ec is set. - * - * input | output - * ---------------------------------|--------------- - * NaN | NaN - * Inf | Inf - * normal, fits in output | converted (possibly lossy) - * normal, does not fit in output | ec is set - * subnormal | best effort - * -Inf | -Inf - */ -// clang-format on -template ::value)> -FMT_CONSTEXPR auto safe_float_conversion(const From from, int& ec) -> To { - ec = 0; - using T = std::numeric_limits; - static_assert(std::is_floating_point::value, "From must be floating"); - static_assert(std::is_floating_point::value, "To must be floating"); - - // catch the only happy case - if (std::isfinite(from)) { - if (from >= T::lowest() && from <= (T::max)()) { - return static_cast(from); - } - // not within range. - ec = 1; - return {}; - } - - // nan and inf will be preserved - return static_cast(from); -} // function - -template ::value)> -FMT_CONSTEXPR auto safe_float_conversion(const From from, int& ec) -> To { - ec = 0; - static_assert(std::is_floating_point::value, "From must be floating"); - return from; -} - -/// Safe duration cast between integral durations -template ::value), - FMT_ENABLE_IF(std::is_integral::value)> -auto safe_duration_cast(std::chrono::duration from, - int& ec) -> To { - using From = std::chrono::duration; - ec = 0; - // the basic idea is that we need to convert from count() in the from type - // to count() in the To type, by multiplying it with this: - struct Factor - : std::ratio_divide {}; - - static_assert(Factor::num > 0, "num must be positive"); - static_assert(Factor::den > 0, "den must be positive"); - - // the conversion is like this: multiply from.count() with Factor::num - // /Factor::den and convert it to To::rep, all this without - // overflow/underflow. let's start by finding a suitable type that can hold - // both To, From and Factor::num - using IntermediateRep = - typename std::common_type::type; - - // safe conversion to IntermediateRep - IntermediateRep count = - lossless_integral_conversion(from.count(), ec); - if (ec) return {}; - // multiply with Factor::num without overflow or underflow - if (detail::const_check(Factor::num != 1)) { - const auto max1 = detail::max_value() / Factor::num; - if (count > max1) { - ec = 1; - return {}; - } - const auto min1 = - (std::numeric_limits::min)() / Factor::num; - if (detail::const_check(!std::is_unsigned::value) && - count < min1) { - ec = 1; - return {}; - } - count *= Factor::num; - } - - if (detail::const_check(Factor::den != 1)) count /= Factor::den; - auto tocount = lossless_integral_conversion(count, ec); - return ec ? To() : To(tocount); -} - -/// Safe duration_cast between floating point durations -template ::value), - FMT_ENABLE_IF(std::is_floating_point::value)> -auto safe_duration_cast(std::chrono::duration from, - int& ec) -> To { - using From = std::chrono::duration; - ec = 0; - if (std::isnan(from.count())) { - // nan in, gives nan out. easy. - return To{std::numeric_limits::quiet_NaN()}; - } - // maybe we should also check if from is denormal, and decide what to do about - // it. - - // +-inf should be preserved. - if (std::isinf(from.count())) { - return To{from.count()}; - } - - // the basic idea is that we need to convert from count() in the from type - // to count() in the To type, by multiplying it with this: - struct Factor - : std::ratio_divide {}; - - static_assert(Factor::num > 0, "num must be positive"); - static_assert(Factor::den > 0, "den must be positive"); - - // the conversion is like this: multiply from.count() with Factor::num - // /Factor::den and convert it to To::rep, all this without - // overflow/underflow. let's start by finding a suitable type that can hold - // both To, From and Factor::num - using IntermediateRep = - typename std::common_type::type; - - // force conversion of From::rep -> IntermediateRep to be safe, - // even if it will never happen be narrowing in this context. - IntermediateRep count = - safe_float_conversion(from.count(), ec); - if (ec) { - return {}; - } - - // multiply with Factor::num without overflow or underflow - if (detail::const_check(Factor::num != 1)) { - constexpr auto max1 = detail::max_value() / - static_cast(Factor::num); - if (count > max1) { - ec = 1; - return {}; - } - constexpr auto min1 = std::numeric_limits::lowest() / - static_cast(Factor::num); - if (count < min1) { - ec = 1; - return {}; - } - count *= static_cast(Factor::num); - } - - // this can't go wrong, right? den>0 is checked earlier. - if (detail::const_check(Factor::den != 1)) { - using common_t = typename std::common_type::type; - count /= static_cast(Factor::den); - } - - // convert to the to type, safely - using ToRep = typename To::rep; - - const ToRep tocount = safe_float_conversion(count, ec); - if (ec) { - return {}; - } - return To{tocount}; -} -} // namespace safe_duration_cast -#endif - -// Prevents expansion of a preceding token as a function-style macro. -// Usage: f FMT_NOMACRO() -#define FMT_NOMACRO - -namespace detail { -template struct null {}; -inline auto localtime_r FMT_NOMACRO(...) -> null<> { return null<>(); } -inline auto localtime_s(...) -> null<> { return null<>(); } -inline auto gmtime_r(...) -> null<> { return null<>(); } -inline auto gmtime_s(...) -> null<> { return null<>(); } - -// It is defined here and not in ostream.h because the latter has expensive -// includes. -template class formatbuf : public Streambuf { - private: - using char_type = typename Streambuf::char_type; - using streamsize = decltype(std::declval().sputn(nullptr, 0)); - using int_type = typename Streambuf::int_type; - using traits_type = typename Streambuf::traits_type; - - buffer& buffer_; - - public: - explicit formatbuf(buffer& buf) : buffer_(buf) {} - - protected: - // The put area is always empty. This makes the implementation simpler and has - // the advantage that the streambuf and the buffer are always in sync and - // sputc never writes into uninitialized memory. A disadvantage is that each - // call to sputc always results in a (virtual) call to overflow. There is no - // disadvantage here for sputn since this always results in a call to xsputn. - - auto overflow(int_type ch) -> int_type override { - if (!traits_type::eq_int_type(ch, traits_type::eof())) - buffer_.push_back(static_cast(ch)); - return ch; - } - - auto xsputn(const char_type* s, streamsize count) -> streamsize override { - buffer_.append(s, s + count); - return count; - } -}; - -inline auto get_classic_locale() -> const std::locale& { - static const auto& locale = std::locale::classic(); - return locale; -} - -template struct codecvt_result { - static constexpr const size_t max_size = 32; - CodeUnit buf[max_size]; - CodeUnit* end; -}; - -template -void write_codecvt(codecvt_result& out, string_view in_buf, - const std::locale& loc) { -#if FMT_CLANG_VERSION -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated" - auto& f = std::use_facet>(loc); -# pragma clang diagnostic pop -#else - auto& f = std::use_facet>(loc); -#endif - auto mb = std::mbstate_t(); - const char* from_next = nullptr; - auto result = f.in(mb, in_buf.begin(), in_buf.end(), from_next, - std::begin(out.buf), std::end(out.buf), out.end); - if (result != std::codecvt_base::ok) - FMT_THROW(format_error("failed to format time")); -} - -template -auto write_encoded_tm_str(OutputIt out, string_view in, const std::locale& loc) - -> OutputIt { - if (detail::use_utf8() && loc != get_classic_locale()) { - // char16_t and char32_t codecvts are broken in MSVC (linkage errors) and - // gcc-4. -#if FMT_MSC_VERSION != 0 || \ - (defined(__GLIBCXX__) && \ - (!defined(_GLIBCXX_USE_DUAL_ABI) || _GLIBCXX_USE_DUAL_ABI == 0)) - // The _GLIBCXX_USE_DUAL_ABI macro is always defined in libstdc++ from gcc-5 - // and newer. - using code_unit = wchar_t; -#else - using code_unit = char32_t; -#endif - - using unit_t = codecvt_result; - unit_t unit; - write_codecvt(unit, in, loc); - // In UTF-8 is used one to four one-byte code units. - auto u = - to_utf8>(); - if (!u.convert({unit.buf, to_unsigned(unit.end - unit.buf)})) - FMT_THROW(format_error("failed to format time")); - return copy(u.c_str(), u.c_str() + u.size(), out); - } - return copy(in.data(), in.data() + in.size(), out); -} - -template ::value)> -auto write_tm_str(OutputIt out, string_view sv, const std::locale& loc) - -> OutputIt { - codecvt_result unit; - write_codecvt(unit, sv, loc); - return copy(unit.buf, unit.end, out); -} - -template ::value)> -auto write_tm_str(OutputIt out, string_view sv, const std::locale& loc) - -> OutputIt { - return write_encoded_tm_str(out, sv, loc); -} - -template -inline void do_write(buffer& buf, const std::tm& time, - const std::locale& loc, char format, char modifier) { - auto&& format_buf = formatbuf>(buf); - auto&& os = std::basic_ostream(&format_buf); - os.imbue(loc); - const auto& facet = std::use_facet>(loc); - auto end = facet.put(os, os, Char(' '), &time, format, modifier); - if (end.failed()) FMT_THROW(format_error("failed to format time")); -} - -template ::value)> -auto write(OutputIt out, const std::tm& time, const std::locale& loc, - char format, char modifier = 0) -> OutputIt { - auto&& buf = get_buffer(out); - do_write(buf, time, loc, format, modifier); - return get_iterator(buf, out); -} - -template ::value)> -auto write(OutputIt out, const std::tm& time, const std::locale& loc, - char format, char modifier = 0) -> OutputIt { - auto&& buf = basic_memory_buffer(); - do_write(buf, time, loc, format, modifier); - return write_encoded_tm_str(out, string_view(buf.data(), buf.size()), loc); -} - -template -struct is_same_arithmetic_type - : public std::integral_constant::value && - std::is_integral::value) || - (std::is_floating_point::value && - std::is_floating_point::value)> { -}; - -template < - typename To, typename FromRep, typename FromPeriod, - FMT_ENABLE_IF(is_same_arithmetic_type::value)> -auto fmt_duration_cast(std::chrono::duration from) -> To { -#if FMT_SAFE_DURATION_CAST - // Throwing version of safe_duration_cast is only available for - // integer to integer or float to float casts. - int ec; - To to = safe_duration_cast::safe_duration_cast(from, ec); - if (ec) FMT_THROW(format_error("cannot format duration")); - return to; -#else - // Standard duration cast, may overflow. - return std::chrono::duration_cast(from); -#endif -} - -template < - typename To, typename FromRep, typename FromPeriod, - FMT_ENABLE_IF(!is_same_arithmetic_type::value)> -auto fmt_duration_cast(std::chrono::duration from) -> To { - // Mixed integer <-> float cast is not supported by safe_duration_cast. - return std::chrono::duration_cast(from); -} - -template -auto to_time_t( - std::chrono::time_point time_point) - -> std::time_t { - // Cannot use std::chrono::system_clock::to_time_t since this would first - // require a cast to std::chrono::system_clock::time_point, which could - // overflow. - return fmt_duration_cast>( - time_point.time_since_epoch()) - .count(); -} -} // namespace detail - -FMT_BEGIN_EXPORT - -/** - * Converts given time since epoch as `std::time_t` value into calendar time, - * expressed in local time. Unlike `std::localtime`, this function is - * thread-safe on most platforms. - */ -inline auto localtime(std::time_t time) -> std::tm { - struct dispatcher { - std::time_t time_; - std::tm tm_; - - dispatcher(std::time_t t) : time_(t) {} - - auto run() -> bool { - using namespace fmt::detail; - return handle(localtime_r(&time_, &tm_)); - } - - auto handle(std::tm* tm) -> bool { return tm != nullptr; } - - auto handle(detail::null<>) -> bool { - using namespace fmt::detail; - return fallback(localtime_s(&tm_, &time_)); - } - - auto fallback(int res) -> bool { return res == 0; } - -#if !FMT_MSC_VERSION - auto fallback(detail::null<>) -> bool { - using namespace fmt::detail; - std::tm* tm = std::localtime(&time_); - if (tm) tm_ = *tm; - return tm != nullptr; - } -#endif - }; - dispatcher lt(time); - // Too big time values may be unsupported. - if (!lt.run()) FMT_THROW(format_error("time_t value out of range")); - return lt.tm_; -} - -#if FMT_USE_LOCAL_TIME -template -inline auto localtime(std::chrono::local_time time) -> std::tm { - return localtime( - detail::to_time_t(std::chrono::current_zone()->to_sys(time))); -} -#endif - -/** - * Converts given time since epoch as `std::time_t` value into calendar time, - * expressed in Coordinated Universal Time (UTC). Unlike `std::gmtime`, this - * function is thread-safe on most platforms. - */ -inline auto gmtime(std::time_t time) -> std::tm { - struct dispatcher { - std::time_t time_; - std::tm tm_; - - dispatcher(std::time_t t) : time_(t) {} - - auto run() -> bool { - using namespace fmt::detail; - return handle(gmtime_r(&time_, &tm_)); - } - - auto handle(std::tm* tm) -> bool { return tm != nullptr; } - - auto handle(detail::null<>) -> bool { - using namespace fmt::detail; - return fallback(gmtime_s(&tm_, &time_)); - } - - auto fallback(int res) -> bool { return res == 0; } - -#if !FMT_MSC_VERSION - auto fallback(detail::null<>) -> bool { - std::tm* tm = std::gmtime(&time_); - if (tm) tm_ = *tm; - return tm != nullptr; - } -#endif - }; - auto gt = dispatcher(time); - // Too big time values may be unsupported. - if (!gt.run()) FMT_THROW(format_error("time_t value out of range")); - return gt.tm_; -} - -template -inline auto gmtime( - std::chrono::time_point time_point) - -> std::tm { - return gmtime(detail::to_time_t(time_point)); -} - -namespace detail { - -// Writes two-digit numbers a, b and c separated by sep to buf. -// The method by Pavel Novikov based on -// https://johnnylee-sde.github.io/Fast-unsigned-integer-to-time-string/. -inline void write_digit2_separated(char* buf, unsigned a, unsigned b, - unsigned c, char sep) { - unsigned long long digits = - a | (b << 24) | (static_cast(c) << 48); - // Convert each value to BCD. - // We have x = a * 10 + b and we want to convert it to BCD y = a * 16 + b. - // The difference is - // y - x = a * 6 - // a can be found from x: - // a = floor(x / 10) - // then - // y = x + a * 6 = x + floor(x / 10) * 6 - // floor(x / 10) is (x * 205) >> 11 (needs 16 bits). - digits += (((digits * 205) >> 11) & 0x000f00000f00000f) * 6; - // Put low nibbles to high bytes and high nibbles to low bytes. - digits = ((digits & 0x00f00000f00000f0) >> 4) | - ((digits & 0x000f00000f00000f) << 8); - auto usep = static_cast(sep); - // Add ASCII '0' to each digit byte and insert separators. - digits |= 0x3030003030003030 | (usep << 16) | (usep << 40); - - constexpr const size_t len = 8; - if (const_check(is_big_endian())) { - char tmp[len]; - std::memcpy(tmp, &digits, len); - std::reverse_copy(tmp, tmp + len, buf); - } else { - std::memcpy(buf, &digits, len); - } -} - -template -FMT_CONSTEXPR inline auto get_units() -> const char* { - if (std::is_same::value) return "as"; - if (std::is_same::value) return "fs"; - if (std::is_same::value) return "ps"; - if (std::is_same::value) return "ns"; - if (std::is_same::value) return "µs"; - if (std::is_same::value) return "ms"; - if (std::is_same::value) return "cs"; - if (std::is_same::value) return "ds"; - if (std::is_same>::value) return "s"; - if (std::is_same::value) return "das"; - if (std::is_same::value) return "hs"; - if (std::is_same::value) return "ks"; - if (std::is_same::value) return "Ms"; - if (std::is_same::value) return "Gs"; - if (std::is_same::value) return "Ts"; - if (std::is_same::value) return "Ps"; - if (std::is_same::value) return "Es"; - if (std::is_same>::value) return "min"; - if (std::is_same>::value) return "h"; - if (std::is_same>::value) return "d"; - return nullptr; -} - -enum class numeric_system { - standard, - // Alternative numeric system, e.g. 十二 instead of 12 in ja_JP locale. - alternative -}; - -// Glibc extensions for formatting numeric values. -enum class pad_type { - // Pad a numeric result string with zeros (the default). - zero, - // Do not pad a numeric result string. - none, - // Pad a numeric result string with spaces. - space, -}; - -template -auto write_padding(OutputIt out, pad_type pad, int width) -> OutputIt { - if (pad == pad_type::none) return out; - return detail::fill_n(out, width, pad == pad_type::space ? ' ' : '0'); -} - -template -auto write_padding(OutputIt out, pad_type pad) -> OutputIt { - if (pad != pad_type::none) *out++ = pad == pad_type::space ? ' ' : '0'; - return out; -} - -// Parses a put_time-like format string and invokes handler actions. -template -FMT_CONSTEXPR auto parse_chrono_format(const Char* begin, const Char* end, - Handler&& handler) -> const Char* { - if (begin == end || *begin == '}') return begin; - if (*begin != '%') FMT_THROW(format_error("invalid format")); - auto ptr = begin; - while (ptr != end) { - pad_type pad = pad_type::zero; - auto c = *ptr; - if (c == '}') break; - if (c != '%') { - ++ptr; - continue; - } - if (begin != ptr) handler.on_text(begin, ptr); - ++ptr; // consume '%' - if (ptr == end) FMT_THROW(format_error("invalid format")); - c = *ptr; - switch (c) { - case '_': - pad = pad_type::space; - ++ptr; - break; - case '-': - pad = pad_type::none; - ++ptr; - break; - } - if (ptr == end) FMT_THROW(format_error("invalid format")); - c = *ptr++; - switch (c) { - case '%': - handler.on_text(ptr - 1, ptr); - break; - case 'n': { - const Char newline[] = {'\n'}; - handler.on_text(newline, newline + 1); - break; - } - case 't': { - const Char tab[] = {'\t'}; - handler.on_text(tab, tab + 1); - break; - } - // Year: - case 'Y': - handler.on_year(numeric_system::standard); - break; - case 'y': - handler.on_short_year(numeric_system::standard); - break; - case 'C': - handler.on_century(numeric_system::standard); - break; - case 'G': - handler.on_iso_week_based_year(); - break; - case 'g': - handler.on_iso_week_based_short_year(); - break; - // Day of the week: - case 'a': - handler.on_abbr_weekday(); - break; - case 'A': - handler.on_full_weekday(); - break; - case 'w': - handler.on_dec0_weekday(numeric_system::standard); - break; - case 'u': - handler.on_dec1_weekday(numeric_system::standard); - break; - // Month: - case 'b': - case 'h': - handler.on_abbr_month(); - break; - case 'B': - handler.on_full_month(); - break; - case 'm': - handler.on_dec_month(numeric_system::standard); - break; - // Day of the year/month: - case 'U': - handler.on_dec0_week_of_year(numeric_system::standard, pad); - break; - case 'W': - handler.on_dec1_week_of_year(numeric_system::standard, pad); - break; - case 'V': - handler.on_iso_week_of_year(numeric_system::standard, pad); - break; - case 'j': - handler.on_day_of_year(); - break; - case 'd': - handler.on_day_of_month(numeric_system::standard, pad); - break; - case 'e': - handler.on_day_of_month(numeric_system::standard, pad_type::space); - break; - // Hour, minute, second: - case 'H': - handler.on_24_hour(numeric_system::standard, pad); - break; - case 'I': - handler.on_12_hour(numeric_system::standard, pad); - break; - case 'M': - handler.on_minute(numeric_system::standard, pad); - break; - case 'S': - handler.on_second(numeric_system::standard, pad); - break; - // Other: - case 'c': - handler.on_datetime(numeric_system::standard); - break; - case 'x': - handler.on_loc_date(numeric_system::standard); - break; - case 'X': - handler.on_loc_time(numeric_system::standard); - break; - case 'D': - handler.on_us_date(); - break; - case 'F': - handler.on_iso_date(); - break; - case 'r': - handler.on_12_hour_time(); - break; - case 'R': - handler.on_24_hour_time(); - break; - case 'T': - handler.on_iso_time(); - break; - case 'p': - handler.on_am_pm(); - break; - case 'Q': - handler.on_duration_value(); - break; - case 'q': - handler.on_duration_unit(); - break; - case 'z': - handler.on_utc_offset(numeric_system::standard); - break; - case 'Z': - handler.on_tz_name(); - break; - // Alternative representation: - case 'E': { - if (ptr == end) FMT_THROW(format_error("invalid format")); - c = *ptr++; - switch (c) { - case 'Y': - handler.on_year(numeric_system::alternative); - break; - case 'y': - handler.on_offset_year(); - break; - case 'C': - handler.on_century(numeric_system::alternative); - break; - case 'c': - handler.on_datetime(numeric_system::alternative); - break; - case 'x': - handler.on_loc_date(numeric_system::alternative); - break; - case 'X': - handler.on_loc_time(numeric_system::alternative); - break; - case 'z': - handler.on_utc_offset(numeric_system::alternative); - break; - default: - FMT_THROW(format_error("invalid format")); - } - break; - } - case 'O': - if (ptr == end) FMT_THROW(format_error("invalid format")); - c = *ptr++; - switch (c) { - case 'y': - handler.on_short_year(numeric_system::alternative); - break; - case 'm': - handler.on_dec_month(numeric_system::alternative); - break; - case 'U': - handler.on_dec0_week_of_year(numeric_system::alternative, pad); - break; - case 'W': - handler.on_dec1_week_of_year(numeric_system::alternative, pad); - break; - case 'V': - handler.on_iso_week_of_year(numeric_system::alternative, pad); - break; - case 'd': - handler.on_day_of_month(numeric_system::alternative, pad); - break; - case 'e': - handler.on_day_of_month(numeric_system::alternative, pad_type::space); - break; - case 'w': - handler.on_dec0_weekday(numeric_system::alternative); - break; - case 'u': - handler.on_dec1_weekday(numeric_system::alternative); - break; - case 'H': - handler.on_24_hour(numeric_system::alternative, pad); - break; - case 'I': - handler.on_12_hour(numeric_system::alternative, pad); - break; - case 'M': - handler.on_minute(numeric_system::alternative, pad); - break; - case 'S': - handler.on_second(numeric_system::alternative, pad); - break; - case 'z': - handler.on_utc_offset(numeric_system::alternative); - break; - default: - FMT_THROW(format_error("invalid format")); - } - break; - default: - FMT_THROW(format_error("invalid format")); - } - begin = ptr; - } - if (begin != ptr) handler.on_text(begin, ptr); - return ptr; -} - -template struct null_chrono_spec_handler { - FMT_CONSTEXPR void unsupported() { - static_cast(this)->unsupported(); - } - FMT_CONSTEXPR void on_year(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_short_year(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_offset_year() { unsupported(); } - FMT_CONSTEXPR void on_century(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_iso_week_based_year() { unsupported(); } - FMT_CONSTEXPR void on_iso_week_based_short_year() { unsupported(); } - FMT_CONSTEXPR void on_abbr_weekday() { unsupported(); } - FMT_CONSTEXPR void on_full_weekday() { unsupported(); } - FMT_CONSTEXPR void on_dec0_weekday(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_dec1_weekday(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_abbr_month() { unsupported(); } - FMT_CONSTEXPR void on_full_month() { unsupported(); } - FMT_CONSTEXPR void on_dec_month(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_dec0_week_of_year(numeric_system, pad_type) { - unsupported(); - } - FMT_CONSTEXPR void on_dec1_week_of_year(numeric_system, pad_type) { - unsupported(); - } - FMT_CONSTEXPR void on_iso_week_of_year(numeric_system, pad_type) { - unsupported(); - } - FMT_CONSTEXPR void on_day_of_year() { unsupported(); } - FMT_CONSTEXPR void on_day_of_month(numeric_system, pad_type) { - unsupported(); - } - FMT_CONSTEXPR void on_24_hour(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_12_hour(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_minute(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_second(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_datetime(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_loc_date(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_loc_time(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_us_date() { unsupported(); } - FMT_CONSTEXPR void on_iso_date() { unsupported(); } - FMT_CONSTEXPR void on_12_hour_time() { unsupported(); } - FMT_CONSTEXPR void on_24_hour_time() { unsupported(); } - FMT_CONSTEXPR void on_iso_time() { unsupported(); } - FMT_CONSTEXPR void on_am_pm() { unsupported(); } - FMT_CONSTEXPR void on_duration_value() { unsupported(); } - FMT_CONSTEXPR void on_duration_unit() { unsupported(); } - FMT_CONSTEXPR void on_utc_offset(numeric_system) { unsupported(); } - FMT_CONSTEXPR void on_tz_name() { unsupported(); } -}; - -struct tm_format_checker : null_chrono_spec_handler { - FMT_NORETURN void unsupported() { FMT_THROW(format_error("no format")); } - - template - FMT_CONSTEXPR void on_text(const Char*, const Char*) {} - FMT_CONSTEXPR void on_year(numeric_system) {} - FMT_CONSTEXPR void on_short_year(numeric_system) {} - FMT_CONSTEXPR void on_offset_year() {} - FMT_CONSTEXPR void on_century(numeric_system) {} - FMT_CONSTEXPR void on_iso_week_based_year() {} - FMT_CONSTEXPR void on_iso_week_based_short_year() {} - FMT_CONSTEXPR void on_abbr_weekday() {} - FMT_CONSTEXPR void on_full_weekday() {} - FMT_CONSTEXPR void on_dec0_weekday(numeric_system) {} - FMT_CONSTEXPR void on_dec1_weekday(numeric_system) {} - FMT_CONSTEXPR void on_abbr_month() {} - FMT_CONSTEXPR void on_full_month() {} - FMT_CONSTEXPR void on_dec_month(numeric_system) {} - FMT_CONSTEXPR void on_dec0_week_of_year(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_dec1_week_of_year(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_iso_week_of_year(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_day_of_year() {} - FMT_CONSTEXPR void on_day_of_month(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_24_hour(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_12_hour(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_minute(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_second(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_datetime(numeric_system) {} - FMT_CONSTEXPR void on_loc_date(numeric_system) {} - FMT_CONSTEXPR void on_loc_time(numeric_system) {} - FMT_CONSTEXPR void on_us_date() {} - FMT_CONSTEXPR void on_iso_date() {} - FMT_CONSTEXPR void on_12_hour_time() {} - FMT_CONSTEXPR void on_24_hour_time() {} - FMT_CONSTEXPR void on_iso_time() {} - FMT_CONSTEXPR void on_am_pm() {} - FMT_CONSTEXPR void on_utc_offset(numeric_system) {} - FMT_CONSTEXPR void on_tz_name() {} -}; - -inline auto tm_wday_full_name(int wday) -> const char* { - static constexpr const char* full_name_list[] = { - "Sunday", "Monday", "Tuesday", "Wednesday", - "Thursday", "Friday", "Saturday"}; - return wday >= 0 && wday <= 6 ? full_name_list[wday] : "?"; -} -inline auto tm_wday_short_name(int wday) -> const char* { - static constexpr const char* short_name_list[] = {"Sun", "Mon", "Tue", "Wed", - "Thu", "Fri", "Sat"}; - return wday >= 0 && wday <= 6 ? short_name_list[wday] : "???"; -} - -inline auto tm_mon_full_name(int mon) -> const char* { - static constexpr const char* full_name_list[] = { - "January", "February", "March", "April", "May", "June", - "July", "August", "September", "October", "November", "December"}; - return mon >= 0 && mon <= 11 ? full_name_list[mon] : "?"; -} -inline auto tm_mon_short_name(int mon) -> const char* { - static constexpr const char* short_name_list[] = { - "Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", - }; - return mon >= 0 && mon <= 11 ? short_name_list[mon] : "???"; -} - -template -struct has_member_data_tm_gmtoff : std::false_type {}; -template -struct has_member_data_tm_gmtoff> - : std::true_type {}; - -template -struct has_member_data_tm_zone : std::false_type {}; -template -struct has_member_data_tm_zone> - : std::true_type {}; - -#if FMT_USE_TZSET -inline void tzset_once() { - static bool init = []() -> bool { - _tzset(); - return true; - }(); - ignore_unused(init); -} -#endif - -// Converts value to Int and checks that it's in the range [0, upper). -template ::value)> -inline auto to_nonnegative_int(T value, Int upper) -> Int { - if (!std::is_unsigned::value && - (value < 0 || to_unsigned(value) > to_unsigned(upper))) { - FMT_THROW(fmt::format_error("chrono value is out of range")); - } - return static_cast(value); -} -template ::value)> -inline auto to_nonnegative_int(T value, Int upper) -> Int { - auto int_value = static_cast(value); - if (int_value < 0 || value > static_cast(upper)) - FMT_THROW(format_error("invalid value")); - return int_value; -} - -constexpr auto pow10(std::uint32_t n) -> long long { - return n == 0 ? 1 : 10 * pow10(n - 1); -} - -// Counts the number of fractional digits in the range [0, 18] according to the -// C++20 spec. If more than 18 fractional digits are required then returns 6 for -// microseconds precision. -template () / 10)> -struct count_fractional_digits { - static constexpr int value = - Num % Den == 0 ? N : count_fractional_digits::value; -}; - -// Base case that doesn't instantiate any more templates -// in order to avoid overflow. -template -struct count_fractional_digits { - static constexpr int value = (Num % Den == 0) ? N : 6; -}; - -// Format subseconds which are given as an integer type with an appropriate -// number of digits. -template -void write_fractional_seconds(OutputIt& out, Duration d, int precision = -1) { - constexpr auto num_fractional_digits = - count_fractional_digits::value; - - using subsecond_precision = std::chrono::duration< - typename std::common_type::type, - std::ratio<1, detail::pow10(num_fractional_digits)>>; - - const auto fractional = d - fmt_duration_cast(d); - const auto subseconds = - std::chrono::treat_as_floating_point< - typename subsecond_precision::rep>::value - ? fractional.count() - : fmt_duration_cast(fractional).count(); - auto n = static_cast>(subseconds); - const int num_digits = detail::count_digits(n); - - int leading_zeroes = (std::max)(0, num_fractional_digits - num_digits); - if (precision < 0) { - FMT_ASSERT(!std::is_floating_point::value, ""); - if (std::ratio_less::value) { - *out++ = '.'; - out = detail::fill_n(out, leading_zeroes, '0'); - out = format_decimal(out, n, num_digits).end; - } - } else if (precision > 0) { - *out++ = '.'; - leading_zeroes = (std::min)(leading_zeroes, precision); - int remaining = precision - leading_zeroes; - out = detail::fill_n(out, leading_zeroes, '0'); - if (remaining < num_digits) { - int num_truncated_digits = num_digits - remaining; - n /= to_unsigned(detail::pow10(to_unsigned(num_truncated_digits))); - if (n) { - out = format_decimal(out, n, remaining).end; - } - return; - } - if (n) { - out = format_decimal(out, n, num_digits).end; - remaining -= num_digits; - } - out = detail::fill_n(out, remaining, '0'); - } -} - -// Format subseconds which are given as a floating point type with an -// appropriate number of digits. We cannot pass the Duration here, as we -// explicitly need to pass the Rep value in the chrono_formatter. -template -void write_floating_seconds(memory_buffer& buf, Duration duration, - int num_fractional_digits = -1) { - using rep = typename Duration::rep; - FMT_ASSERT(std::is_floating_point::value, ""); - - auto val = duration.count(); - - if (num_fractional_digits < 0) { - // For `std::round` with fallback to `round`: - // On some toolchains `std::round` is not available (e.g. GCC 6). - using namespace std; - num_fractional_digits = - count_fractional_digits::value; - if (num_fractional_digits < 6 && static_cast(round(val)) != val) - num_fractional_digits = 6; - } - - fmt::format_to(std::back_inserter(buf), FMT_STRING("{:.{}f}"), - std::fmod(val * static_cast(Duration::period::num) / - static_cast(Duration::period::den), - static_cast(60)), - num_fractional_digits); -} - -template -class tm_writer { - private: - static constexpr int days_per_week = 7; - - const std::locale& loc_; - const bool is_classic_; - OutputIt out_; - const Duration* subsecs_; - const std::tm& tm_; - - auto tm_sec() const noexcept -> int { - FMT_ASSERT(tm_.tm_sec >= 0 && tm_.tm_sec <= 61, ""); - return tm_.tm_sec; - } - auto tm_min() const noexcept -> int { - FMT_ASSERT(tm_.tm_min >= 0 && tm_.tm_min <= 59, ""); - return tm_.tm_min; - } - auto tm_hour() const noexcept -> int { - FMT_ASSERT(tm_.tm_hour >= 0 && tm_.tm_hour <= 23, ""); - return tm_.tm_hour; - } - auto tm_mday() const noexcept -> int { - FMT_ASSERT(tm_.tm_mday >= 1 && tm_.tm_mday <= 31, ""); - return tm_.tm_mday; - } - auto tm_mon() const noexcept -> int { - FMT_ASSERT(tm_.tm_mon >= 0 && tm_.tm_mon <= 11, ""); - return tm_.tm_mon; - } - auto tm_year() const noexcept -> long long { return 1900ll + tm_.tm_year; } - auto tm_wday() const noexcept -> int { - FMT_ASSERT(tm_.tm_wday >= 0 && tm_.tm_wday <= 6, ""); - return tm_.tm_wday; - } - auto tm_yday() const noexcept -> int { - FMT_ASSERT(tm_.tm_yday >= 0 && tm_.tm_yday <= 365, ""); - return tm_.tm_yday; - } - - auto tm_hour12() const noexcept -> int { - const auto h = tm_hour(); - const auto z = h < 12 ? h : h - 12; - return z == 0 ? 12 : z; - } - - // POSIX and the C Standard are unclear or inconsistent about what %C and %y - // do if the year is negative or exceeds 9999. Use the convention that %C - // concatenated with %y yields the same output as %Y, and that %Y contains at - // least 4 characters, with more only if necessary. - auto split_year_lower(long long year) const noexcept -> int { - auto l = year % 100; - if (l < 0) l = -l; // l in [0, 99] - return static_cast(l); - } - - // Algorithm: https://en.wikipedia.org/wiki/ISO_week_date. - auto iso_year_weeks(long long curr_year) const noexcept -> int { - const auto prev_year = curr_year - 1; - const auto curr_p = - (curr_year + curr_year / 4 - curr_year / 100 + curr_year / 400) % - days_per_week; - const auto prev_p = - (prev_year + prev_year / 4 - prev_year / 100 + prev_year / 400) % - days_per_week; - return 52 + ((curr_p == 4 || prev_p == 3) ? 1 : 0); - } - auto iso_week_num(int tm_yday, int tm_wday) const noexcept -> int { - return (tm_yday + 11 - (tm_wday == 0 ? days_per_week : tm_wday)) / - days_per_week; - } - auto tm_iso_week_year() const noexcept -> long long { - const auto year = tm_year(); - const auto w = iso_week_num(tm_yday(), tm_wday()); - if (w < 1) return year - 1; - if (w > iso_year_weeks(year)) return year + 1; - return year; - } - auto tm_iso_week_of_year() const noexcept -> int { - const auto year = tm_year(); - const auto w = iso_week_num(tm_yday(), tm_wday()); - if (w < 1) return iso_year_weeks(year - 1); - if (w > iso_year_weeks(year)) return 1; - return w; - } - - void write1(int value) { - *out_++ = static_cast('0' + to_unsigned(value) % 10); - } - void write2(int value) { - const char* d = digits2(to_unsigned(value) % 100); - *out_++ = *d++; - *out_++ = *d; - } - void write2(int value, pad_type pad) { - unsigned int v = to_unsigned(value) % 100; - if (v >= 10) { - const char* d = digits2(v); - *out_++ = *d++; - *out_++ = *d; - } else { - out_ = detail::write_padding(out_, pad); - *out_++ = static_cast('0' + v); - } - } - - void write_year_extended(long long year) { - // At least 4 characters. - int width = 4; - if (year < 0) { - *out_++ = '-'; - year = 0 - year; - --width; - } - uint32_or_64_or_128_t n = to_unsigned(year); - const int num_digits = count_digits(n); - if (width > num_digits) - out_ = detail::fill_n(out_, width - num_digits, '0'); - out_ = format_decimal(out_, n, num_digits).end; - } - void write_year(long long year) { - if (year >= 0 && year < 10000) { - write2(static_cast(year / 100)); - write2(static_cast(year % 100)); - } else { - write_year_extended(year); - } - } - - void write_utc_offset(long offset, numeric_system ns) { - if (offset < 0) { - *out_++ = '-'; - offset = -offset; - } else { - *out_++ = '+'; - } - offset /= 60; - write2(static_cast(offset / 60)); - if (ns != numeric_system::standard) *out_++ = ':'; - write2(static_cast(offset % 60)); - } - template ::value)> - void format_utc_offset_impl(const T& tm, numeric_system ns) { - write_utc_offset(tm.tm_gmtoff, ns); - } - template ::value)> - void format_utc_offset_impl(const T& tm, numeric_system ns) { -#if defined(_WIN32) && defined(_UCRT) -# if FMT_USE_TZSET - tzset_once(); -# endif - long offset = 0; - _get_timezone(&offset); - if (tm.tm_isdst) { - long dstbias = 0; - _get_dstbias(&dstbias); - offset += dstbias; - } - write_utc_offset(-offset, ns); -#else - if (ns == numeric_system::standard) return format_localized('z'); - - // Extract timezone offset from timezone conversion functions. - std::tm gtm = tm; - std::time_t gt = std::mktime(>m); - std::tm ltm = gmtime(gt); - std::time_t lt = std::mktime(<m); - long offset = gt - lt; - write_utc_offset(offset, ns); -#endif - } - - template ::value)> - void format_tz_name_impl(const T& tm) { - if (is_classic_) - out_ = write_tm_str(out_, tm.tm_zone, loc_); - else - format_localized('Z'); - } - template ::value)> - void format_tz_name_impl(const T&) { - format_localized('Z'); - } - - void format_localized(char format, char modifier = 0) { - out_ = write(out_, tm_, loc_, format, modifier); - } - - public: - tm_writer(const std::locale& loc, OutputIt out, const std::tm& tm, - const Duration* subsecs = nullptr) - : loc_(loc), - is_classic_(loc_ == get_classic_locale()), - out_(out), - subsecs_(subsecs), - tm_(tm) {} - - auto out() const -> OutputIt { return out_; } - - FMT_CONSTEXPR void on_text(const Char* begin, const Char* end) { - out_ = copy(begin, end, out_); - } - - void on_abbr_weekday() { - if (is_classic_) - out_ = write(out_, tm_wday_short_name(tm_wday())); - else - format_localized('a'); - } - void on_full_weekday() { - if (is_classic_) - out_ = write(out_, tm_wday_full_name(tm_wday())); - else - format_localized('A'); - } - void on_dec0_weekday(numeric_system ns) { - if (is_classic_ || ns == numeric_system::standard) return write1(tm_wday()); - format_localized('w', 'O'); - } - void on_dec1_weekday(numeric_system ns) { - if (is_classic_ || ns == numeric_system::standard) { - auto wday = tm_wday(); - write1(wday == 0 ? days_per_week : wday); - } else { - format_localized('u', 'O'); - } - } - - void on_abbr_month() { - if (is_classic_) - out_ = write(out_, tm_mon_short_name(tm_mon())); - else - format_localized('b'); - } - void on_full_month() { - if (is_classic_) - out_ = write(out_, tm_mon_full_name(tm_mon())); - else - format_localized('B'); - } - - void on_datetime(numeric_system ns) { - if (is_classic_) { - on_abbr_weekday(); - *out_++ = ' '; - on_abbr_month(); - *out_++ = ' '; - on_day_of_month(numeric_system::standard, pad_type::space); - *out_++ = ' '; - on_iso_time(); - *out_++ = ' '; - on_year(numeric_system::standard); - } else { - format_localized('c', ns == numeric_system::standard ? '\0' : 'E'); - } - } - void on_loc_date(numeric_system ns) { - if (is_classic_) - on_us_date(); - else - format_localized('x', ns == numeric_system::standard ? '\0' : 'E'); - } - void on_loc_time(numeric_system ns) { - if (is_classic_) - on_iso_time(); - else - format_localized('X', ns == numeric_system::standard ? '\0' : 'E'); - } - void on_us_date() { - char buf[8]; - write_digit2_separated(buf, to_unsigned(tm_mon() + 1), - to_unsigned(tm_mday()), - to_unsigned(split_year_lower(tm_year())), '/'); - out_ = copy(std::begin(buf), std::end(buf), out_); - } - void on_iso_date() { - auto year = tm_year(); - char buf[10]; - size_t offset = 0; - if (year >= 0 && year < 10000) { - copy2(buf, digits2(static_cast(year / 100))); - } else { - offset = 4; - write_year_extended(year); - year = 0; - } - write_digit2_separated(buf + 2, static_cast(year % 100), - to_unsigned(tm_mon() + 1), to_unsigned(tm_mday()), - '-'); - out_ = copy(std::begin(buf) + offset, std::end(buf), out_); - } - - void on_utc_offset(numeric_system ns) { format_utc_offset_impl(tm_, ns); } - void on_tz_name() { format_tz_name_impl(tm_); } - - void on_year(numeric_system ns) { - if (is_classic_ || ns == numeric_system::standard) - return write_year(tm_year()); - format_localized('Y', 'E'); - } - void on_short_year(numeric_system ns) { - if (is_classic_ || ns == numeric_system::standard) - return write2(split_year_lower(tm_year())); - format_localized('y', 'O'); - } - void on_offset_year() { - if (is_classic_) return write2(split_year_lower(tm_year())); - format_localized('y', 'E'); - } - - void on_century(numeric_system ns) { - if (is_classic_ || ns == numeric_system::standard) { - auto year = tm_year(); - auto upper = year / 100; - if (year >= -99 && year < 0) { - // Zero upper on negative year. - *out_++ = '-'; - *out_++ = '0'; - } else if (upper >= 0 && upper < 100) { - write2(static_cast(upper)); - } else { - out_ = write(out_, upper); - } - } else { - format_localized('C', 'E'); - } - } - - void on_dec_month(numeric_system ns) { - if (is_classic_ || ns == numeric_system::standard) - return write2(tm_mon() + 1); - format_localized('m', 'O'); - } - - void on_dec0_week_of_year(numeric_system ns, pad_type pad) { - if (is_classic_ || ns == numeric_system::standard) - return write2((tm_yday() + days_per_week - tm_wday()) / days_per_week, - pad); - format_localized('U', 'O'); - } - void on_dec1_week_of_year(numeric_system ns, pad_type pad) { - if (is_classic_ || ns == numeric_system::standard) { - auto wday = tm_wday(); - write2((tm_yday() + days_per_week - - (wday == 0 ? (days_per_week - 1) : (wday - 1))) / - days_per_week, - pad); - } else { - format_localized('W', 'O'); - } - } - void on_iso_week_of_year(numeric_system ns, pad_type pad) { - if (is_classic_ || ns == numeric_system::standard) - return write2(tm_iso_week_of_year(), pad); - format_localized('V', 'O'); - } - - void on_iso_week_based_year() { write_year(tm_iso_week_year()); } - void on_iso_week_based_short_year() { - write2(split_year_lower(tm_iso_week_year())); - } - - void on_day_of_year() { - auto yday = tm_yday() + 1; - write1(yday / 100); - write2(yday % 100); - } - void on_day_of_month(numeric_system ns, pad_type pad) { - if (is_classic_ || ns == numeric_system::standard) - return write2(tm_mday(), pad); - format_localized('d', 'O'); - } - - void on_24_hour(numeric_system ns, pad_type pad) { - if (is_classic_ || ns == numeric_system::standard) - return write2(tm_hour(), pad); - format_localized('H', 'O'); - } - void on_12_hour(numeric_system ns, pad_type pad) { - if (is_classic_ || ns == numeric_system::standard) - return write2(tm_hour12(), pad); - format_localized('I', 'O'); - } - void on_minute(numeric_system ns, pad_type pad) { - if (is_classic_ || ns == numeric_system::standard) - return write2(tm_min(), pad); - format_localized('M', 'O'); - } - - void on_second(numeric_system ns, pad_type pad) { - if (is_classic_ || ns == numeric_system::standard) { - write2(tm_sec(), pad); - if (subsecs_) { - if (std::is_floating_point::value) { - auto buf = memory_buffer(); - write_floating_seconds(buf, *subsecs_); - if (buf.size() > 1) { - // Remove the leading "0", write something like ".123". - out_ = std::copy(buf.begin() + 1, buf.end(), out_); - } - } else { - write_fractional_seconds(out_, *subsecs_); - } - } - } else { - // Currently no formatting of subseconds when a locale is set. - format_localized('S', 'O'); - } - } - - void on_12_hour_time() { - if (is_classic_) { - char buf[8]; - write_digit2_separated(buf, to_unsigned(tm_hour12()), - to_unsigned(tm_min()), to_unsigned(tm_sec()), ':'); - out_ = copy(std::begin(buf), std::end(buf), out_); - *out_++ = ' '; - on_am_pm(); - } else { - format_localized('r'); - } - } - void on_24_hour_time() { - write2(tm_hour()); - *out_++ = ':'; - write2(tm_min()); - } - void on_iso_time() { - on_24_hour_time(); - *out_++ = ':'; - on_second(numeric_system::standard, pad_type::zero); - } - - void on_am_pm() { - if (is_classic_) { - *out_++ = tm_hour() < 12 ? 'A' : 'P'; - *out_++ = 'M'; - } else { - format_localized('p'); - } - } - - // These apply to chrono durations but not tm. - void on_duration_value() {} - void on_duration_unit() {} -}; - -struct chrono_format_checker : null_chrono_spec_handler { - bool has_precision_integral = false; - - FMT_NORETURN void unsupported() { FMT_THROW(format_error("no date")); } - - template - FMT_CONSTEXPR void on_text(const Char*, const Char*) {} - FMT_CONSTEXPR void on_day_of_year() {} - FMT_CONSTEXPR void on_24_hour(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_12_hour(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_minute(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_second(numeric_system, pad_type) {} - FMT_CONSTEXPR void on_12_hour_time() {} - FMT_CONSTEXPR void on_24_hour_time() {} - FMT_CONSTEXPR void on_iso_time() {} - FMT_CONSTEXPR void on_am_pm() {} - FMT_CONSTEXPR void on_duration_value() const { - if (has_precision_integral) { - FMT_THROW(format_error("precision not allowed for this argument type")); - } - } - FMT_CONSTEXPR void on_duration_unit() {} -}; - -template ::value&& has_isfinite::value)> -inline auto isfinite(T) -> bool { - return true; -} - -template ::value)> -inline auto mod(T x, int y) -> T { - return x % static_cast(y); -} -template ::value)> -inline auto mod(T x, int y) -> T { - return std::fmod(x, static_cast(y)); -} - -// If T is an integral type, maps T to its unsigned counterpart, otherwise -// leaves it unchanged (unlike std::make_unsigned). -template ::value> -struct make_unsigned_or_unchanged { - using type = T; -}; - -template struct make_unsigned_or_unchanged { - using type = typename std::make_unsigned::type; -}; - -template ::value)> -inline auto get_milliseconds(std::chrono::duration d) - -> std::chrono::duration { - // this may overflow and/or the result may not fit in the - // target type. -#if FMT_SAFE_DURATION_CAST - using CommonSecondsType = - typename std::common_type::type; - const auto d_as_common = fmt_duration_cast(d); - const auto d_as_whole_seconds = - fmt_duration_cast(d_as_common); - // this conversion should be nonproblematic - const auto diff = d_as_common - d_as_whole_seconds; - const auto ms = - fmt_duration_cast>(diff); - return ms; -#else - auto s = fmt_duration_cast(d); - return fmt_duration_cast(d - s); -#endif -} - -template ::value)> -auto format_duration_value(OutputIt out, Rep val, int) -> OutputIt { - return write(out, val); -} - -template ::value)> -auto format_duration_value(OutputIt out, Rep val, int precision) -> OutputIt { - auto specs = format_specs(); - specs.precision = precision; - specs.type = - precision >= 0 ? presentation_type::fixed : presentation_type::general; - return write(out, val, specs); -} - -template -auto copy_unit(string_view unit, OutputIt out, Char) -> OutputIt { - return std::copy(unit.begin(), unit.end(), out); -} - -template -auto copy_unit(string_view unit, OutputIt out, wchar_t) -> OutputIt { - // This works when wchar_t is UTF-32 because units only contain characters - // that have the same representation in UTF-16 and UTF-32. - utf8_to_utf16 u(unit); - return std::copy(u.c_str(), u.c_str() + u.size(), out); -} - -template -auto format_duration_unit(OutputIt out) -> OutputIt { - if (const char* unit = get_units()) - return copy_unit(string_view(unit), out, Char()); - *out++ = '['; - out = write(out, Period::num); - if (const_check(Period::den != 1)) { - *out++ = '/'; - out = write(out, Period::den); - } - *out++ = ']'; - *out++ = 's'; - return out; -} - -class get_locale { - private: - union { - std::locale locale_; - }; - bool has_locale_ = false; - - public: - get_locale(bool localized, locale_ref loc) : has_locale_(localized) { -#ifndef FMT_STATIC_THOUSANDS_SEPARATOR - if (localized) - ::new (&locale_) std::locale(loc.template get()); -#endif - } - ~get_locale() { - if (has_locale_) locale_.~locale(); - } - operator const std::locale&() const { - return has_locale_ ? locale_ : get_classic_locale(); - } -}; - -template -struct chrono_formatter { - FormatContext& context; - OutputIt out; - int precision; - bool localized = false; - // rep is unsigned to avoid overflow. - using rep = - conditional_t::value && sizeof(Rep) < sizeof(int), - unsigned, typename make_unsigned_or_unchanged::type>; - rep val; - using seconds = std::chrono::duration; - seconds s; - using milliseconds = std::chrono::duration; - bool negative; - - using char_type = typename FormatContext::char_type; - using tm_writer_type = tm_writer; - - chrono_formatter(FormatContext& ctx, OutputIt o, - std::chrono::duration d) - : context(ctx), - out(o), - val(static_cast(d.count())), - negative(false) { - if (d.count() < 0) { - val = 0 - val; - negative = true; - } - - // this may overflow and/or the result may not fit in the - // target type. - // might need checked conversion (rep!=Rep) - s = fmt_duration_cast(std::chrono::duration(val)); - } - - // returns true if nan or inf, writes to out. - auto handle_nan_inf() -> bool { - if (isfinite(val)) { - return false; - } - if (isnan(val)) { - write_nan(); - return true; - } - // must be +-inf - if (val > 0) { - write_pinf(); - } else { - write_ninf(); - } - return true; - } - - auto days() const -> Rep { return static_cast(s.count() / 86400); } - auto hour() const -> Rep { - return static_cast(mod((s.count() / 3600), 24)); - } - - auto hour12() const -> Rep { - Rep hour = static_cast(mod((s.count() / 3600), 12)); - return hour <= 0 ? 12 : hour; - } - - auto minute() const -> Rep { - return static_cast(mod((s.count() / 60), 60)); - } - auto second() const -> Rep { return static_cast(mod(s.count(), 60)); } - - auto time() const -> std::tm { - auto time = std::tm(); - time.tm_hour = to_nonnegative_int(hour(), 24); - time.tm_min = to_nonnegative_int(minute(), 60); - time.tm_sec = to_nonnegative_int(second(), 60); - return time; - } - - void write_sign() { - if (negative) { - *out++ = '-'; - negative = false; - } - } - - void write(Rep value, int width, pad_type pad = pad_type::zero) { - write_sign(); - if (isnan(value)) return write_nan(); - uint32_or_64_or_128_t n = - to_unsigned(to_nonnegative_int(value, max_value())); - int num_digits = detail::count_digits(n); - if (width > num_digits) { - out = detail::write_padding(out, pad, width - num_digits); - } - out = format_decimal(out, n, num_digits).end; - } - - void write_nan() { std::copy_n("nan", 3, out); } - void write_pinf() { std::copy_n("inf", 3, out); } - void write_ninf() { std::copy_n("-inf", 4, out); } - - template - void format_tm(const tm& time, Callback cb, Args... args) { - if (isnan(val)) return write_nan(); - get_locale loc(localized, context.locale()); - auto w = tm_writer_type(loc, out, time); - (w.*cb)(args...); - out = w.out(); - } - - void on_text(const char_type* begin, const char_type* end) { - std::copy(begin, end, out); - } - - // These are not implemented because durations don't have date information. - void on_abbr_weekday() {} - void on_full_weekday() {} - void on_dec0_weekday(numeric_system) {} - void on_dec1_weekday(numeric_system) {} - void on_abbr_month() {} - void on_full_month() {} - void on_datetime(numeric_system) {} - void on_loc_date(numeric_system) {} - void on_loc_time(numeric_system) {} - void on_us_date() {} - void on_iso_date() {} - void on_utc_offset(numeric_system) {} - void on_tz_name() {} - void on_year(numeric_system) {} - void on_short_year(numeric_system) {} - void on_offset_year() {} - void on_century(numeric_system) {} - void on_iso_week_based_year() {} - void on_iso_week_based_short_year() {} - void on_dec_month(numeric_system) {} - void on_dec0_week_of_year(numeric_system, pad_type) {} - void on_dec1_week_of_year(numeric_system, pad_type) {} - void on_iso_week_of_year(numeric_system, pad_type) {} - void on_day_of_month(numeric_system, pad_type) {} - - void on_day_of_year() { - if (handle_nan_inf()) return; - write(days(), 0); - } - - void on_24_hour(numeric_system ns, pad_type pad) { - if (handle_nan_inf()) return; - - if (ns == numeric_system::standard) return write(hour(), 2, pad); - auto time = tm(); - time.tm_hour = to_nonnegative_int(hour(), 24); - format_tm(time, &tm_writer_type::on_24_hour, ns, pad); - } - - void on_12_hour(numeric_system ns, pad_type pad) { - if (handle_nan_inf()) return; - - if (ns == numeric_system::standard) return write(hour12(), 2, pad); - auto time = tm(); - time.tm_hour = to_nonnegative_int(hour12(), 12); - format_tm(time, &tm_writer_type::on_12_hour, ns, pad); - } - - void on_minute(numeric_system ns, pad_type pad) { - if (handle_nan_inf()) return; - - if (ns == numeric_system::standard) return write(minute(), 2, pad); - auto time = tm(); - time.tm_min = to_nonnegative_int(minute(), 60); - format_tm(time, &tm_writer_type::on_minute, ns, pad); - } - - void on_second(numeric_system ns, pad_type pad) { - if (handle_nan_inf()) return; - - if (ns == numeric_system::standard) { - if (std::is_floating_point::value) { - auto buf = memory_buffer(); - write_floating_seconds(buf, std::chrono::duration(val), - precision); - if (negative) *out++ = '-'; - if (buf.size() < 2 || buf[1] == '.') { - out = detail::write_padding(out, pad); - } - out = std::copy(buf.begin(), buf.end(), out); - } else { - write(second(), 2, pad); - write_fractional_seconds( - out, std::chrono::duration(val), precision); - } - return; - } - auto time = tm(); - time.tm_sec = to_nonnegative_int(second(), 60); - format_tm(time, &tm_writer_type::on_second, ns, pad); - } - - void on_12_hour_time() { - if (handle_nan_inf()) return; - format_tm(time(), &tm_writer_type::on_12_hour_time); - } - - void on_24_hour_time() { - if (handle_nan_inf()) { - *out++ = ':'; - handle_nan_inf(); - return; - } - - write(hour(), 2); - *out++ = ':'; - write(minute(), 2); - } - - void on_iso_time() { - on_24_hour_time(); - *out++ = ':'; - if (handle_nan_inf()) return; - on_second(numeric_system::standard, pad_type::zero); - } - - void on_am_pm() { - if (handle_nan_inf()) return; - format_tm(time(), &tm_writer_type::on_am_pm); - } - - void on_duration_value() { - if (handle_nan_inf()) return; - write_sign(); - out = format_duration_value(out, val, precision); - } - - void on_duration_unit() { - out = format_duration_unit(out); - } -}; - -} // namespace detail - -#if defined(__cpp_lib_chrono) && __cpp_lib_chrono >= 201907 -using weekday = std::chrono::weekday; -using day = std::chrono::day; -using month = std::chrono::month; -using year = std::chrono::year; -using year_month_day = std::chrono::year_month_day; -#else -// A fallback version of weekday. -class weekday { - private: - unsigned char value_; - - public: - weekday() = default; - constexpr explicit weekday(unsigned wd) noexcept - : value_(static_cast(wd != 7 ? wd : 0)) {} - constexpr auto c_encoding() const noexcept -> unsigned { return value_; } -}; - -class day { - private: - unsigned char value_; - - public: - day() = default; - constexpr explicit day(unsigned d) noexcept - : value_(static_cast(d)) {} - constexpr explicit operator unsigned() const noexcept { return value_; } -}; - -class month { - private: - unsigned char value_; - - public: - month() = default; - constexpr explicit month(unsigned m) noexcept - : value_(static_cast(m)) {} - constexpr explicit operator unsigned() const noexcept { return value_; } -}; - -class year { - private: - int value_; - - public: - year() = default; - constexpr explicit year(int y) noexcept : value_(y) {} - constexpr explicit operator int() const noexcept { return value_; } -}; - -class year_month_day { - private: - fmt::year year_; - fmt::month month_; - fmt::day day_; - - public: - year_month_day() = default; - constexpr year_month_day(const year& y, const month& m, const day& d) noexcept - : year_(y), month_(m), day_(d) {} - constexpr auto year() const noexcept -> fmt::year { return year_; } - constexpr auto month() const noexcept -> fmt::month { return month_; } - constexpr auto day() const noexcept -> fmt::day { return day_; } -}; -#endif - -template -struct formatter : private formatter { - private: - bool localized_ = false; - bool use_tm_formatter_ = false; - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto it = ctx.begin(), end = ctx.end(); - if (it != end && *it == 'L') { - ++it; - localized_ = true; - return it; - } - use_tm_formatter_ = it != end && *it != '}'; - return use_tm_formatter_ ? formatter::parse(ctx) : it; - } - - template - auto format(weekday wd, FormatContext& ctx) const -> decltype(ctx.out()) { - auto time = std::tm(); - time.tm_wday = static_cast(wd.c_encoding()); - if (use_tm_formatter_) return formatter::format(time, ctx); - detail::get_locale loc(localized_, ctx.locale()); - auto w = detail::tm_writer(loc, ctx.out(), time); - w.on_abbr_weekday(); - return w.out(); - } -}; - -template -struct formatter : private formatter { - private: - bool use_tm_formatter_ = false; - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto it = ctx.begin(), end = ctx.end(); - use_tm_formatter_ = it != end && *it != '}'; - return use_tm_formatter_ ? formatter::parse(ctx) : it; - } - - template - auto format(day d, FormatContext& ctx) const -> decltype(ctx.out()) { - auto time = std::tm(); - time.tm_mday = static_cast(static_cast(d)); - if (use_tm_formatter_) return formatter::format(time, ctx); - detail::get_locale loc(false, ctx.locale()); - auto w = detail::tm_writer(loc, ctx.out(), time); - w.on_day_of_month(detail::numeric_system::standard, detail::pad_type::zero); - return w.out(); - } -}; - -template -struct formatter : private formatter { - private: - bool localized_ = false; - bool use_tm_formatter_ = false; - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto it = ctx.begin(), end = ctx.end(); - if (it != end && *it == 'L') { - ++it; - localized_ = true; - return it; - } - use_tm_formatter_ = it != end && *it != '}'; - return use_tm_formatter_ ? formatter::parse(ctx) : it; - } - - template - auto format(month m, FormatContext& ctx) const -> decltype(ctx.out()) { - auto time = std::tm(); - time.tm_mon = static_cast(static_cast(m)) - 1; - if (use_tm_formatter_) return formatter::format(time, ctx); - detail::get_locale loc(localized_, ctx.locale()); - auto w = detail::tm_writer(loc, ctx.out(), time); - w.on_abbr_month(); - return w.out(); - } -}; - -template -struct formatter : private formatter { - private: - bool use_tm_formatter_ = false; - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto it = ctx.begin(), end = ctx.end(); - use_tm_formatter_ = it != end && *it != '}'; - return use_tm_formatter_ ? formatter::parse(ctx) : it; - } - - template - auto format(year y, FormatContext& ctx) const -> decltype(ctx.out()) { - auto time = std::tm(); - time.tm_year = static_cast(y) - 1900; - if (use_tm_formatter_) return formatter::format(time, ctx); - detail::get_locale loc(false, ctx.locale()); - auto w = detail::tm_writer(loc, ctx.out(), time); - w.on_year(detail::numeric_system::standard); - return w.out(); - } -}; - -template -struct formatter : private formatter { - private: - bool use_tm_formatter_ = false; - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto it = ctx.begin(), end = ctx.end(); - use_tm_formatter_ = it != end && *it != '}'; - return use_tm_formatter_ ? formatter::parse(ctx) : it; - } - - template - auto format(year_month_day val, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto time = std::tm(); - time.tm_year = static_cast(val.year()) - 1900; - time.tm_mon = static_cast(static_cast(val.month())) - 1; - time.tm_mday = static_cast(static_cast(val.day())); - if (use_tm_formatter_) return formatter::format(time, ctx); - detail::get_locale loc(true, ctx.locale()); - auto w = detail::tm_writer(loc, ctx.out(), time); - w.on_iso_date(); - return w.out(); - } -}; - -template -struct formatter, Char> { - private: - format_specs specs_; - detail::arg_ref width_ref_; - detail::arg_ref precision_ref_; - bool localized_ = false; - basic_string_view format_str_; - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto it = ctx.begin(), end = ctx.end(); - if (it == end || *it == '}') return it; - - it = detail::parse_align(it, end, specs_); - if (it == end) return it; - - it = detail::parse_dynamic_spec(it, end, specs_.width, width_ref_, ctx); - if (it == end) return it; - - auto checker = detail::chrono_format_checker(); - if (*it == '.') { - checker.has_precision_integral = !std::is_floating_point::value; - it = detail::parse_precision(it, end, specs_.precision, precision_ref_, - ctx); - } - if (it != end && *it == 'L') { - localized_ = true; - ++it; - } - end = detail::parse_chrono_format(it, end, checker); - format_str_ = {it, detail::to_unsigned(end - it)}; - return end; - } - - template - auto format(std::chrono::duration d, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto specs = specs_; - auto precision = specs.precision; - specs.precision = -1; - auto begin = format_str_.begin(), end = format_str_.end(); - // As a possible future optimization, we could avoid extra copying if width - // is not specified. - auto buf = basic_memory_buffer(); - auto out = std::back_inserter(buf); - detail::handle_dynamic_spec(specs.width, width_ref_, - ctx); - detail::handle_dynamic_spec(precision, - precision_ref_, ctx); - if (begin == end || *begin == '}') { - out = detail::format_duration_value(out, d.count(), precision); - detail::format_duration_unit(out); - } else { - using chrono_formatter = - detail::chrono_formatter; - auto f = chrono_formatter(ctx, out, d); - f.precision = precision; - f.localized = localized_; - detail::parse_chrono_format(begin, end, f); - } - return detail::write( - ctx.out(), basic_string_view(buf.data(), buf.size()), specs); - } -}; - -template -struct formatter, - Char> : formatter { - FMT_CONSTEXPR formatter() { - this->format_str_ = detail::string_literal{}; - } - - template - auto format(std::chrono::time_point val, - FormatContext& ctx) const -> decltype(ctx.out()) { - std::tm tm = gmtime(val); - using period = typename Duration::period; - if (detail::const_check( - period::num == 1 && period::den == 1 && - !std::is_floating_point::value)) { - return formatter::format(tm, ctx); - } - Duration epoch = val.time_since_epoch(); - Duration subsecs = detail::fmt_duration_cast( - epoch - detail::fmt_duration_cast(epoch)); - if (subsecs.count() < 0) { - auto second = - detail::fmt_duration_cast(std::chrono::seconds(1)); - if (tm.tm_sec != 0) - --tm.tm_sec; - else - tm = gmtime(val - second); - subsecs += detail::fmt_duration_cast(std::chrono::seconds(1)); - } - return formatter::do_format(tm, ctx, &subsecs); - } -}; - -#if FMT_USE_LOCAL_TIME -template -struct formatter, Char> - : formatter { - FMT_CONSTEXPR formatter() { - this->format_str_ = detail::string_literal{}; - } - - template - auto format(std::chrono::local_time val, FormatContext& ctx) const - -> decltype(ctx.out()) { - using period = typename Duration::period; - if (period::num != 1 || period::den != 1 || - std::is_floating_point::value) { - const auto epoch = val.time_since_epoch(); - const auto subsecs = detail::fmt_duration_cast( - epoch - detail::fmt_duration_cast(epoch)); - - return formatter::do_format(localtime(val), ctx, &subsecs); - } - - return formatter::format(localtime(val), ctx); - } -}; -#endif - -#if FMT_USE_UTC_TIME -template -struct formatter, - Char> - : formatter, - Char> { - template - auto format(std::chrono::time_point val, - FormatContext& ctx) const -> decltype(ctx.out()) { - return formatter< - std::chrono::time_point, - Char>::format(std::chrono::utc_clock::to_sys(val), ctx); - } -}; -#endif - -template struct formatter { - private: - format_specs specs_; - detail::arg_ref width_ref_; - - protected: - basic_string_view format_str_; - - template - auto do_format(const std::tm& tm, FormatContext& ctx, - const Duration* subsecs) const -> decltype(ctx.out()) { - auto specs = specs_; - auto buf = basic_memory_buffer(); - auto out = std::back_inserter(buf); - detail::handle_dynamic_spec(specs.width, width_ref_, - ctx); - - auto loc_ref = ctx.locale(); - detail::get_locale loc(static_cast(loc_ref), loc_ref); - auto w = - detail::tm_writer(loc, out, tm, subsecs); - detail::parse_chrono_format(format_str_.begin(), format_str_.end(), w); - return detail::write( - ctx.out(), basic_string_view(buf.data(), buf.size()), specs); - } - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto it = ctx.begin(), end = ctx.end(); - if (it == end || *it == '}') return it; - - it = detail::parse_align(it, end, specs_); - if (it == end) return it; - - it = detail::parse_dynamic_spec(it, end, specs_.width, width_ref_, ctx); - if (it == end) return it; - - end = detail::parse_chrono_format(it, end, detail::tm_format_checker()); - // Replace the default format_str only if the new spec is not empty. - if (end != it) format_str_ = {it, detail::to_unsigned(end - it)}; - return end; - } - - template - auto format(const std::tm& tm, FormatContext& ctx) const - -> decltype(ctx.out()) { - return do_format(tm, ctx, nullptr); - } -}; - -FMT_END_EXPORT -FMT_END_NAMESPACE - -#endif // FMT_CHRONO_H_ diff --git a/tt_metal/third_party/fmt/fmt/color.h b/tt_metal/third_party/fmt/fmt/color.h deleted file mode 100644 index f0e9dd94ef3..00000000000 --- a/tt_metal/third_party/fmt/fmt/color.h +++ /dev/null @@ -1,612 +0,0 @@ -// Formatting library for C++ - color support -// -// Copyright (c) 2018 - present, Victor Zverovich and fmt contributors -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_COLOR_H_ -#define FMT_COLOR_H_ - -#include "format.h" - -FMT_BEGIN_NAMESPACE -FMT_BEGIN_EXPORT - -enum class color : uint32_t { - alice_blue = 0xF0F8FF, // rgb(240,248,255) - antique_white = 0xFAEBD7, // rgb(250,235,215) - aqua = 0x00FFFF, // rgb(0,255,255) - aquamarine = 0x7FFFD4, // rgb(127,255,212) - azure = 0xF0FFFF, // rgb(240,255,255) - beige = 0xF5F5DC, // rgb(245,245,220) - bisque = 0xFFE4C4, // rgb(255,228,196) - black = 0x000000, // rgb(0,0,0) - blanched_almond = 0xFFEBCD, // rgb(255,235,205) - blue = 0x0000FF, // rgb(0,0,255) - blue_violet = 0x8A2BE2, // rgb(138,43,226) - brown = 0xA52A2A, // rgb(165,42,42) - burly_wood = 0xDEB887, // rgb(222,184,135) - cadet_blue = 0x5F9EA0, // rgb(95,158,160) - chartreuse = 0x7FFF00, // rgb(127,255,0) - chocolate = 0xD2691E, // rgb(210,105,30) - coral = 0xFF7F50, // rgb(255,127,80) - cornflower_blue = 0x6495ED, // rgb(100,149,237) - cornsilk = 0xFFF8DC, // rgb(255,248,220) - crimson = 0xDC143C, // rgb(220,20,60) - cyan = 0x00FFFF, // rgb(0,255,255) - dark_blue = 0x00008B, // rgb(0,0,139) - dark_cyan = 0x008B8B, // rgb(0,139,139) - dark_golden_rod = 0xB8860B, // rgb(184,134,11) - dark_gray = 0xA9A9A9, // rgb(169,169,169) - dark_green = 0x006400, // rgb(0,100,0) - dark_khaki = 0xBDB76B, // rgb(189,183,107) - dark_magenta = 0x8B008B, // rgb(139,0,139) - dark_olive_green = 0x556B2F, // rgb(85,107,47) - dark_orange = 0xFF8C00, // rgb(255,140,0) - dark_orchid = 0x9932CC, // rgb(153,50,204) - dark_red = 0x8B0000, // rgb(139,0,0) - dark_salmon = 0xE9967A, // rgb(233,150,122) - dark_sea_green = 0x8FBC8F, // rgb(143,188,143) - dark_slate_blue = 0x483D8B, // rgb(72,61,139) - dark_slate_gray = 0x2F4F4F, // rgb(47,79,79) - dark_turquoise = 0x00CED1, // rgb(0,206,209) - dark_violet = 0x9400D3, // rgb(148,0,211) - deep_pink = 0xFF1493, // rgb(255,20,147) - deep_sky_blue = 0x00BFFF, // rgb(0,191,255) - dim_gray = 0x696969, // rgb(105,105,105) - dodger_blue = 0x1E90FF, // rgb(30,144,255) - fire_brick = 0xB22222, // rgb(178,34,34) - floral_white = 0xFFFAF0, // rgb(255,250,240) - forest_green = 0x228B22, // rgb(34,139,34) - fuchsia = 0xFF00FF, // rgb(255,0,255) - gainsboro = 0xDCDCDC, // rgb(220,220,220) - ghost_white = 0xF8F8FF, // rgb(248,248,255) - gold = 0xFFD700, // rgb(255,215,0) - golden_rod = 0xDAA520, // rgb(218,165,32) - gray = 0x808080, // rgb(128,128,128) - green = 0x008000, // rgb(0,128,0) - green_yellow = 0xADFF2F, // rgb(173,255,47) - honey_dew = 0xF0FFF0, // rgb(240,255,240) - hot_pink = 0xFF69B4, // rgb(255,105,180) - indian_red = 0xCD5C5C, // rgb(205,92,92) - indigo = 0x4B0082, // rgb(75,0,130) - ivory = 0xFFFFF0, // rgb(255,255,240) - khaki = 0xF0E68C, // rgb(240,230,140) - lavender = 0xE6E6FA, // rgb(230,230,250) - lavender_blush = 0xFFF0F5, // rgb(255,240,245) - lawn_green = 0x7CFC00, // rgb(124,252,0) - lemon_chiffon = 0xFFFACD, // rgb(255,250,205) - light_blue = 0xADD8E6, // rgb(173,216,230) - light_coral = 0xF08080, // rgb(240,128,128) - light_cyan = 0xE0FFFF, // rgb(224,255,255) - light_golden_rod_yellow = 0xFAFAD2, // rgb(250,250,210) - light_gray = 0xD3D3D3, // rgb(211,211,211) - light_green = 0x90EE90, // rgb(144,238,144) - light_pink = 0xFFB6C1, // rgb(255,182,193) - light_salmon = 0xFFA07A, // rgb(255,160,122) - light_sea_green = 0x20B2AA, // rgb(32,178,170) - light_sky_blue = 0x87CEFA, // rgb(135,206,250) - light_slate_gray = 0x778899, // rgb(119,136,153) - light_steel_blue = 0xB0C4DE, // rgb(176,196,222) - light_yellow = 0xFFFFE0, // rgb(255,255,224) - lime = 0x00FF00, // rgb(0,255,0) - lime_green = 0x32CD32, // rgb(50,205,50) - linen = 0xFAF0E6, // rgb(250,240,230) - magenta = 0xFF00FF, // rgb(255,0,255) - maroon = 0x800000, // rgb(128,0,0) - medium_aquamarine = 0x66CDAA, // rgb(102,205,170) - medium_blue = 0x0000CD, // rgb(0,0,205) - medium_orchid = 0xBA55D3, // rgb(186,85,211) - medium_purple = 0x9370DB, // rgb(147,112,219) - medium_sea_green = 0x3CB371, // rgb(60,179,113) - medium_slate_blue = 0x7B68EE, // rgb(123,104,238) - medium_spring_green = 0x00FA9A, // rgb(0,250,154) - medium_turquoise = 0x48D1CC, // rgb(72,209,204) - medium_violet_red = 0xC71585, // rgb(199,21,133) - midnight_blue = 0x191970, // rgb(25,25,112) - mint_cream = 0xF5FFFA, // rgb(245,255,250) - misty_rose = 0xFFE4E1, // rgb(255,228,225) - moccasin = 0xFFE4B5, // rgb(255,228,181) - navajo_white = 0xFFDEAD, // rgb(255,222,173) - navy = 0x000080, // rgb(0,0,128) - old_lace = 0xFDF5E6, // rgb(253,245,230) - olive = 0x808000, // rgb(128,128,0) - olive_drab = 0x6B8E23, // rgb(107,142,35) - orange = 0xFFA500, // rgb(255,165,0) - orange_red = 0xFF4500, // rgb(255,69,0) - orchid = 0xDA70D6, // rgb(218,112,214) - pale_golden_rod = 0xEEE8AA, // rgb(238,232,170) - pale_green = 0x98FB98, // rgb(152,251,152) - pale_turquoise = 0xAFEEEE, // rgb(175,238,238) - pale_violet_red = 0xDB7093, // rgb(219,112,147) - papaya_whip = 0xFFEFD5, // rgb(255,239,213) - peach_puff = 0xFFDAB9, // rgb(255,218,185) - peru = 0xCD853F, // rgb(205,133,63) - pink = 0xFFC0CB, // rgb(255,192,203) - plum = 0xDDA0DD, // rgb(221,160,221) - powder_blue = 0xB0E0E6, // rgb(176,224,230) - purple = 0x800080, // rgb(128,0,128) - rebecca_purple = 0x663399, // rgb(102,51,153) - red = 0xFF0000, // rgb(255,0,0) - rosy_brown = 0xBC8F8F, // rgb(188,143,143) - royal_blue = 0x4169E1, // rgb(65,105,225) - saddle_brown = 0x8B4513, // rgb(139,69,19) - salmon = 0xFA8072, // rgb(250,128,114) - sandy_brown = 0xF4A460, // rgb(244,164,96) - sea_green = 0x2E8B57, // rgb(46,139,87) - sea_shell = 0xFFF5EE, // rgb(255,245,238) - sienna = 0xA0522D, // rgb(160,82,45) - silver = 0xC0C0C0, // rgb(192,192,192) - sky_blue = 0x87CEEB, // rgb(135,206,235) - slate_blue = 0x6A5ACD, // rgb(106,90,205) - slate_gray = 0x708090, // rgb(112,128,144) - snow = 0xFFFAFA, // rgb(255,250,250) - spring_green = 0x00FF7F, // rgb(0,255,127) - steel_blue = 0x4682B4, // rgb(70,130,180) - tan = 0xD2B48C, // rgb(210,180,140) - teal = 0x008080, // rgb(0,128,128) - thistle = 0xD8BFD8, // rgb(216,191,216) - tomato = 0xFF6347, // rgb(255,99,71) - turquoise = 0x40E0D0, // rgb(64,224,208) - violet = 0xEE82EE, // rgb(238,130,238) - wheat = 0xF5DEB3, // rgb(245,222,179) - white = 0xFFFFFF, // rgb(255,255,255) - white_smoke = 0xF5F5F5, // rgb(245,245,245) - yellow = 0xFFFF00, // rgb(255,255,0) - yellow_green = 0x9ACD32 // rgb(154,205,50) -}; // enum class color - -enum class terminal_color : uint8_t { - black = 30, - red, - green, - yellow, - blue, - magenta, - cyan, - white, - bright_black = 90, - bright_red, - bright_green, - bright_yellow, - bright_blue, - bright_magenta, - bright_cyan, - bright_white -}; - -enum class emphasis : uint8_t { - bold = 1, - faint = 1 << 1, - italic = 1 << 2, - underline = 1 << 3, - blink = 1 << 4, - reverse = 1 << 5, - conceal = 1 << 6, - strikethrough = 1 << 7, -}; - -// rgb is a struct for red, green and blue colors. -// Using the name "rgb" makes some editors show the color in a tooltip. -struct rgb { - FMT_CONSTEXPR rgb() : r(0), g(0), b(0) {} - FMT_CONSTEXPR rgb(uint8_t r_, uint8_t g_, uint8_t b_) : r(r_), g(g_), b(b_) {} - FMT_CONSTEXPR rgb(uint32_t hex) - : r((hex >> 16) & 0xFF), g((hex >> 8) & 0xFF), b(hex & 0xFF) {} - FMT_CONSTEXPR rgb(color hex) - : r((uint32_t(hex) >> 16) & 0xFF), - g((uint32_t(hex) >> 8) & 0xFF), - b(uint32_t(hex) & 0xFF) {} - uint8_t r; - uint8_t g; - uint8_t b; -}; - -namespace detail { - -// color is a struct of either a rgb color or a terminal color. -struct color_type { - FMT_CONSTEXPR color_type() noexcept : is_rgb(), value{} {} - FMT_CONSTEXPR color_type(color rgb_color) noexcept : is_rgb(true), value{} { - value.rgb_color = static_cast(rgb_color); - } - FMT_CONSTEXPR color_type(rgb rgb_color) noexcept : is_rgb(true), value{} { - value.rgb_color = (static_cast(rgb_color.r) << 16) | - (static_cast(rgb_color.g) << 8) | rgb_color.b; - } - FMT_CONSTEXPR color_type(terminal_color term_color) noexcept - : is_rgb(), value{} { - value.term_color = static_cast(term_color); - } - bool is_rgb; - union color_union { - uint8_t term_color; - uint32_t rgb_color; - } value; -}; -} // namespace detail - -/// A text style consisting of foreground and background colors and emphasis. -class text_style { - public: - FMT_CONSTEXPR text_style(emphasis em = emphasis()) noexcept - : set_foreground_color(), set_background_color(), ems(em) {} - - FMT_CONSTEXPR auto operator|=(const text_style& rhs) -> text_style& { - if (!set_foreground_color) { - set_foreground_color = rhs.set_foreground_color; - foreground_color = rhs.foreground_color; - } else if (rhs.set_foreground_color) { - if (!foreground_color.is_rgb || !rhs.foreground_color.is_rgb) - report_error("can't OR a terminal color"); - foreground_color.value.rgb_color |= rhs.foreground_color.value.rgb_color; - } - - if (!set_background_color) { - set_background_color = rhs.set_background_color; - background_color = rhs.background_color; - } else if (rhs.set_background_color) { - if (!background_color.is_rgb || !rhs.background_color.is_rgb) - report_error("can't OR a terminal color"); - background_color.value.rgb_color |= rhs.background_color.value.rgb_color; - } - - ems = static_cast(static_cast(ems) | - static_cast(rhs.ems)); - return *this; - } - - friend FMT_CONSTEXPR auto operator|(text_style lhs, const text_style& rhs) - -> text_style { - return lhs |= rhs; - } - - FMT_CONSTEXPR auto has_foreground() const noexcept -> bool { - return set_foreground_color; - } - FMT_CONSTEXPR auto has_background() const noexcept -> bool { - return set_background_color; - } - FMT_CONSTEXPR auto has_emphasis() const noexcept -> bool { - return static_cast(ems) != 0; - } - FMT_CONSTEXPR auto get_foreground() const noexcept -> detail::color_type { - FMT_ASSERT(has_foreground(), "no foreground specified for this style"); - return foreground_color; - } - FMT_CONSTEXPR auto get_background() const noexcept -> detail::color_type { - FMT_ASSERT(has_background(), "no background specified for this style"); - return background_color; - } - FMT_CONSTEXPR auto get_emphasis() const noexcept -> emphasis { - FMT_ASSERT(has_emphasis(), "no emphasis specified for this style"); - return ems; - } - - private: - FMT_CONSTEXPR text_style(bool is_foreground, - detail::color_type text_color) noexcept - : set_foreground_color(), set_background_color(), ems() { - if (is_foreground) { - foreground_color = text_color; - set_foreground_color = true; - } else { - background_color = text_color; - set_background_color = true; - } - } - - friend FMT_CONSTEXPR auto fg(detail::color_type foreground) noexcept - -> text_style; - - friend FMT_CONSTEXPR auto bg(detail::color_type background) noexcept - -> text_style; - - detail::color_type foreground_color; - detail::color_type background_color; - bool set_foreground_color; - bool set_background_color; - emphasis ems; -}; - -/// Creates a text style from the foreground (text) color. -FMT_CONSTEXPR inline auto fg(detail::color_type foreground) noexcept - -> text_style { - return text_style(true, foreground); -} - -/// Creates a text style from the background color. -FMT_CONSTEXPR inline auto bg(detail::color_type background) noexcept - -> text_style { - return text_style(false, background); -} - -FMT_CONSTEXPR inline auto operator|(emphasis lhs, emphasis rhs) noexcept - -> text_style { - return text_style(lhs) | rhs; -} - -namespace detail { - -template struct ansi_color_escape { - FMT_CONSTEXPR ansi_color_escape(detail::color_type text_color, - const char* esc) noexcept { - // If we have a terminal color, we need to output another escape code - // sequence. - if (!text_color.is_rgb) { - bool is_background = esc == string_view("\x1b[48;2;"); - uint32_t value = text_color.value.term_color; - // Background ASCII codes are the same as the foreground ones but with - // 10 more. - if (is_background) value += 10u; - - size_t index = 0; - buffer[index++] = static_cast('\x1b'); - buffer[index++] = static_cast('['); - - if (value >= 100u) { - buffer[index++] = static_cast('1'); - value %= 100u; - } - buffer[index++] = static_cast('0' + value / 10u); - buffer[index++] = static_cast('0' + value % 10u); - - buffer[index++] = static_cast('m'); - buffer[index++] = static_cast('\0'); - return; - } - - for (int i = 0; i < 7; i++) { - buffer[i] = static_cast(esc[i]); - } - rgb color(text_color.value.rgb_color); - to_esc(color.r, buffer + 7, ';'); - to_esc(color.g, buffer + 11, ';'); - to_esc(color.b, buffer + 15, 'm'); - buffer[19] = static_cast(0); - } - FMT_CONSTEXPR ansi_color_escape(emphasis em) noexcept { - uint8_t em_codes[num_emphases] = {}; - if (has_emphasis(em, emphasis::bold)) em_codes[0] = 1; - if (has_emphasis(em, emphasis::faint)) em_codes[1] = 2; - if (has_emphasis(em, emphasis::italic)) em_codes[2] = 3; - if (has_emphasis(em, emphasis::underline)) em_codes[3] = 4; - if (has_emphasis(em, emphasis::blink)) em_codes[4] = 5; - if (has_emphasis(em, emphasis::reverse)) em_codes[5] = 7; - if (has_emphasis(em, emphasis::conceal)) em_codes[6] = 8; - if (has_emphasis(em, emphasis::strikethrough)) em_codes[7] = 9; - - size_t index = 0; - for (size_t i = 0; i < num_emphases; ++i) { - if (!em_codes[i]) continue; - buffer[index++] = static_cast('\x1b'); - buffer[index++] = static_cast('['); - buffer[index++] = static_cast('0' + em_codes[i]); - buffer[index++] = static_cast('m'); - } - buffer[index++] = static_cast(0); - } - FMT_CONSTEXPR operator const Char*() const noexcept { return buffer; } - - FMT_CONSTEXPR auto begin() const noexcept -> const Char* { return buffer; } - FMT_CONSTEXPR20 auto end() const noexcept -> const Char* { - return buffer + basic_string_view(buffer).size(); - } - - private: - static constexpr size_t num_emphases = 8; - Char buffer[7u + 3u * num_emphases + 1u]; - - static FMT_CONSTEXPR void to_esc(uint8_t c, Char* out, - char delimiter) noexcept { - out[0] = static_cast('0' + c / 100); - out[1] = static_cast('0' + c / 10 % 10); - out[2] = static_cast('0' + c % 10); - out[3] = static_cast(delimiter); - } - static FMT_CONSTEXPR auto has_emphasis(emphasis em, emphasis mask) noexcept - -> bool { - return static_cast(em) & static_cast(mask); - } -}; - -template -FMT_CONSTEXPR auto make_foreground_color(detail::color_type foreground) noexcept - -> ansi_color_escape { - return ansi_color_escape(foreground, "\x1b[38;2;"); -} - -template -FMT_CONSTEXPR auto make_background_color(detail::color_type background) noexcept - -> ansi_color_escape { - return ansi_color_escape(background, "\x1b[48;2;"); -} - -template -FMT_CONSTEXPR auto make_emphasis(emphasis em) noexcept - -> ansi_color_escape { - return ansi_color_escape(em); -} - -template inline void reset_color(buffer& buffer) { - auto reset_color = string_view("\x1b[0m"); - buffer.append(reset_color.begin(), reset_color.end()); -} - -template struct styled_arg : detail::view { - const T& value; - text_style style; - styled_arg(const T& v, text_style s) : value(v), style(s) {} -}; - -template -void vformat_to( - buffer& buf, const text_style& ts, basic_string_view format_str, - basic_format_args>> args) { - bool has_style = false; - if (ts.has_emphasis()) { - has_style = true; - auto emphasis = detail::make_emphasis(ts.get_emphasis()); - buf.append(emphasis.begin(), emphasis.end()); - } - if (ts.has_foreground()) { - has_style = true; - auto foreground = detail::make_foreground_color(ts.get_foreground()); - buf.append(foreground.begin(), foreground.end()); - } - if (ts.has_background()) { - has_style = true; - auto background = detail::make_background_color(ts.get_background()); - buf.append(background.begin(), background.end()); - } - detail::vformat_to(buf, format_str, args, {}); - if (has_style) detail::reset_color(buf); -} - -} // namespace detail - -inline void vprint(FILE* f, const text_style& ts, string_view fmt, - format_args args) { - auto buf = memory_buffer(); - detail::vformat_to(buf, ts, fmt, args); - print(f, FMT_STRING("{}"), string_view(buf.begin(), buf.size())); -} - -/** - * Formats a string and prints it to the specified file stream using ANSI - * escape sequences to specify text formatting. - * - * **Example**: - * - * fmt::print(fmt::emphasis::bold | fg(fmt::color::red), - * "Elapsed time: {0:.2f} seconds", 1.23); - */ -template -void print(FILE* f, const text_style& ts, format_string fmt, - T&&... args) { - vprint(f, ts, fmt, fmt::make_format_args(args...)); -} - -/** - * Formats a string and prints it to stdout using ANSI escape sequences to - * specify text formatting. - * - * **Example**: - * - * fmt::print(fmt::emphasis::bold | fg(fmt::color::red), - * "Elapsed time: {0:.2f} seconds", 1.23); - */ -template -void print(const text_style& ts, format_string fmt, T&&... args) { - return print(stdout, ts, fmt, std::forward(args)...); -} - -inline auto vformat(const text_style& ts, string_view fmt, format_args args) - -> std::string { - auto buf = memory_buffer(); - detail::vformat_to(buf, ts, fmt, args); - return fmt::to_string(buf); -} - -/** - * Formats arguments and returns the result as a string using ANSI escape - * sequences to specify text formatting. - * - * **Example**: - * - * ``` - * #include - * std::string message = fmt::format(fmt::emphasis::bold | fg(fmt::color::red), - * "The answer is {}", 42); - * ``` - */ -template -inline auto format(const text_style& ts, format_string fmt, T&&... args) - -> std::string { - return fmt::vformat(ts, fmt, fmt::make_format_args(args...)); -} - -/// Formats a string with the given text_style and writes the output to `out`. -template ::value)> -auto vformat_to(OutputIt out, const text_style& ts, string_view fmt, - format_args args) -> OutputIt { - auto&& buf = detail::get_buffer(out); - detail::vformat_to(buf, ts, fmt, args); - return detail::get_iterator(buf, out); -} - -/** - * Formats arguments with the given text style, writes the result to the output - * iterator `out` and returns the iterator past the end of the output range. - * - * **Example**: - * - * std::vector out; - * fmt::format_to(std::back_inserter(out), - * fmt::emphasis::bold | fg(fmt::color::red), "{}", 42); - */ -template ::value)> -inline auto format_to(OutputIt out, const text_style& ts, - format_string fmt, T&&... args) -> OutputIt { - return vformat_to(out, ts, fmt, fmt::make_format_args(args...)); -} - -template -struct formatter, Char> : formatter { - template - auto format(const detail::styled_arg& arg, FormatContext& ctx) const - -> decltype(ctx.out()) { - const auto& ts = arg.style; - const auto& value = arg.value; - auto out = ctx.out(); - - bool has_style = false; - if (ts.has_emphasis()) { - has_style = true; - auto emphasis = detail::make_emphasis(ts.get_emphasis()); - out = std::copy(emphasis.begin(), emphasis.end(), out); - } - if (ts.has_foreground()) { - has_style = true; - auto foreground = - detail::make_foreground_color(ts.get_foreground()); - out = std::copy(foreground.begin(), foreground.end(), out); - } - if (ts.has_background()) { - has_style = true; - auto background = - detail::make_background_color(ts.get_background()); - out = std::copy(background.begin(), background.end(), out); - } - out = formatter::format(value, ctx); - if (has_style) { - auto reset_color = string_view("\x1b[0m"); - out = std::copy(reset_color.begin(), reset_color.end(), out); - } - return out; - } -}; - -/** - * Returns an argument that will be formatted using ANSI escape sequences, - * to be used in a formatting function. - * - * **Example**: - * - * fmt::print("Elapsed time: {0:.2f} seconds", - * fmt::styled(1.23, fmt::fg(fmt::color::green) | - * fmt::bg(fmt::color::blue))); - */ -template -FMT_CONSTEXPR auto styled(const T& value, text_style ts) - -> detail::styled_arg> { - return detail::styled_arg>{value, ts}; -} - -FMT_END_EXPORT -FMT_END_NAMESPACE - -#endif // FMT_COLOR_H_ diff --git a/tt_metal/third_party/fmt/fmt/compile.h b/tt_metal/third_party/fmt/fmt/compile.h deleted file mode 100644 index b2afc2c309f..00000000000 --- a/tt_metal/third_party/fmt/fmt/compile.h +++ /dev/null @@ -1,529 +0,0 @@ -// Formatting library for C++ - experimental format string compilation -// -// Copyright (c) 2012 - present, Victor Zverovich and fmt contributors -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_COMPILE_H_ -#define FMT_COMPILE_H_ - -#ifndef FMT_MODULE -# include // std::back_inserter -#endif - -#include "format.h" - -FMT_BEGIN_NAMESPACE - -// A compile-time string which is compiled into fast formatting code. -FMT_EXPORT class compiled_string {}; - -namespace detail { - -template -FMT_CONSTEXPR inline auto copy(InputIt begin, InputIt end, counting_iterator it) - -> counting_iterator { - return it + (end - begin); -} - -template -struct is_compiled_string : std::is_base_of {}; - -/** - * Converts a string literal `s` into a format string that will be parsed at - * compile time and converted into efficient formatting code. Requires C++17 - * `constexpr if` compiler support. - * - * **Example**: - * - * // Converts 42 into std::string using the most efficient method and no - * // runtime format string processing. - * std::string s = fmt::format(FMT_COMPILE("{}"), 42); - */ -#if defined(__cpp_if_constexpr) && defined(__cpp_return_type_deduction) -# define FMT_COMPILE(s) FMT_STRING_IMPL(s, fmt::compiled_string, explicit) -#else -# define FMT_COMPILE(s) FMT_STRING(s) -#endif - -#if FMT_USE_NONTYPE_TEMPLATE_ARGS -template Str> -struct udl_compiled_string : compiled_string { - using char_type = Char; - explicit constexpr operator basic_string_view() const { - return {Str.data, N - 1}; - } -}; -#endif - -template -auto first(const T& value, const Tail&...) -> const T& { - return value; -} - -#if defined(__cpp_if_constexpr) && defined(__cpp_return_type_deduction) -template struct type_list {}; - -// Returns a reference to the argument at index N from [first, rest...]. -template -constexpr const auto& get([[maybe_unused]] const T& first, - [[maybe_unused]] const Args&... rest) { - static_assert(N < 1 + sizeof...(Args), "index is out of bounds"); - if constexpr (N == 0) - return first; - else - return detail::get(rest...); -} - -template -constexpr int get_arg_index_by_name(basic_string_view name, - type_list) { - return get_arg_index_by_name(name); -} - -template struct get_type_impl; - -template struct get_type_impl> { - using type = - remove_cvref_t(std::declval()...))>; -}; - -template -using get_type = typename get_type_impl::type; - -template struct is_compiled_format : std::false_type {}; - -template struct text { - basic_string_view data; - using char_type = Char; - - template - constexpr OutputIt format(OutputIt out, const Args&...) const { - return write(out, data); - } -}; - -template -struct is_compiled_format> : std::true_type {}; - -template -constexpr text make_text(basic_string_view s, size_t pos, - size_t size) { - return {{&s[pos], size}}; -} - -template struct code_unit { - Char value; - using char_type = Char; - - template - constexpr OutputIt format(OutputIt out, const Args&...) const { - *out++ = value; - return out; - } -}; - -// This ensures that the argument type is convertible to `const T&`. -template -constexpr const T& get_arg_checked(const Args&... args) { - const auto& arg = detail::get(args...); - if constexpr (detail::is_named_arg>()) { - return arg.value; - } else { - return arg; - } -} - -template -struct is_compiled_format> : std::true_type {}; - -// A replacement field that refers to argument N. -template struct field { - using char_type = Char; - - template - constexpr OutputIt format(OutputIt out, const Args&... args) const { - const T& arg = get_arg_checked(args...); - if constexpr (std::is_convertible>::value) { - auto s = basic_string_view(arg); - return copy(s.begin(), s.end(), out); - } - return write(out, arg); - } -}; - -template -struct is_compiled_format> : std::true_type {}; - -// A replacement field that refers to argument with name. -template struct runtime_named_field { - using char_type = Char; - basic_string_view name; - - template - constexpr static bool try_format_argument( - OutputIt& out, - // [[maybe_unused]] due to unused-but-set-parameter warning in GCC 7,8,9 - [[maybe_unused]] basic_string_view arg_name, const T& arg) { - if constexpr (is_named_arg::type>::value) { - if (arg_name == arg.name) { - out = write(out, arg.value); - return true; - } - } - return false; - } - - template - constexpr OutputIt format(OutputIt out, const Args&... args) const { - bool found = (try_format_argument(out, name, args) || ...); - if (!found) { - FMT_THROW(format_error("argument with specified name is not found")); - } - return out; - } -}; - -template -struct is_compiled_format> : std::true_type {}; - -// A replacement field that refers to argument N and has format specifiers. -template struct spec_field { - using char_type = Char; - formatter fmt; - - template - constexpr FMT_INLINE OutputIt format(OutputIt out, - const Args&... args) const { - const auto& vargs = - fmt::make_format_args>(args...); - basic_format_context ctx(out, vargs); - return fmt.format(get_arg_checked(args...), ctx); - } -}; - -template -struct is_compiled_format> : std::true_type {}; - -template struct concat { - L lhs; - R rhs; - using char_type = typename L::char_type; - - template - constexpr OutputIt format(OutputIt out, const Args&... args) const { - out = lhs.format(out, args...); - return rhs.format(out, args...); - } -}; - -template -struct is_compiled_format> : std::true_type {}; - -template -constexpr concat make_concat(L lhs, R rhs) { - return {lhs, rhs}; -} - -struct unknown_format {}; - -template -constexpr size_t parse_text(basic_string_view str, size_t pos) { - for (size_t size = str.size(); pos != size; ++pos) { - if (str[pos] == '{' || str[pos] == '}') break; - } - return pos; -} - -template -constexpr auto compile_format_string(S fmt); - -template -constexpr auto parse_tail(T head, S fmt) { - if constexpr (POS != basic_string_view(fmt).size()) { - constexpr auto tail = compile_format_string(fmt); - if constexpr (std::is_same, - unknown_format>()) - return tail; - else - return make_concat(head, tail); - } else { - return head; - } -} - -template struct parse_specs_result { - formatter fmt; - size_t end; - int next_arg_id; -}; - -enum { manual_indexing_id = -1 }; - -template -constexpr parse_specs_result parse_specs(basic_string_view str, - size_t pos, int next_arg_id) { - str.remove_prefix(pos); - auto ctx = - compile_parse_context(str, max_value(), nullptr, next_arg_id); - auto f = formatter(); - auto end = f.parse(ctx); - return {f, pos + fmt::detail::to_unsigned(end - str.data()), - next_arg_id == 0 ? manual_indexing_id : ctx.next_arg_id()}; -} - -template struct arg_id_handler { - arg_ref arg_id; - - constexpr int on_auto() { - FMT_ASSERT(false, "handler cannot be used with automatic indexing"); - return 0; - } - constexpr int on_index(int id) { - arg_id = arg_ref(id); - return 0; - } - constexpr int on_name(basic_string_view id) { - arg_id = arg_ref(id); - return 0; - } -}; - -template struct parse_arg_id_result { - arg_ref arg_id; - const Char* arg_id_end; -}; - -template -constexpr auto parse_arg_id(const Char* begin, const Char* end) { - auto handler = arg_id_handler{arg_ref{}}; - auto arg_id_end = parse_arg_id(begin, end, handler); - return parse_arg_id_result{handler.arg_id, arg_id_end}; -} - -template struct field_type { - using type = remove_cvref_t; -}; - -template -struct field_type::value>> { - using type = remove_cvref_t; -}; - -template -constexpr auto parse_replacement_field_then_tail(S fmt) { - using char_type = typename S::char_type; - constexpr auto str = basic_string_view(fmt); - constexpr char_type c = END_POS != str.size() ? str[END_POS] : char_type(); - if constexpr (c == '}') { - return parse_tail( - field::type, ARG_INDEX>(), fmt); - } else if constexpr (c != ':') { - FMT_THROW(format_error("expected ':'")); - } else { - constexpr auto result = parse_specs::type>( - str, END_POS + 1, NEXT_ID == manual_indexing_id ? 0 : NEXT_ID); - if constexpr (result.end >= str.size() || str[result.end] != '}') { - FMT_THROW(format_error("expected '}'")); - return 0; - } else { - return parse_tail( - spec_field::type, ARG_INDEX>{ - result.fmt}, - fmt); - } - } -} - -// Compiles a non-empty format string and returns the compiled representation -// or unknown_format() on unrecognized input. -template -constexpr auto compile_format_string(S fmt) { - using char_type = typename S::char_type; - constexpr auto str = basic_string_view(fmt); - if constexpr (str[POS] == '{') { - if constexpr (POS + 1 == str.size()) - FMT_THROW(format_error("unmatched '{' in format string")); - if constexpr (str[POS + 1] == '{') { - return parse_tail(make_text(str, POS, 1), fmt); - } else if constexpr (str[POS + 1] == '}' || str[POS + 1] == ':') { - static_assert(ID != manual_indexing_id, - "cannot switch from manual to automatic argument indexing"); - constexpr auto next_id = - ID != manual_indexing_id ? ID + 1 : manual_indexing_id; - return parse_replacement_field_then_tail, Args, - POS + 1, ID, next_id>(fmt); - } else { - constexpr auto arg_id_result = - parse_arg_id(str.data() + POS + 1, str.data() + str.size()); - constexpr auto arg_id_end_pos = arg_id_result.arg_id_end - str.data(); - constexpr char_type c = - arg_id_end_pos != str.size() ? str[arg_id_end_pos] : char_type(); - static_assert(c == '}' || c == ':', "missing '}' in format string"); - if constexpr (arg_id_result.arg_id.kind == arg_id_kind::index) { - static_assert( - ID == manual_indexing_id || ID == 0, - "cannot switch from automatic to manual argument indexing"); - constexpr auto arg_index = arg_id_result.arg_id.val.index; - return parse_replacement_field_then_tail, - Args, arg_id_end_pos, - arg_index, manual_indexing_id>( - fmt); - } else if constexpr (arg_id_result.arg_id.kind == arg_id_kind::name) { - constexpr auto arg_index = - get_arg_index_by_name(arg_id_result.arg_id.val.name, Args{}); - if constexpr (arg_index >= 0) { - constexpr auto next_id = - ID != manual_indexing_id ? ID + 1 : manual_indexing_id; - return parse_replacement_field_then_tail< - decltype(get_type::value), Args, arg_id_end_pos, - arg_index, next_id>(fmt); - } else if constexpr (c == '}') { - return parse_tail( - runtime_named_field{arg_id_result.arg_id.val.name}, - fmt); - } else if constexpr (c == ':') { - return unknown_format(); // no type info for specs parsing - } - } - } - } else if constexpr (str[POS] == '}') { - if constexpr (POS + 1 == str.size()) - FMT_THROW(format_error("unmatched '}' in format string")); - return parse_tail(make_text(str, POS, 1), fmt); - } else { - constexpr auto end = parse_text(str, POS + 1); - if constexpr (end - POS > 1) { - return parse_tail(make_text(str, POS, end - POS), fmt); - } else { - return parse_tail(code_unit{str[POS]}, fmt); - } - } -} - -template ::value)> -constexpr auto compile(S fmt) { - constexpr auto str = basic_string_view(fmt); - if constexpr (str.size() == 0) { - return detail::make_text(str, 0, 0); - } else { - constexpr auto result = - detail::compile_format_string, 0, 0>(fmt); - return result; - } -} -#endif // defined(__cpp_if_constexpr) && defined(__cpp_return_type_deduction) -} // namespace detail - -FMT_BEGIN_EXPORT - -#if defined(__cpp_if_constexpr) && defined(__cpp_return_type_deduction) - -template ::value)> -FMT_INLINE std::basic_string format(const CompiledFormat& cf, - const Args&... args) { - auto s = std::basic_string(); - cf.format(std::back_inserter(s), args...); - return s; -} - -template ::value)> -constexpr FMT_INLINE OutputIt format_to(OutputIt out, const CompiledFormat& cf, - const Args&... args) { - return cf.format(out, args...); -} - -template ::value)> -FMT_INLINE std::basic_string format(const S&, - Args&&... args) { - if constexpr (std::is_same::value) { - constexpr auto str = basic_string_view(S()); - if constexpr (str.size() == 2 && str[0] == '{' && str[1] == '}') { - const auto& first = detail::first(args...); - if constexpr (detail::is_named_arg< - remove_cvref_t>::value) { - return fmt::to_string(first.value); - } else { - return fmt::to_string(first); - } - } - } - constexpr auto compiled = detail::compile(S()); - if constexpr (std::is_same, - detail::unknown_format>()) { - return fmt::format( - static_cast>(S()), - std::forward(args)...); - } else { - return fmt::format(compiled, std::forward(args)...); - } -} - -template ::value)> -FMT_CONSTEXPR OutputIt format_to(OutputIt out, const S&, Args&&... args) { - constexpr auto compiled = detail::compile(S()); - if constexpr (std::is_same, - detail::unknown_format>()) { - return fmt::format_to( - out, static_cast>(S()), - std::forward(args)...); - } else { - return fmt::format_to(out, compiled, std::forward(args)...); - } -} -#endif - -template ::value)> -auto format_to_n(OutputIt out, size_t n, const S& fmt, Args&&... args) - -> format_to_n_result { - using traits = detail::fixed_buffer_traits; - auto buf = detail::iterator_buffer(out, n); - fmt::format_to(std::back_inserter(buf), fmt, std::forward(args)...); - return {buf.out(), buf.count()}; -} - -template ::value)> -FMT_CONSTEXPR20 auto formatted_size(const S& fmt, const Args&... args) - -> size_t { - return fmt::format_to(detail::counting_iterator(), fmt, args...).count(); -} - -template ::value)> -void print(std::FILE* f, const S& fmt, const Args&... args) { - memory_buffer buffer; - fmt::format_to(std::back_inserter(buffer), fmt, args...); - detail::print(f, {buffer.data(), buffer.size()}); -} - -template ::value)> -void print(const S& fmt, const Args&... args) { - print(stdout, fmt, args...); -} - -#if FMT_USE_NONTYPE_TEMPLATE_ARGS -inline namespace literals { -template constexpr auto operator""_cf() { - using char_t = remove_cvref_t; - return detail::udl_compiled_string(); -} -} // namespace literals -#endif - -FMT_END_EXPORT -FMT_END_NAMESPACE - -#endif // FMT_COMPILE_H_ diff --git a/tt_metal/third_party/fmt/fmt/core.h b/tt_metal/third_party/fmt/fmt/core.h deleted file mode 100644 index 8ca735f0c00..00000000000 --- a/tt_metal/third_party/fmt/fmt/core.h +++ /dev/null @@ -1,5 +0,0 @@ -// This file is only provided for compatibility and may be removed in future -// versions. Use fmt/base.h if you don't need fmt::format and fmt/format.h -// otherwise. - -#include "format.h" diff --git a/tt_metal/third_party/fmt/fmt/format-inl.h b/tt_metal/third_party/fmt/fmt/format-inl.h deleted file mode 100644 index 8d07cc67233..00000000000 --- a/tt_metal/third_party/fmt/fmt/format-inl.h +++ /dev/null @@ -1,1904 +0,0 @@ -// Formatting library for C++ - implementation -// -// Copyright (c) 2012 - 2016, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_FORMAT_INL_H_ -#define FMT_FORMAT_INL_H_ - -#ifndef FMT_MODULE -# include -# include // errno -# include -# include -# include - -# if !defined(FMT_STATIC_THOUSANDS_SEPARATOR) -# include -# endif -#endif - -#if defined(_WIN32) && !defined(FMT_USE_WRITE_CONSOLE) -# include // _isatty -#endif - -#include "format.h" - -FMT_BEGIN_NAMESPACE -namespace detail { - -FMT_FUNC void assert_fail(const char* file, int line, const char* message) { - // Use unchecked std::fprintf to avoid triggering another assertion when - // writing to stderr fails - std::fprintf(stderr, "%s:%d: assertion failed: %s", file, line, message); - // Chosen instead of std::abort to satisfy Clang in CUDA mode during device - // code pass. - std::terminate(); -} - -FMT_FUNC void format_error_code(detail::buffer& out, int error_code, - string_view message) noexcept { - // Report error code making sure that the output fits into - // inline_buffer_size to avoid dynamic memory allocation and potential - // bad_alloc. - out.try_resize(0); - static const char SEP[] = ": "; - static const char ERROR_STR[] = "error "; - // Subtract 2 to account for terminating null characters in SEP and ERROR_STR. - size_t error_code_size = sizeof(SEP) + sizeof(ERROR_STR) - 2; - auto abs_value = static_cast>(error_code); - if (detail::is_negative(error_code)) { - abs_value = 0 - abs_value; - ++error_code_size; - } - error_code_size += detail::to_unsigned(detail::count_digits(abs_value)); - auto it = appender(out); - if (message.size() <= inline_buffer_size - error_code_size) - fmt::format_to(it, FMT_STRING("{}{}"), message, SEP); - fmt::format_to(it, FMT_STRING("{}{}"), ERROR_STR, error_code); - FMT_ASSERT(out.size() <= inline_buffer_size, ""); -} - -FMT_FUNC void report_error(format_func func, int error_code, - const char* message) noexcept { - memory_buffer full_message; - func(full_message, error_code, message); - // Don't use fwrite_fully because the latter may throw. - if (std::fwrite(full_message.data(), full_message.size(), 1, stderr) > 0) - std::fputc('\n', stderr); -} - -// A wrapper around fwrite that throws on error. -inline void fwrite_fully(const void* ptr, size_t count, FILE* stream) { - size_t written = std::fwrite(ptr, 1, count, stream); - if (written < count) - FMT_THROW(system_error(errno, FMT_STRING("cannot write to file"))); -} - -#ifndef FMT_STATIC_THOUSANDS_SEPARATOR -template -locale_ref::locale_ref(const Locale& loc) : locale_(&loc) { - static_assert(std::is_same::value, ""); -} - -template auto locale_ref::get() const -> Locale { - static_assert(std::is_same::value, ""); - return locale_ ? *static_cast(locale_) : std::locale(); -} - -template -FMT_FUNC auto thousands_sep_impl(locale_ref loc) -> thousands_sep_result { - auto& facet = std::use_facet>(loc.get()); - auto grouping = facet.grouping(); - auto thousands_sep = grouping.empty() ? Char() : facet.thousands_sep(); - return {std::move(grouping), thousands_sep}; -} -template -FMT_FUNC auto decimal_point_impl(locale_ref loc) -> Char { - return std::use_facet>(loc.get()) - .decimal_point(); -} -#else -template -FMT_FUNC auto thousands_sep_impl(locale_ref) -> thousands_sep_result { - return {"\03", FMT_STATIC_THOUSANDS_SEPARATOR}; -} -template FMT_FUNC Char decimal_point_impl(locale_ref) { - return '.'; -} -#endif - -FMT_FUNC auto write_loc(appender out, loc_value value, - const format_specs& specs, locale_ref loc) -> bool { -#ifdef FMT_STATIC_THOUSANDS_SEPARATOR - value.visit(loc_writer<>{ - out, specs, std::string(1, FMT_STATIC_THOUSANDS_SEPARATOR), "\3", "."}); - return true; -#else - auto locale = loc.get(); - // We cannot use the num_put facet because it may produce output in - // a wrong encoding. - using facet = format_facet; - if (std::has_facet(locale)) - return std::use_facet(locale).put(out, value, specs); - return facet(locale).put(out, value, specs); -#endif -} -} // namespace detail - -FMT_FUNC void report_error(const char* message) { - FMT_THROW(format_error(message)); -} - -template typename Locale::id format_facet::id; - -#ifndef FMT_STATIC_THOUSANDS_SEPARATOR -template format_facet::format_facet(Locale& loc) { - auto& numpunct = std::use_facet>(loc); - grouping_ = numpunct.grouping(); - if (!grouping_.empty()) separator_ = std::string(1, numpunct.thousands_sep()); -} - -template <> -FMT_API FMT_FUNC auto format_facet::do_put( - appender out, loc_value val, const format_specs& specs) const -> bool { - return val.visit( - detail::loc_writer<>{out, specs, separator_, grouping_, decimal_point_}); -} -#endif - -FMT_FUNC auto vsystem_error(int error_code, string_view fmt, format_args args) - -> std::system_error { - auto ec = std::error_code(error_code, std::generic_category()); - return std::system_error(ec, vformat(fmt, args)); -} - -namespace detail { - -template -inline auto operator==(basic_fp x, basic_fp y) -> bool { - return x.f == y.f && x.e == y.e; -} - -// Compilers should be able to optimize this into the ror instruction. -FMT_CONSTEXPR inline auto rotr(uint32_t n, uint32_t r) noexcept -> uint32_t { - r &= 31; - return (n >> r) | (n << (32 - r)); -} -FMT_CONSTEXPR inline auto rotr(uint64_t n, uint32_t r) noexcept -> uint64_t { - r &= 63; - return (n >> r) | (n << (64 - r)); -} - -// Implementation of Dragonbox algorithm: https://github.com/jk-jeon/dragonbox. -namespace dragonbox { -// Computes upper 64 bits of multiplication of a 32-bit unsigned integer and a -// 64-bit unsigned integer. -inline auto umul96_upper64(uint32_t x, uint64_t y) noexcept -> uint64_t { - return umul128_upper64(static_cast(x) << 32, y); -} - -// Computes lower 128 bits of multiplication of a 64-bit unsigned integer and a -// 128-bit unsigned integer. -inline auto umul192_lower128(uint64_t x, uint128_fallback y) noexcept - -> uint128_fallback { - uint64_t high = x * y.high(); - uint128_fallback high_low = umul128(x, y.low()); - return {high + high_low.high(), high_low.low()}; -} - -// Computes lower 64 bits of multiplication of a 32-bit unsigned integer and a -// 64-bit unsigned integer. -inline auto umul96_lower64(uint32_t x, uint64_t y) noexcept -> uint64_t { - return x * y; -} - -// Various fast log computations. -inline auto floor_log10_pow2_minus_log10_4_over_3(int e) noexcept -> int { - FMT_ASSERT(e <= 2936 && e >= -2985, "too large exponent"); - return (e * 631305 - 261663) >> 21; -} - -FMT_INLINE_VARIABLE constexpr struct { - uint32_t divisor; - int shift_amount; -} div_small_pow10_infos[] = {{10, 16}, {100, 16}}; - -// Replaces n by floor(n / pow(10, N)) returning true if and only if n is -// divisible by pow(10, N). -// Precondition: n <= pow(10, N + 1). -template -auto check_divisibility_and_divide_by_pow10(uint32_t& n) noexcept -> bool { - // The numbers below are chosen such that: - // 1. floor(n/d) = floor(nm / 2^k) where d=10 or d=100, - // 2. nm mod 2^k < m if and only if n is divisible by d, - // where m is magic_number, k is shift_amount - // and d is divisor. - // - // Item 1 is a common technique of replacing division by a constant with - // multiplication, see e.g. "Division by Invariant Integers Using - // Multiplication" by Granlund and Montgomery (1994). magic_number (m) is set - // to ceil(2^k/d) for large enough k. - // The idea for item 2 originates from Schubfach. - constexpr auto info = div_small_pow10_infos[N - 1]; - FMT_ASSERT(n <= info.divisor * 10, "n is too large"); - constexpr uint32_t magic_number = - (1u << info.shift_amount) / info.divisor + 1; - n *= magic_number; - const uint32_t comparison_mask = (1u << info.shift_amount) - 1; - bool result = (n & comparison_mask) < magic_number; - n >>= info.shift_amount; - return result; -} - -// Computes floor(n / pow(10, N)) for small n and N. -// Precondition: n <= pow(10, N + 1). -template auto small_division_by_pow10(uint32_t n) noexcept -> uint32_t { - constexpr auto info = div_small_pow10_infos[N - 1]; - FMT_ASSERT(n <= info.divisor * 10, "n is too large"); - constexpr uint32_t magic_number = - (1u << info.shift_amount) / info.divisor + 1; - return (n * magic_number) >> info.shift_amount; -} - -// Computes floor(n / 10^(kappa + 1)) (float) -inline auto divide_by_10_to_kappa_plus_1(uint32_t n) noexcept -> uint32_t { - // 1374389535 = ceil(2^37/100) - return static_cast((static_cast(n) * 1374389535) >> 37); -} -// Computes floor(n / 10^(kappa + 1)) (double) -inline auto divide_by_10_to_kappa_plus_1(uint64_t n) noexcept -> uint64_t { - // 2361183241434822607 = ceil(2^(64+7)/1000) - return umul128_upper64(n, 2361183241434822607ull) >> 7; -} - -// Various subroutines using pow10 cache -template struct cache_accessor; - -template <> struct cache_accessor { - using carrier_uint = float_info::carrier_uint; - using cache_entry_type = uint64_t; - - static auto get_cached_power(int k) noexcept -> uint64_t { - FMT_ASSERT(k >= float_info::min_k && k <= float_info::max_k, - "k is out of range"); - static constexpr const uint64_t pow10_significands[] = { - 0x81ceb32c4b43fcf5, 0xa2425ff75e14fc32, 0xcad2f7f5359a3b3f, - 0xfd87b5f28300ca0e, 0x9e74d1b791e07e49, 0xc612062576589ddb, - 0xf79687aed3eec552, 0x9abe14cd44753b53, 0xc16d9a0095928a28, - 0xf1c90080baf72cb2, 0x971da05074da7bef, 0xbce5086492111aeb, - 0xec1e4a7db69561a6, 0x9392ee8e921d5d08, 0xb877aa3236a4b44a, - 0xe69594bec44de15c, 0x901d7cf73ab0acda, 0xb424dc35095cd810, - 0xe12e13424bb40e14, 0x8cbccc096f5088cc, 0xafebff0bcb24aaff, - 0xdbe6fecebdedd5bf, 0x89705f4136b4a598, 0xabcc77118461cefd, - 0xd6bf94d5e57a42bd, 0x8637bd05af6c69b6, 0xa7c5ac471b478424, - 0xd1b71758e219652c, 0x83126e978d4fdf3c, 0xa3d70a3d70a3d70b, - 0xcccccccccccccccd, 0x8000000000000000, 0xa000000000000000, - 0xc800000000000000, 0xfa00000000000000, 0x9c40000000000000, - 0xc350000000000000, 0xf424000000000000, 0x9896800000000000, - 0xbebc200000000000, 0xee6b280000000000, 0x9502f90000000000, - 0xba43b74000000000, 0xe8d4a51000000000, 0x9184e72a00000000, - 0xb5e620f480000000, 0xe35fa931a0000000, 0x8e1bc9bf04000000, - 0xb1a2bc2ec5000000, 0xde0b6b3a76400000, 0x8ac7230489e80000, - 0xad78ebc5ac620000, 0xd8d726b7177a8000, 0x878678326eac9000, - 0xa968163f0a57b400, 0xd3c21bcecceda100, 0x84595161401484a0, - 0xa56fa5b99019a5c8, 0xcecb8f27f4200f3a, 0x813f3978f8940985, - 0xa18f07d736b90be6, 0xc9f2c9cd04674edf, 0xfc6f7c4045812297, - 0x9dc5ada82b70b59e, 0xc5371912364ce306, 0xf684df56c3e01bc7, - 0x9a130b963a6c115d, 0xc097ce7bc90715b4, 0xf0bdc21abb48db21, - 0x96769950b50d88f5, 0xbc143fa4e250eb32, 0xeb194f8e1ae525fe, - 0x92efd1b8d0cf37bf, 0xb7abc627050305ae, 0xe596b7b0c643c71a, - 0x8f7e32ce7bea5c70, 0xb35dbf821ae4f38c, 0xe0352f62a19e306f}; - return pow10_significands[k - float_info::min_k]; - } - - struct compute_mul_result { - carrier_uint result; - bool is_integer; - }; - struct compute_mul_parity_result { - bool parity; - bool is_integer; - }; - - static auto compute_mul(carrier_uint u, - const cache_entry_type& cache) noexcept - -> compute_mul_result { - auto r = umul96_upper64(u, cache); - return {static_cast(r >> 32), - static_cast(r) == 0}; - } - - static auto compute_delta(const cache_entry_type& cache, int beta) noexcept - -> uint32_t { - return static_cast(cache >> (64 - 1 - beta)); - } - - static auto compute_mul_parity(carrier_uint two_f, - const cache_entry_type& cache, - int beta) noexcept - -> compute_mul_parity_result { - FMT_ASSERT(beta >= 1, ""); - FMT_ASSERT(beta < 64, ""); - - auto r = umul96_lower64(two_f, cache); - return {((r >> (64 - beta)) & 1) != 0, - static_cast(r >> (32 - beta)) == 0}; - } - - static auto compute_left_endpoint_for_shorter_interval_case( - const cache_entry_type& cache, int beta) noexcept -> carrier_uint { - return static_cast( - (cache - (cache >> (num_significand_bits() + 2))) >> - (64 - num_significand_bits() - 1 - beta)); - } - - static auto compute_right_endpoint_for_shorter_interval_case( - const cache_entry_type& cache, int beta) noexcept -> carrier_uint { - return static_cast( - (cache + (cache >> (num_significand_bits() + 1))) >> - (64 - num_significand_bits() - 1 - beta)); - } - - static auto compute_round_up_for_shorter_interval_case( - const cache_entry_type& cache, int beta) noexcept -> carrier_uint { - return (static_cast( - cache >> (64 - num_significand_bits() - 2 - beta)) + - 1) / - 2; - } -}; - -template <> struct cache_accessor { - using carrier_uint = float_info::carrier_uint; - using cache_entry_type = uint128_fallback; - - static auto get_cached_power(int k) noexcept -> uint128_fallback { - FMT_ASSERT(k >= float_info::min_k && k <= float_info::max_k, - "k is out of range"); - - static constexpr const uint128_fallback pow10_significands[] = { -#if FMT_USE_FULL_CACHE_DRAGONBOX - {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b}, - {0x9faacf3df73609b1, 0x77b191618c54e9ad}, - {0xc795830d75038c1d, 0xd59df5b9ef6a2418}, - {0xf97ae3d0d2446f25, 0x4b0573286b44ad1e}, - {0x9becce62836ac577, 0x4ee367f9430aec33}, - {0xc2e801fb244576d5, 0x229c41f793cda740}, - {0xf3a20279ed56d48a, 0x6b43527578c11110}, - {0x9845418c345644d6, 0x830a13896b78aaaa}, - {0xbe5691ef416bd60c, 0x23cc986bc656d554}, - {0xedec366b11c6cb8f, 0x2cbfbe86b7ec8aa9}, - {0x94b3a202eb1c3f39, 0x7bf7d71432f3d6aa}, - {0xb9e08a83a5e34f07, 0xdaf5ccd93fb0cc54}, - {0xe858ad248f5c22c9, 0xd1b3400f8f9cff69}, - {0x91376c36d99995be, 0x23100809b9c21fa2}, - {0xb58547448ffffb2d, 0xabd40a0c2832a78b}, - {0xe2e69915b3fff9f9, 0x16c90c8f323f516d}, - {0x8dd01fad907ffc3b, 0xae3da7d97f6792e4}, - {0xb1442798f49ffb4a, 0x99cd11cfdf41779d}, - {0xdd95317f31c7fa1d, 0x40405643d711d584}, - {0x8a7d3eef7f1cfc52, 0x482835ea666b2573}, - {0xad1c8eab5ee43b66, 0xda3243650005eed0}, - {0xd863b256369d4a40, 0x90bed43e40076a83}, - {0x873e4f75e2224e68, 0x5a7744a6e804a292}, - {0xa90de3535aaae202, 0x711515d0a205cb37}, - {0xd3515c2831559a83, 0x0d5a5b44ca873e04}, - {0x8412d9991ed58091, 0xe858790afe9486c3}, - {0xa5178fff668ae0b6, 0x626e974dbe39a873}, - {0xce5d73ff402d98e3, 0xfb0a3d212dc81290}, - {0x80fa687f881c7f8e, 0x7ce66634bc9d0b9a}, - {0xa139029f6a239f72, 0x1c1fffc1ebc44e81}, - {0xc987434744ac874e, 0xa327ffb266b56221}, - {0xfbe9141915d7a922, 0x4bf1ff9f0062baa9}, - {0x9d71ac8fada6c9b5, 0x6f773fc3603db4aa}, - {0xc4ce17b399107c22, 0xcb550fb4384d21d4}, - {0xf6019da07f549b2b, 0x7e2a53a146606a49}, - {0x99c102844f94e0fb, 0x2eda7444cbfc426e}, - {0xc0314325637a1939, 0xfa911155fefb5309}, - {0xf03d93eebc589f88, 0x793555ab7eba27cb}, - {0x96267c7535b763b5, 0x4bc1558b2f3458df}, - {0xbbb01b9283253ca2, 0x9eb1aaedfb016f17}, - {0xea9c227723ee8bcb, 0x465e15a979c1cadd}, - {0x92a1958a7675175f, 0x0bfacd89ec191eca}, - {0xb749faed14125d36, 0xcef980ec671f667c}, - {0xe51c79a85916f484, 0x82b7e12780e7401b}, - {0x8f31cc0937ae58d2, 0xd1b2ecb8b0908811}, - {0xb2fe3f0b8599ef07, 0x861fa7e6dcb4aa16}, - {0xdfbdcece67006ac9, 0x67a791e093e1d49b}, - {0x8bd6a141006042bd, 0xe0c8bb2c5c6d24e1}, - {0xaecc49914078536d, 0x58fae9f773886e19}, - {0xda7f5bf590966848, 0xaf39a475506a899f}, - {0x888f99797a5e012d, 0x6d8406c952429604}, - {0xaab37fd7d8f58178, 0xc8e5087ba6d33b84}, - {0xd5605fcdcf32e1d6, 0xfb1e4a9a90880a65}, - {0x855c3be0a17fcd26, 0x5cf2eea09a550680}, - {0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481f}, - {0xd0601d8efc57b08b, 0xf13b94daf124da27}, - {0x823c12795db6ce57, 0x76c53d08d6b70859}, - {0xa2cb1717b52481ed, 0x54768c4b0c64ca6f}, - {0xcb7ddcdda26da268, 0xa9942f5dcf7dfd0a}, - {0xfe5d54150b090b02, 0xd3f93b35435d7c4d}, - {0x9efa548d26e5a6e1, 0xc47bc5014a1a6db0}, - {0xc6b8e9b0709f109a, 0x359ab6419ca1091c}, - {0xf867241c8cc6d4c0, 0xc30163d203c94b63}, - {0x9b407691d7fc44f8, 0x79e0de63425dcf1e}, - {0xc21094364dfb5636, 0x985915fc12f542e5}, - {0xf294b943e17a2bc4, 0x3e6f5b7b17b2939e}, - {0x979cf3ca6cec5b5a, 0xa705992ceecf9c43}, - {0xbd8430bd08277231, 0x50c6ff782a838354}, - {0xece53cec4a314ebd, 0xa4f8bf5635246429}, - {0x940f4613ae5ed136, 0x871b7795e136be9a}, - {0xb913179899f68584, 0x28e2557b59846e40}, - {0xe757dd7ec07426e5, 0x331aeada2fe589d0}, - {0x9096ea6f3848984f, 0x3ff0d2c85def7622}, - {0xb4bca50b065abe63, 0x0fed077a756b53aa}, - {0xe1ebce4dc7f16dfb, 0xd3e8495912c62895}, - {0x8d3360f09cf6e4bd, 0x64712dd7abbbd95d}, - {0xb080392cc4349dec, 0xbd8d794d96aacfb4}, - {0xdca04777f541c567, 0xecf0d7a0fc5583a1}, - {0x89e42caaf9491b60, 0xf41686c49db57245}, - {0xac5d37d5b79b6239, 0x311c2875c522ced6}, - {0xd77485cb25823ac7, 0x7d633293366b828c}, - {0x86a8d39ef77164bc, 0xae5dff9c02033198}, - {0xa8530886b54dbdeb, 0xd9f57f830283fdfd}, - {0xd267caa862a12d66, 0xd072df63c324fd7c}, - {0x8380dea93da4bc60, 0x4247cb9e59f71e6e}, - {0xa46116538d0deb78, 0x52d9be85f074e609}, - {0xcd795be870516656, 0x67902e276c921f8c}, - {0x806bd9714632dff6, 0x00ba1cd8a3db53b7}, - {0xa086cfcd97bf97f3, 0x80e8a40eccd228a5}, - {0xc8a883c0fdaf7df0, 0x6122cd128006b2ce}, - {0xfad2a4b13d1b5d6c, 0x796b805720085f82}, - {0x9cc3a6eec6311a63, 0xcbe3303674053bb1}, - {0xc3f490aa77bd60fc, 0xbedbfc4411068a9d}, - {0xf4f1b4d515acb93b, 0xee92fb5515482d45}, - {0x991711052d8bf3c5, 0x751bdd152d4d1c4b}, - {0xbf5cd54678eef0b6, 0xd262d45a78a0635e}, - {0xef340a98172aace4, 0x86fb897116c87c35}, - {0x9580869f0e7aac0e, 0xd45d35e6ae3d4da1}, - {0xbae0a846d2195712, 0x8974836059cca10a}, - {0xe998d258869facd7, 0x2bd1a438703fc94c}, - {0x91ff83775423cc06, 0x7b6306a34627ddd0}, - {0xb67f6455292cbf08, 0x1a3bc84c17b1d543}, - {0xe41f3d6a7377eeca, 0x20caba5f1d9e4a94}, - {0x8e938662882af53e, 0x547eb47b7282ee9d}, - {0xb23867fb2a35b28d, 0xe99e619a4f23aa44}, - {0xdec681f9f4c31f31, 0x6405fa00e2ec94d5}, - {0x8b3c113c38f9f37e, 0xde83bc408dd3dd05}, - {0xae0b158b4738705e, 0x9624ab50b148d446}, - {0xd98ddaee19068c76, 0x3badd624dd9b0958}, - {0x87f8a8d4cfa417c9, 0xe54ca5d70a80e5d7}, - {0xa9f6d30a038d1dbc, 0x5e9fcf4ccd211f4d}, - {0xd47487cc8470652b, 0x7647c32000696720}, - {0x84c8d4dfd2c63f3b, 0x29ecd9f40041e074}, - {0xa5fb0a17c777cf09, 0xf468107100525891}, - {0xcf79cc9db955c2cc, 0x7182148d4066eeb5}, - {0x81ac1fe293d599bf, 0xc6f14cd848405531}, - {0xa21727db38cb002f, 0xb8ada00e5a506a7d}, - {0xca9cf1d206fdc03b, 0xa6d90811f0e4851d}, - {0xfd442e4688bd304a, 0x908f4a166d1da664}, - {0x9e4a9cec15763e2e, 0x9a598e4e043287ff}, - {0xc5dd44271ad3cdba, 0x40eff1e1853f29fe}, - {0xf7549530e188c128, 0xd12bee59e68ef47d}, - {0x9a94dd3e8cf578b9, 0x82bb74f8301958cf}, - {0xc13a148e3032d6e7, 0xe36a52363c1faf02}, - {0xf18899b1bc3f8ca1, 0xdc44e6c3cb279ac2}, - {0x96f5600f15a7b7e5, 0x29ab103a5ef8c0ba}, - {0xbcb2b812db11a5de, 0x7415d448f6b6f0e8}, - {0xebdf661791d60f56, 0x111b495b3464ad22}, - {0x936b9fcebb25c995, 0xcab10dd900beec35}, - {0xb84687c269ef3bfb, 0x3d5d514f40eea743}, - {0xe65829b3046b0afa, 0x0cb4a5a3112a5113}, - {0x8ff71a0fe2c2e6dc, 0x47f0e785eaba72ac}, - {0xb3f4e093db73a093, 0x59ed216765690f57}, - {0xe0f218b8d25088b8, 0x306869c13ec3532d}, - {0x8c974f7383725573, 0x1e414218c73a13fc}, - {0xafbd2350644eeacf, 0xe5d1929ef90898fb}, - {0xdbac6c247d62a583, 0xdf45f746b74abf3a}, - {0x894bc396ce5da772, 0x6b8bba8c328eb784}, - {0xab9eb47c81f5114f, 0x066ea92f3f326565}, - {0xd686619ba27255a2, 0xc80a537b0efefebe}, - {0x8613fd0145877585, 0xbd06742ce95f5f37}, - {0xa798fc4196e952e7, 0x2c48113823b73705}, - {0xd17f3b51fca3a7a0, 0xf75a15862ca504c6}, - {0x82ef85133de648c4, 0x9a984d73dbe722fc}, - {0xa3ab66580d5fdaf5, 0xc13e60d0d2e0ebbb}, - {0xcc963fee10b7d1b3, 0x318df905079926a9}, - {0xffbbcfe994e5c61f, 0xfdf17746497f7053}, - {0x9fd561f1fd0f9bd3, 0xfeb6ea8bedefa634}, - {0xc7caba6e7c5382c8, 0xfe64a52ee96b8fc1}, - {0xf9bd690a1b68637b, 0x3dfdce7aa3c673b1}, - {0x9c1661a651213e2d, 0x06bea10ca65c084f}, - {0xc31bfa0fe5698db8, 0x486e494fcff30a63}, - {0xf3e2f893dec3f126, 0x5a89dba3c3efccfb}, - {0x986ddb5c6b3a76b7, 0xf89629465a75e01d}, - {0xbe89523386091465, 0xf6bbb397f1135824}, - {0xee2ba6c0678b597f, 0x746aa07ded582e2d}, - {0x94db483840b717ef, 0xa8c2a44eb4571cdd}, - {0xba121a4650e4ddeb, 0x92f34d62616ce414}, - {0xe896a0d7e51e1566, 0x77b020baf9c81d18}, - {0x915e2486ef32cd60, 0x0ace1474dc1d122f}, - {0xb5b5ada8aaff80b8, 0x0d819992132456bb}, - {0xe3231912d5bf60e6, 0x10e1fff697ed6c6a}, - {0x8df5efabc5979c8f, 0xca8d3ffa1ef463c2}, - {0xb1736b96b6fd83b3, 0xbd308ff8a6b17cb3}, - {0xddd0467c64bce4a0, 0xac7cb3f6d05ddbdf}, - {0x8aa22c0dbef60ee4, 0x6bcdf07a423aa96c}, - {0xad4ab7112eb3929d, 0x86c16c98d2c953c7}, - {0xd89d64d57a607744, 0xe871c7bf077ba8b8}, - {0x87625f056c7c4a8b, 0x11471cd764ad4973}, - {0xa93af6c6c79b5d2d, 0xd598e40d3dd89bd0}, - {0xd389b47879823479, 0x4aff1d108d4ec2c4}, - {0x843610cb4bf160cb, 0xcedf722a585139bb}, - {0xa54394fe1eedb8fe, 0xc2974eb4ee658829}, - {0xce947a3da6a9273e, 0x733d226229feea33}, - {0x811ccc668829b887, 0x0806357d5a3f5260}, - {0xa163ff802a3426a8, 0xca07c2dcb0cf26f8}, - {0xc9bcff6034c13052, 0xfc89b393dd02f0b6}, - {0xfc2c3f3841f17c67, 0xbbac2078d443ace3}, - {0x9d9ba7832936edc0, 0xd54b944b84aa4c0e}, - {0xc5029163f384a931, 0x0a9e795e65d4df12}, - {0xf64335bcf065d37d, 0x4d4617b5ff4a16d6}, - {0x99ea0196163fa42e, 0x504bced1bf8e4e46}, - {0xc06481fb9bcf8d39, 0xe45ec2862f71e1d7}, - {0xf07da27a82c37088, 0x5d767327bb4e5a4d}, - {0x964e858c91ba2655, 0x3a6a07f8d510f870}, - {0xbbe226efb628afea, 0x890489f70a55368c}, - {0xeadab0aba3b2dbe5, 0x2b45ac74ccea842f}, - {0x92c8ae6b464fc96f, 0x3b0b8bc90012929e}, - {0xb77ada0617e3bbcb, 0x09ce6ebb40173745}, - {0xe55990879ddcaabd, 0xcc420a6a101d0516}, - {0x8f57fa54c2a9eab6, 0x9fa946824a12232e}, - {0xb32df8e9f3546564, 0x47939822dc96abfa}, - {0xdff9772470297ebd, 0x59787e2b93bc56f8}, - {0x8bfbea76c619ef36, 0x57eb4edb3c55b65b}, - {0xaefae51477a06b03, 0xede622920b6b23f2}, - {0xdab99e59958885c4, 0xe95fab368e45ecee}, - {0x88b402f7fd75539b, 0x11dbcb0218ebb415}, - {0xaae103b5fcd2a881, 0xd652bdc29f26a11a}, - {0xd59944a37c0752a2, 0x4be76d3346f04960}, - {0x857fcae62d8493a5, 0x6f70a4400c562ddc}, - {0xa6dfbd9fb8e5b88e, 0xcb4ccd500f6bb953}, - {0xd097ad07a71f26b2, 0x7e2000a41346a7a8}, - {0x825ecc24c873782f, 0x8ed400668c0c28c9}, - {0xa2f67f2dfa90563b, 0x728900802f0f32fb}, - {0xcbb41ef979346bca, 0x4f2b40a03ad2ffba}, - {0xfea126b7d78186bc, 0xe2f610c84987bfa9}, - {0x9f24b832e6b0f436, 0x0dd9ca7d2df4d7ca}, - {0xc6ede63fa05d3143, 0x91503d1c79720dbc}, - {0xf8a95fcf88747d94, 0x75a44c6397ce912b}, - {0x9b69dbe1b548ce7c, 0xc986afbe3ee11abb}, - {0xc24452da229b021b, 0xfbe85badce996169}, - {0xf2d56790ab41c2a2, 0xfae27299423fb9c4}, - {0x97c560ba6b0919a5, 0xdccd879fc967d41b}, - {0xbdb6b8e905cb600f, 0x5400e987bbc1c921}, - {0xed246723473e3813, 0x290123e9aab23b69}, - {0x9436c0760c86e30b, 0xf9a0b6720aaf6522}, - {0xb94470938fa89bce, 0xf808e40e8d5b3e6a}, - {0xe7958cb87392c2c2, 0xb60b1d1230b20e05}, - {0x90bd77f3483bb9b9, 0xb1c6f22b5e6f48c3}, - {0xb4ecd5f01a4aa828, 0x1e38aeb6360b1af4}, - {0xe2280b6c20dd5232, 0x25c6da63c38de1b1}, - {0x8d590723948a535f, 0x579c487e5a38ad0f}, - {0xb0af48ec79ace837, 0x2d835a9df0c6d852}, - {0xdcdb1b2798182244, 0xf8e431456cf88e66}, - {0x8a08f0f8bf0f156b, 0x1b8e9ecb641b5900}, - {0xac8b2d36eed2dac5, 0xe272467e3d222f40}, - {0xd7adf884aa879177, 0x5b0ed81dcc6abb10}, - {0x86ccbb52ea94baea, 0x98e947129fc2b4ea}, - {0xa87fea27a539e9a5, 0x3f2398d747b36225}, - {0xd29fe4b18e88640e, 0x8eec7f0d19a03aae}, - {0x83a3eeeef9153e89, 0x1953cf68300424ad}, - {0xa48ceaaab75a8e2b, 0x5fa8c3423c052dd8}, - {0xcdb02555653131b6, 0x3792f412cb06794e}, - {0x808e17555f3ebf11, 0xe2bbd88bbee40bd1}, - {0xa0b19d2ab70e6ed6, 0x5b6aceaeae9d0ec5}, - {0xc8de047564d20a8b, 0xf245825a5a445276}, - {0xfb158592be068d2e, 0xeed6e2f0f0d56713}, - {0x9ced737bb6c4183d, 0x55464dd69685606c}, - {0xc428d05aa4751e4c, 0xaa97e14c3c26b887}, - {0xf53304714d9265df, 0xd53dd99f4b3066a9}, - {0x993fe2c6d07b7fab, 0xe546a8038efe402a}, - {0xbf8fdb78849a5f96, 0xde98520472bdd034}, - {0xef73d256a5c0f77c, 0x963e66858f6d4441}, - {0x95a8637627989aad, 0xdde7001379a44aa9}, - {0xbb127c53b17ec159, 0x5560c018580d5d53}, - {0xe9d71b689dde71af, 0xaab8f01e6e10b4a7}, - {0x9226712162ab070d, 0xcab3961304ca70e9}, - {0xb6b00d69bb55c8d1, 0x3d607b97c5fd0d23}, - {0xe45c10c42a2b3b05, 0x8cb89a7db77c506b}, - {0x8eb98a7a9a5b04e3, 0x77f3608e92adb243}, - {0xb267ed1940f1c61c, 0x55f038b237591ed4}, - {0xdf01e85f912e37a3, 0x6b6c46dec52f6689}, - {0x8b61313bbabce2c6, 0x2323ac4b3b3da016}, - {0xae397d8aa96c1b77, 0xabec975e0a0d081b}, - {0xd9c7dced53c72255, 0x96e7bd358c904a22}, - {0x881cea14545c7575, 0x7e50d64177da2e55}, - {0xaa242499697392d2, 0xdde50bd1d5d0b9ea}, - {0xd4ad2dbfc3d07787, 0x955e4ec64b44e865}, - {0x84ec3c97da624ab4, 0xbd5af13bef0b113f}, - {0xa6274bbdd0fadd61, 0xecb1ad8aeacdd58f}, - {0xcfb11ead453994ba, 0x67de18eda5814af3}, - {0x81ceb32c4b43fcf4, 0x80eacf948770ced8}, - {0xa2425ff75e14fc31, 0xa1258379a94d028e}, - {0xcad2f7f5359a3b3e, 0x096ee45813a04331}, - {0xfd87b5f28300ca0d, 0x8bca9d6e188853fd}, - {0x9e74d1b791e07e48, 0x775ea264cf55347e}, - {0xc612062576589dda, 0x95364afe032a819e}, - {0xf79687aed3eec551, 0x3a83ddbd83f52205}, - {0x9abe14cd44753b52, 0xc4926a9672793543}, - {0xc16d9a0095928a27, 0x75b7053c0f178294}, - {0xf1c90080baf72cb1, 0x5324c68b12dd6339}, - {0x971da05074da7bee, 0xd3f6fc16ebca5e04}, - {0xbce5086492111aea, 0x88f4bb1ca6bcf585}, - {0xec1e4a7db69561a5, 0x2b31e9e3d06c32e6}, - {0x9392ee8e921d5d07, 0x3aff322e62439fd0}, - {0xb877aa3236a4b449, 0x09befeb9fad487c3}, - {0xe69594bec44de15b, 0x4c2ebe687989a9b4}, - {0x901d7cf73ab0acd9, 0x0f9d37014bf60a11}, - {0xb424dc35095cd80f, 0x538484c19ef38c95}, - {0xe12e13424bb40e13, 0x2865a5f206b06fba}, - {0x8cbccc096f5088cb, 0xf93f87b7442e45d4}, - {0xafebff0bcb24aafe, 0xf78f69a51539d749}, - {0xdbe6fecebdedd5be, 0xb573440e5a884d1c}, - {0x89705f4136b4a597, 0x31680a88f8953031}, - {0xabcc77118461cefc, 0xfdc20d2b36ba7c3e}, - {0xd6bf94d5e57a42bc, 0x3d32907604691b4d}, - {0x8637bd05af6c69b5, 0xa63f9a49c2c1b110}, - {0xa7c5ac471b478423, 0x0fcf80dc33721d54}, - {0xd1b71758e219652b, 0xd3c36113404ea4a9}, - {0x83126e978d4fdf3b, 0x645a1cac083126ea}, - {0xa3d70a3d70a3d70a, 0x3d70a3d70a3d70a4}, - {0xcccccccccccccccc, 0xcccccccccccccccd}, - {0x8000000000000000, 0x0000000000000000}, - {0xa000000000000000, 0x0000000000000000}, - {0xc800000000000000, 0x0000000000000000}, - {0xfa00000000000000, 0x0000000000000000}, - {0x9c40000000000000, 0x0000000000000000}, - {0xc350000000000000, 0x0000000000000000}, - {0xf424000000000000, 0x0000000000000000}, - {0x9896800000000000, 0x0000000000000000}, - {0xbebc200000000000, 0x0000000000000000}, - {0xee6b280000000000, 0x0000000000000000}, - {0x9502f90000000000, 0x0000000000000000}, - {0xba43b74000000000, 0x0000000000000000}, - {0xe8d4a51000000000, 0x0000000000000000}, - {0x9184e72a00000000, 0x0000000000000000}, - {0xb5e620f480000000, 0x0000000000000000}, - {0xe35fa931a0000000, 0x0000000000000000}, - {0x8e1bc9bf04000000, 0x0000000000000000}, - {0xb1a2bc2ec5000000, 0x0000000000000000}, - {0xde0b6b3a76400000, 0x0000000000000000}, - {0x8ac7230489e80000, 0x0000000000000000}, - {0xad78ebc5ac620000, 0x0000000000000000}, - {0xd8d726b7177a8000, 0x0000000000000000}, - {0x878678326eac9000, 0x0000000000000000}, - {0xa968163f0a57b400, 0x0000000000000000}, - {0xd3c21bcecceda100, 0x0000000000000000}, - {0x84595161401484a0, 0x0000000000000000}, - {0xa56fa5b99019a5c8, 0x0000000000000000}, - {0xcecb8f27f4200f3a, 0x0000000000000000}, - {0x813f3978f8940984, 0x4000000000000000}, - {0xa18f07d736b90be5, 0x5000000000000000}, - {0xc9f2c9cd04674ede, 0xa400000000000000}, - {0xfc6f7c4045812296, 0x4d00000000000000}, - {0x9dc5ada82b70b59d, 0xf020000000000000}, - {0xc5371912364ce305, 0x6c28000000000000}, - {0xf684df56c3e01bc6, 0xc732000000000000}, - {0x9a130b963a6c115c, 0x3c7f400000000000}, - {0xc097ce7bc90715b3, 0x4b9f100000000000}, - {0xf0bdc21abb48db20, 0x1e86d40000000000}, - {0x96769950b50d88f4, 0x1314448000000000}, - {0xbc143fa4e250eb31, 0x17d955a000000000}, - {0xeb194f8e1ae525fd, 0x5dcfab0800000000}, - {0x92efd1b8d0cf37be, 0x5aa1cae500000000}, - {0xb7abc627050305ad, 0xf14a3d9e40000000}, - {0xe596b7b0c643c719, 0x6d9ccd05d0000000}, - {0x8f7e32ce7bea5c6f, 0xe4820023a2000000}, - {0xb35dbf821ae4f38b, 0xdda2802c8a800000}, - {0xe0352f62a19e306e, 0xd50b2037ad200000}, - {0x8c213d9da502de45, 0x4526f422cc340000}, - {0xaf298d050e4395d6, 0x9670b12b7f410000}, - {0xdaf3f04651d47b4c, 0x3c0cdd765f114000}, - {0x88d8762bf324cd0f, 0xa5880a69fb6ac800}, - {0xab0e93b6efee0053, 0x8eea0d047a457a00}, - {0xd5d238a4abe98068, 0x72a4904598d6d880}, - {0x85a36366eb71f041, 0x47a6da2b7f864750}, - {0xa70c3c40a64e6c51, 0x999090b65f67d924}, - {0xd0cf4b50cfe20765, 0xfff4b4e3f741cf6d}, - {0x82818f1281ed449f, 0xbff8f10e7a8921a5}, - {0xa321f2d7226895c7, 0xaff72d52192b6a0e}, - {0xcbea6f8ceb02bb39, 0x9bf4f8a69f764491}, - {0xfee50b7025c36a08, 0x02f236d04753d5b5}, - {0x9f4f2726179a2245, 0x01d762422c946591}, - {0xc722f0ef9d80aad6, 0x424d3ad2b7b97ef6}, - {0xf8ebad2b84e0d58b, 0xd2e0898765a7deb3}, - {0x9b934c3b330c8577, 0x63cc55f49f88eb30}, - {0xc2781f49ffcfa6d5, 0x3cbf6b71c76b25fc}, - {0xf316271c7fc3908a, 0x8bef464e3945ef7b}, - {0x97edd871cfda3a56, 0x97758bf0e3cbb5ad}, - {0xbde94e8e43d0c8ec, 0x3d52eeed1cbea318}, - {0xed63a231d4c4fb27, 0x4ca7aaa863ee4bde}, - {0x945e455f24fb1cf8, 0x8fe8caa93e74ef6b}, - {0xb975d6b6ee39e436, 0xb3e2fd538e122b45}, - {0xe7d34c64a9c85d44, 0x60dbbca87196b617}, - {0x90e40fbeea1d3a4a, 0xbc8955e946fe31ce}, - {0xb51d13aea4a488dd, 0x6babab6398bdbe42}, - {0xe264589a4dcdab14, 0xc696963c7eed2dd2}, - {0x8d7eb76070a08aec, 0xfc1e1de5cf543ca3}, - {0xb0de65388cc8ada8, 0x3b25a55f43294bcc}, - {0xdd15fe86affad912, 0x49ef0eb713f39ebf}, - {0x8a2dbf142dfcc7ab, 0x6e3569326c784338}, - {0xacb92ed9397bf996, 0x49c2c37f07965405}, - {0xd7e77a8f87daf7fb, 0xdc33745ec97be907}, - {0x86f0ac99b4e8dafd, 0x69a028bb3ded71a4}, - {0xa8acd7c0222311bc, 0xc40832ea0d68ce0d}, - {0xd2d80db02aabd62b, 0xf50a3fa490c30191}, - {0x83c7088e1aab65db, 0x792667c6da79e0fb}, - {0xa4b8cab1a1563f52, 0x577001b891185939}, - {0xcde6fd5e09abcf26, 0xed4c0226b55e6f87}, - {0x80b05e5ac60b6178, 0x544f8158315b05b5}, - {0xa0dc75f1778e39d6, 0x696361ae3db1c722}, - {0xc913936dd571c84c, 0x03bc3a19cd1e38ea}, - {0xfb5878494ace3a5f, 0x04ab48a04065c724}, - {0x9d174b2dcec0e47b, 0x62eb0d64283f9c77}, - {0xc45d1df942711d9a, 0x3ba5d0bd324f8395}, - {0xf5746577930d6500, 0xca8f44ec7ee3647a}, - {0x9968bf6abbe85f20, 0x7e998b13cf4e1ecc}, - {0xbfc2ef456ae276e8, 0x9e3fedd8c321a67f}, - {0xefb3ab16c59b14a2, 0xc5cfe94ef3ea101f}, - {0x95d04aee3b80ece5, 0xbba1f1d158724a13}, - {0xbb445da9ca61281f, 0x2a8a6e45ae8edc98}, - {0xea1575143cf97226, 0xf52d09d71a3293be}, - {0x924d692ca61be758, 0x593c2626705f9c57}, - {0xb6e0c377cfa2e12e, 0x6f8b2fb00c77836d}, - {0xe498f455c38b997a, 0x0b6dfb9c0f956448}, - {0x8edf98b59a373fec, 0x4724bd4189bd5ead}, - {0xb2977ee300c50fe7, 0x58edec91ec2cb658}, - {0xdf3d5e9bc0f653e1, 0x2f2967b66737e3ee}, - {0x8b865b215899f46c, 0xbd79e0d20082ee75}, - {0xae67f1e9aec07187, 0xecd8590680a3aa12}, - {0xda01ee641a708de9, 0xe80e6f4820cc9496}, - {0x884134fe908658b2, 0x3109058d147fdcde}, - {0xaa51823e34a7eede, 0xbd4b46f0599fd416}, - {0xd4e5e2cdc1d1ea96, 0x6c9e18ac7007c91b}, - {0x850fadc09923329e, 0x03e2cf6bc604ddb1}, - {0xa6539930bf6bff45, 0x84db8346b786151d}, - {0xcfe87f7cef46ff16, 0xe612641865679a64}, - {0x81f14fae158c5f6e, 0x4fcb7e8f3f60c07f}, - {0xa26da3999aef7749, 0xe3be5e330f38f09e}, - {0xcb090c8001ab551c, 0x5cadf5bfd3072cc6}, - {0xfdcb4fa002162a63, 0x73d9732fc7c8f7f7}, - {0x9e9f11c4014dda7e, 0x2867e7fddcdd9afb}, - {0xc646d63501a1511d, 0xb281e1fd541501b9}, - {0xf7d88bc24209a565, 0x1f225a7ca91a4227}, - {0x9ae757596946075f, 0x3375788de9b06959}, - {0xc1a12d2fc3978937, 0x0052d6b1641c83af}, - {0xf209787bb47d6b84, 0xc0678c5dbd23a49b}, - {0x9745eb4d50ce6332, 0xf840b7ba963646e1}, - {0xbd176620a501fbff, 0xb650e5a93bc3d899}, - {0xec5d3fa8ce427aff, 0xa3e51f138ab4cebf}, - {0x93ba47c980e98cdf, 0xc66f336c36b10138}, - {0xb8a8d9bbe123f017, 0xb80b0047445d4185}, - {0xe6d3102ad96cec1d, 0xa60dc059157491e6}, - {0x9043ea1ac7e41392, 0x87c89837ad68db30}, - {0xb454e4a179dd1877, 0x29babe4598c311fc}, - {0xe16a1dc9d8545e94, 0xf4296dd6fef3d67b}, - {0x8ce2529e2734bb1d, 0x1899e4a65f58660d}, - {0xb01ae745b101e9e4, 0x5ec05dcff72e7f90}, - {0xdc21a1171d42645d, 0x76707543f4fa1f74}, - {0x899504ae72497eba, 0x6a06494a791c53a9}, - {0xabfa45da0edbde69, 0x0487db9d17636893}, - {0xd6f8d7509292d603, 0x45a9d2845d3c42b7}, - {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3}, - {0xa7f26836f282b732, 0x8e6cac7768d7141f}, - {0xd1ef0244af2364ff, 0x3207d795430cd927}, - {0x8335616aed761f1f, 0x7f44e6bd49e807b9}, - {0xa402b9c5a8d3a6e7, 0x5f16206c9c6209a7}, - {0xcd036837130890a1, 0x36dba887c37a8c10}, - {0x802221226be55a64, 0xc2494954da2c978a}, - {0xa02aa96b06deb0fd, 0xf2db9baa10b7bd6d}, - {0xc83553c5c8965d3d, 0x6f92829494e5acc8}, - {0xfa42a8b73abbf48c, 0xcb772339ba1f17fa}, - {0x9c69a97284b578d7, 0xff2a760414536efc}, - {0xc38413cf25e2d70d, 0xfef5138519684abb}, - {0xf46518c2ef5b8cd1, 0x7eb258665fc25d6a}, - {0x98bf2f79d5993802, 0xef2f773ffbd97a62}, - {0xbeeefb584aff8603, 0xaafb550ffacfd8fb}, - {0xeeaaba2e5dbf6784, 0x95ba2a53f983cf39}, - {0x952ab45cfa97a0b2, 0xdd945a747bf26184}, - {0xba756174393d88df, 0x94f971119aeef9e5}, - {0xe912b9d1478ceb17, 0x7a37cd5601aab85e}, - {0x91abb422ccb812ee, 0xac62e055c10ab33b}, - {0xb616a12b7fe617aa, 0x577b986b314d600a}, - {0xe39c49765fdf9d94, 0xed5a7e85fda0b80c}, - {0x8e41ade9fbebc27d, 0x14588f13be847308}, - {0xb1d219647ae6b31c, 0x596eb2d8ae258fc9}, - {0xde469fbd99a05fe3, 0x6fca5f8ed9aef3bc}, - {0x8aec23d680043bee, 0x25de7bb9480d5855}, - {0xada72ccc20054ae9, 0xaf561aa79a10ae6b}, - {0xd910f7ff28069da4, 0x1b2ba1518094da05}, - {0x87aa9aff79042286, 0x90fb44d2f05d0843}, - {0xa99541bf57452b28, 0x353a1607ac744a54}, - {0xd3fa922f2d1675f2, 0x42889b8997915ce9}, - {0x847c9b5d7c2e09b7, 0x69956135febada12}, - {0xa59bc234db398c25, 0x43fab9837e699096}, - {0xcf02b2c21207ef2e, 0x94f967e45e03f4bc}, - {0x8161afb94b44f57d, 0x1d1be0eebac278f6}, - {0xa1ba1ba79e1632dc, 0x6462d92a69731733}, - {0xca28a291859bbf93, 0x7d7b8f7503cfdcff}, - {0xfcb2cb35e702af78, 0x5cda735244c3d43f}, - {0x9defbf01b061adab, 0x3a0888136afa64a8}, - {0xc56baec21c7a1916, 0x088aaa1845b8fdd1}, - {0xf6c69a72a3989f5b, 0x8aad549e57273d46}, - {0x9a3c2087a63f6399, 0x36ac54e2f678864c}, - {0xc0cb28a98fcf3c7f, 0x84576a1bb416a7de}, - {0xf0fdf2d3f3c30b9f, 0x656d44a2a11c51d6}, - {0x969eb7c47859e743, 0x9f644ae5a4b1b326}, - {0xbc4665b596706114, 0x873d5d9f0dde1fef}, - {0xeb57ff22fc0c7959, 0xa90cb506d155a7eb}, - {0x9316ff75dd87cbd8, 0x09a7f12442d588f3}, - {0xb7dcbf5354e9bece, 0x0c11ed6d538aeb30}, - {0xe5d3ef282a242e81, 0x8f1668c8a86da5fb}, - {0x8fa475791a569d10, 0xf96e017d694487bd}, - {0xb38d92d760ec4455, 0x37c981dcc395a9ad}, - {0xe070f78d3927556a, 0x85bbe253f47b1418}, - {0x8c469ab843b89562, 0x93956d7478ccec8f}, - {0xaf58416654a6babb, 0x387ac8d1970027b3}, - {0xdb2e51bfe9d0696a, 0x06997b05fcc0319f}, - {0x88fcf317f22241e2, 0x441fece3bdf81f04}, - {0xab3c2fddeeaad25a, 0xd527e81cad7626c4}, - {0xd60b3bd56a5586f1, 0x8a71e223d8d3b075}, - {0x85c7056562757456, 0xf6872d5667844e4a}, - {0xa738c6bebb12d16c, 0xb428f8ac016561dc}, - {0xd106f86e69d785c7, 0xe13336d701beba53}, - {0x82a45b450226b39c, 0xecc0024661173474}, - {0xa34d721642b06084, 0x27f002d7f95d0191}, - {0xcc20ce9bd35c78a5, 0x31ec038df7b441f5}, - {0xff290242c83396ce, 0x7e67047175a15272}, - {0x9f79a169bd203e41, 0x0f0062c6e984d387}, - {0xc75809c42c684dd1, 0x52c07b78a3e60869}, - {0xf92e0c3537826145, 0xa7709a56ccdf8a83}, - {0x9bbcc7a142b17ccb, 0x88a66076400bb692}, - {0xc2abf989935ddbfe, 0x6acff893d00ea436}, - {0xf356f7ebf83552fe, 0x0583f6b8c4124d44}, - {0x98165af37b2153de, 0xc3727a337a8b704b}, - {0xbe1bf1b059e9a8d6, 0x744f18c0592e4c5d}, - {0xeda2ee1c7064130c, 0x1162def06f79df74}, - {0x9485d4d1c63e8be7, 0x8addcb5645ac2ba9}, - {0xb9a74a0637ce2ee1, 0x6d953e2bd7173693}, - {0xe8111c87c5c1ba99, 0xc8fa8db6ccdd0438}, - {0x910ab1d4db9914a0, 0x1d9c9892400a22a3}, - {0xb54d5e4a127f59c8, 0x2503beb6d00cab4c}, - {0xe2a0b5dc971f303a, 0x2e44ae64840fd61e}, - {0x8da471a9de737e24, 0x5ceaecfed289e5d3}, - {0xb10d8e1456105dad, 0x7425a83e872c5f48}, - {0xdd50f1996b947518, 0xd12f124e28f7771a}, - {0x8a5296ffe33cc92f, 0x82bd6b70d99aaa70}, - {0xace73cbfdc0bfb7b, 0x636cc64d1001550c}, - {0xd8210befd30efa5a, 0x3c47f7e05401aa4f}, - {0x8714a775e3e95c78, 0x65acfaec34810a72}, - {0xa8d9d1535ce3b396, 0x7f1839a741a14d0e}, - {0xd31045a8341ca07c, 0x1ede48111209a051}, - {0x83ea2b892091e44d, 0x934aed0aab460433}, - {0xa4e4b66b68b65d60, 0xf81da84d56178540}, - {0xce1de40642e3f4b9, 0x36251260ab9d668f}, - {0x80d2ae83e9ce78f3, 0xc1d72b7c6b42601a}, - {0xa1075a24e4421730, 0xb24cf65b8612f820}, - {0xc94930ae1d529cfc, 0xdee033f26797b628}, - {0xfb9b7cd9a4a7443c, 0x169840ef017da3b2}, - {0x9d412e0806e88aa5, 0x8e1f289560ee864f}, - {0xc491798a08a2ad4e, 0xf1a6f2bab92a27e3}, - {0xf5b5d7ec8acb58a2, 0xae10af696774b1dc}, - {0x9991a6f3d6bf1765, 0xacca6da1e0a8ef2a}, - {0xbff610b0cc6edd3f, 0x17fd090a58d32af4}, - {0xeff394dcff8a948e, 0xddfc4b4cef07f5b1}, - {0x95f83d0a1fb69cd9, 0x4abdaf101564f98f}, - {0xbb764c4ca7a4440f, 0x9d6d1ad41abe37f2}, - {0xea53df5fd18d5513, 0x84c86189216dc5ee}, - {0x92746b9be2f8552c, 0x32fd3cf5b4e49bb5}, - {0xb7118682dbb66a77, 0x3fbc8c33221dc2a2}, - {0xe4d5e82392a40515, 0x0fabaf3feaa5334b}, - {0x8f05b1163ba6832d, 0x29cb4d87f2a7400f}, - {0xb2c71d5bca9023f8, 0x743e20e9ef511013}, - {0xdf78e4b2bd342cf6, 0x914da9246b255417}, - {0x8bab8eefb6409c1a, 0x1ad089b6c2f7548f}, - {0xae9672aba3d0c320, 0xa184ac2473b529b2}, - {0xda3c0f568cc4f3e8, 0xc9e5d72d90a2741f}, - {0x8865899617fb1871, 0x7e2fa67c7a658893}, - {0xaa7eebfb9df9de8d, 0xddbb901b98feeab8}, - {0xd51ea6fa85785631, 0x552a74227f3ea566}, - {0x8533285c936b35de, 0xd53a88958f872760}, - {0xa67ff273b8460356, 0x8a892abaf368f138}, - {0xd01fef10a657842c, 0x2d2b7569b0432d86}, - {0x8213f56a67f6b29b, 0x9c3b29620e29fc74}, - {0xa298f2c501f45f42, 0x8349f3ba91b47b90}, - {0xcb3f2f7642717713, 0x241c70a936219a74}, - {0xfe0efb53d30dd4d7, 0xed238cd383aa0111}, - {0x9ec95d1463e8a506, 0xf4363804324a40ab}, - {0xc67bb4597ce2ce48, 0xb143c6053edcd0d6}, - {0xf81aa16fdc1b81da, 0xdd94b7868e94050b}, - {0x9b10a4e5e9913128, 0xca7cf2b4191c8327}, - {0xc1d4ce1f63f57d72, 0xfd1c2f611f63a3f1}, - {0xf24a01a73cf2dccf, 0xbc633b39673c8ced}, - {0x976e41088617ca01, 0xd5be0503e085d814}, - {0xbd49d14aa79dbc82, 0x4b2d8644d8a74e19}, - {0xec9c459d51852ba2, 0xddf8e7d60ed1219f}, - {0x93e1ab8252f33b45, 0xcabb90e5c942b504}, - {0xb8da1662e7b00a17, 0x3d6a751f3b936244}, - {0xe7109bfba19c0c9d, 0x0cc512670a783ad5}, - {0x906a617d450187e2, 0x27fb2b80668b24c6}, - {0xb484f9dc9641e9da, 0xb1f9f660802dedf7}, - {0xe1a63853bbd26451, 0x5e7873f8a0396974}, - {0x8d07e33455637eb2, 0xdb0b487b6423e1e9}, - {0xb049dc016abc5e5f, 0x91ce1a9a3d2cda63}, - {0xdc5c5301c56b75f7, 0x7641a140cc7810fc}, - {0x89b9b3e11b6329ba, 0xa9e904c87fcb0a9e}, - {0xac2820d9623bf429, 0x546345fa9fbdcd45}, - {0xd732290fbacaf133, 0xa97c177947ad4096}, - {0x867f59a9d4bed6c0, 0x49ed8eabcccc485e}, - {0xa81f301449ee8c70, 0x5c68f256bfff5a75}, - {0xd226fc195c6a2f8c, 0x73832eec6fff3112}, - {0x83585d8fd9c25db7, 0xc831fd53c5ff7eac}, - {0xa42e74f3d032f525, 0xba3e7ca8b77f5e56}, - {0xcd3a1230c43fb26f, 0x28ce1bd2e55f35ec}, - {0x80444b5e7aa7cf85, 0x7980d163cf5b81b4}, - {0xa0555e361951c366, 0xd7e105bcc3326220}, - {0xc86ab5c39fa63440, 0x8dd9472bf3fefaa8}, - {0xfa856334878fc150, 0xb14f98f6f0feb952}, - {0x9c935e00d4b9d8d2, 0x6ed1bf9a569f33d4}, - {0xc3b8358109e84f07, 0x0a862f80ec4700c9}, - {0xf4a642e14c6262c8, 0xcd27bb612758c0fb}, - {0x98e7e9cccfbd7dbd, 0x8038d51cb897789d}, - {0xbf21e44003acdd2c, 0xe0470a63e6bd56c4}, - {0xeeea5d5004981478, 0x1858ccfce06cac75}, - {0x95527a5202df0ccb, 0x0f37801e0c43ebc9}, - {0xbaa718e68396cffd, 0xd30560258f54e6bb}, - {0xe950df20247c83fd, 0x47c6b82ef32a206a}, - {0x91d28b7416cdd27e, 0x4cdc331d57fa5442}, - {0xb6472e511c81471d, 0xe0133fe4adf8e953}, - {0xe3d8f9e563a198e5, 0x58180fddd97723a7}, - {0x8e679c2f5e44ff8f, 0x570f09eaa7ea7649}, - {0xb201833b35d63f73, 0x2cd2cc6551e513db}, - {0xde81e40a034bcf4f, 0xf8077f7ea65e58d2}, - {0x8b112e86420f6191, 0xfb04afaf27faf783}, - {0xadd57a27d29339f6, 0x79c5db9af1f9b564}, - {0xd94ad8b1c7380874, 0x18375281ae7822bd}, - {0x87cec76f1c830548, 0x8f2293910d0b15b6}, - {0xa9c2794ae3a3c69a, 0xb2eb3875504ddb23}, - {0xd433179d9c8cb841, 0x5fa60692a46151ec}, - {0x849feec281d7f328, 0xdbc7c41ba6bcd334}, - {0xa5c7ea73224deff3, 0x12b9b522906c0801}, - {0xcf39e50feae16bef, 0xd768226b34870a01}, - {0x81842f29f2cce375, 0xe6a1158300d46641}, - {0xa1e53af46f801c53, 0x60495ae3c1097fd1}, - {0xca5e89b18b602368, 0x385bb19cb14bdfc5}, - {0xfcf62c1dee382c42, 0x46729e03dd9ed7b6}, - {0x9e19db92b4e31ba9, 0x6c07a2c26a8346d2}, - {0xc5a05277621be293, 0xc7098b7305241886}, - {0xf70867153aa2db38, 0xb8cbee4fc66d1ea8}, - {0x9a65406d44a5c903, 0x737f74f1dc043329}, - {0xc0fe908895cf3b44, 0x505f522e53053ff3}, - {0xf13e34aabb430a15, 0x647726b9e7c68ff0}, - {0x96c6e0eab509e64d, 0x5eca783430dc19f6}, - {0xbc789925624c5fe0, 0xb67d16413d132073}, - {0xeb96bf6ebadf77d8, 0xe41c5bd18c57e890}, - {0x933e37a534cbaae7, 0x8e91b962f7b6f15a}, - {0xb80dc58e81fe95a1, 0x723627bbb5a4adb1}, - {0xe61136f2227e3b09, 0xcec3b1aaa30dd91d}, - {0x8fcac257558ee4e6, 0x213a4f0aa5e8a7b2}, - {0xb3bd72ed2af29e1f, 0xa988e2cd4f62d19e}, - {0xe0accfa875af45a7, 0x93eb1b80a33b8606}, - {0x8c6c01c9498d8b88, 0xbc72f130660533c4}, - {0xaf87023b9bf0ee6a, 0xeb8fad7c7f8680b5}, - {0xdb68c2ca82ed2a05, 0xa67398db9f6820e2}, -#else - {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b}, - {0xce5d73ff402d98e3, 0xfb0a3d212dc81290}, - {0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481f}, - {0x86a8d39ef77164bc, 0xae5dff9c02033198}, - {0xd98ddaee19068c76, 0x3badd624dd9b0958}, - {0xafbd2350644eeacf, 0xe5d1929ef90898fb}, - {0x8df5efabc5979c8f, 0xca8d3ffa1ef463c2}, - {0xe55990879ddcaabd, 0xcc420a6a101d0516}, - {0xb94470938fa89bce, 0xf808e40e8d5b3e6a}, - {0x95a8637627989aad, 0xdde7001379a44aa9}, - {0xf1c90080baf72cb1, 0x5324c68b12dd6339}, - {0xc350000000000000, 0x0000000000000000}, - {0x9dc5ada82b70b59d, 0xf020000000000000}, - {0xfee50b7025c36a08, 0x02f236d04753d5b5}, - {0xcde6fd5e09abcf26, 0xed4c0226b55e6f87}, - {0xa6539930bf6bff45, 0x84db8346b786151d}, - {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3}, - {0xd910f7ff28069da4, 0x1b2ba1518094da05}, - {0xaf58416654a6babb, 0x387ac8d1970027b3}, - {0x8da471a9de737e24, 0x5ceaecfed289e5d3}, - {0xe4d5e82392a40515, 0x0fabaf3feaa5334b}, - {0xb8da1662e7b00a17, 0x3d6a751f3b936244}, - {0x95527a5202df0ccb, 0x0f37801e0c43ebc9}, - {0xf13e34aabb430a15, 0x647726b9e7c68ff0} -#endif - }; - -#if FMT_USE_FULL_CACHE_DRAGONBOX - return pow10_significands[k - float_info::min_k]; -#else - static constexpr const uint64_t powers_of_5_64[] = { - 0x0000000000000001, 0x0000000000000005, 0x0000000000000019, - 0x000000000000007d, 0x0000000000000271, 0x0000000000000c35, - 0x0000000000003d09, 0x000000000001312d, 0x000000000005f5e1, - 0x00000000001dcd65, 0x00000000009502f9, 0x0000000002e90edd, - 0x000000000e8d4a51, 0x0000000048c27395, 0x000000016bcc41e9, - 0x000000071afd498d, 0x0000002386f26fc1, 0x000000b1a2bc2ec5, - 0x000003782dace9d9, 0x00001158e460913d, 0x000056bc75e2d631, - 0x0001b1ae4d6e2ef5, 0x000878678326eac9, 0x002a5a058fc295ed, - 0x00d3c21bcecceda1, 0x0422ca8b0a00a425, 0x14adf4b7320334b9}; - - static const int compression_ratio = 27; - - // Compute base index. - int cache_index = (k - float_info::min_k) / compression_ratio; - int kb = cache_index * compression_ratio + float_info::min_k; - int offset = k - kb; - - // Get base cache. - uint128_fallback base_cache = pow10_significands[cache_index]; - if (offset == 0) return base_cache; - - // Compute the required amount of bit-shift. - int alpha = floor_log2_pow10(kb + offset) - floor_log2_pow10(kb) - offset; - FMT_ASSERT(alpha > 0 && alpha < 64, "shifting error detected"); - - // Try to recover the real cache. - uint64_t pow5 = powers_of_5_64[offset]; - uint128_fallback recovered_cache = umul128(base_cache.high(), pow5); - uint128_fallback middle_low = umul128(base_cache.low(), pow5); - - recovered_cache += middle_low.high(); - - uint64_t high_to_middle = recovered_cache.high() << (64 - alpha); - uint64_t middle_to_low = recovered_cache.low() << (64 - alpha); - - recovered_cache = - uint128_fallback{(recovered_cache.low() >> alpha) | high_to_middle, - ((middle_low.low() >> alpha) | middle_to_low)}; - FMT_ASSERT(recovered_cache.low() + 1 != 0, ""); - return {recovered_cache.high(), recovered_cache.low() + 1}; -#endif - } - - struct compute_mul_result { - carrier_uint result; - bool is_integer; - }; - struct compute_mul_parity_result { - bool parity; - bool is_integer; - }; - - static auto compute_mul(carrier_uint u, - const cache_entry_type& cache) noexcept - -> compute_mul_result { - auto r = umul192_upper128(u, cache); - return {r.high(), r.low() == 0}; - } - - static auto compute_delta(cache_entry_type const& cache, int beta) noexcept - -> uint32_t { - return static_cast(cache.high() >> (64 - 1 - beta)); - } - - static auto compute_mul_parity(carrier_uint two_f, - const cache_entry_type& cache, - int beta) noexcept - -> compute_mul_parity_result { - FMT_ASSERT(beta >= 1, ""); - FMT_ASSERT(beta < 64, ""); - - auto r = umul192_lower128(two_f, cache); - return {((r.high() >> (64 - beta)) & 1) != 0, - ((r.high() << beta) | (r.low() >> (64 - beta))) == 0}; - } - - static auto compute_left_endpoint_for_shorter_interval_case( - const cache_entry_type& cache, int beta) noexcept -> carrier_uint { - return (cache.high() - - (cache.high() >> (num_significand_bits() + 2))) >> - (64 - num_significand_bits() - 1 - beta); - } - - static auto compute_right_endpoint_for_shorter_interval_case( - const cache_entry_type& cache, int beta) noexcept -> carrier_uint { - return (cache.high() + - (cache.high() >> (num_significand_bits() + 1))) >> - (64 - num_significand_bits() - 1 - beta); - } - - static auto compute_round_up_for_shorter_interval_case( - const cache_entry_type& cache, int beta) noexcept -> carrier_uint { - return ((cache.high() >> (64 - num_significand_bits() - 2 - beta)) + - 1) / - 2; - } -}; - -FMT_FUNC auto get_cached_power(int k) noexcept -> uint128_fallback { - return cache_accessor::get_cached_power(k); -} - -// Various integer checks -template -auto is_left_endpoint_integer_shorter_interval(int exponent) noexcept -> bool { - const int case_shorter_interval_left_endpoint_lower_threshold = 2; - const int case_shorter_interval_left_endpoint_upper_threshold = 3; - return exponent >= case_shorter_interval_left_endpoint_lower_threshold && - exponent <= case_shorter_interval_left_endpoint_upper_threshold; -} - -// Remove trailing zeros from n and return the number of zeros removed (float) -FMT_INLINE int remove_trailing_zeros(uint32_t& n, int s = 0) noexcept { - FMT_ASSERT(n != 0, ""); - // Modular inverse of 5 (mod 2^32): (mod_inv_5 * 5) mod 2^32 = 1. - constexpr uint32_t mod_inv_5 = 0xcccccccd; - constexpr uint32_t mod_inv_25 = 0xc28f5c29; // = mod_inv_5 * mod_inv_5 - - while (true) { - auto q = rotr(n * mod_inv_25, 2); - if (q > max_value() / 100) break; - n = q; - s += 2; - } - auto q = rotr(n * mod_inv_5, 1); - if (q <= max_value() / 10) { - n = q; - s |= 1; - } - return s; -} - -// Removes trailing zeros and returns the number of zeros removed (double) -FMT_INLINE int remove_trailing_zeros(uint64_t& n) noexcept { - FMT_ASSERT(n != 0, ""); - - // This magic number is ceil(2^90 / 10^8). - constexpr uint64_t magic_number = 12379400392853802749ull; - auto nm = umul128(n, magic_number); - - // Is n is divisible by 10^8? - if ((nm.high() & ((1ull << (90 - 64)) - 1)) == 0 && nm.low() < magic_number) { - // If yes, work with the quotient... - auto n32 = static_cast(nm.high() >> (90 - 64)); - // ... and use the 32 bit variant of the function - int s = remove_trailing_zeros(n32, 8); - n = n32; - return s; - } - - // If n is not divisible by 10^8, work with n itself. - constexpr uint64_t mod_inv_5 = 0xcccccccccccccccd; - constexpr uint64_t mod_inv_25 = 0x8f5c28f5c28f5c29; // mod_inv_5 * mod_inv_5 - - int s = 0; - while (true) { - auto q = rotr(n * mod_inv_25, 2); - if (q > max_value() / 100) break; - n = q; - s += 2; - } - auto q = rotr(n * mod_inv_5, 1); - if (q <= max_value() / 10) { - n = q; - s |= 1; - } - - return s; -} - -// The main algorithm for shorter interval case -template -FMT_INLINE decimal_fp shorter_interval_case(int exponent) noexcept { - decimal_fp ret_value; - // Compute k and beta - const int minus_k = floor_log10_pow2_minus_log10_4_over_3(exponent); - const int beta = exponent + floor_log2_pow10(-minus_k); - - // Compute xi and zi - using cache_entry_type = typename cache_accessor::cache_entry_type; - const cache_entry_type cache = cache_accessor::get_cached_power(-minus_k); - - auto xi = cache_accessor::compute_left_endpoint_for_shorter_interval_case( - cache, beta); - auto zi = cache_accessor::compute_right_endpoint_for_shorter_interval_case( - cache, beta); - - // If the left endpoint is not an integer, increase it - if (!is_left_endpoint_integer_shorter_interval(exponent)) ++xi; - - // Try bigger divisor - ret_value.significand = zi / 10; - - // If succeed, remove trailing zeros if necessary and return - if (ret_value.significand * 10 >= xi) { - ret_value.exponent = minus_k + 1; - ret_value.exponent += remove_trailing_zeros(ret_value.significand); - return ret_value; - } - - // Otherwise, compute the round-up of y - ret_value.significand = - cache_accessor::compute_round_up_for_shorter_interval_case(cache, - beta); - ret_value.exponent = minus_k; - - // When tie occurs, choose one of them according to the rule - if (exponent >= float_info::shorter_interval_tie_lower_threshold && - exponent <= float_info::shorter_interval_tie_upper_threshold) { - ret_value.significand = ret_value.significand % 2 == 0 - ? ret_value.significand - : ret_value.significand - 1; - } else if (ret_value.significand < xi) { - ++ret_value.significand; - } - return ret_value; -} - -template auto to_decimal(T x) noexcept -> decimal_fp { - // Step 1: integer promotion & Schubfach multiplier calculation. - - using carrier_uint = typename float_info::carrier_uint; - using cache_entry_type = typename cache_accessor::cache_entry_type; - auto br = bit_cast(x); - - // Extract significand bits and exponent bits. - const carrier_uint significand_mask = - (static_cast(1) << num_significand_bits()) - 1; - carrier_uint significand = (br & significand_mask); - int exponent = - static_cast((br & exponent_mask()) >> num_significand_bits()); - - if (exponent != 0) { // Check if normal. - exponent -= exponent_bias() + num_significand_bits(); - - // Shorter interval case; proceed like Schubfach. - // In fact, when exponent == 1 and significand == 0, the interval is - // regular. However, it can be shown that the end-results are anyway same. - if (significand == 0) return shorter_interval_case(exponent); - - significand |= (static_cast(1) << num_significand_bits()); - } else { - // Subnormal case; the interval is always regular. - if (significand == 0) return {0, 0}; - exponent = - std::numeric_limits::min_exponent - num_significand_bits() - 1; - } - - const bool include_left_endpoint = (significand % 2 == 0); - const bool include_right_endpoint = include_left_endpoint; - - // Compute k and beta. - const int minus_k = floor_log10_pow2(exponent) - float_info::kappa; - const cache_entry_type cache = cache_accessor::get_cached_power(-minus_k); - const int beta = exponent + floor_log2_pow10(-minus_k); - - // Compute zi and deltai. - // 10^kappa <= deltai < 10^(kappa + 1) - const uint32_t deltai = cache_accessor::compute_delta(cache, beta); - const carrier_uint two_fc = significand << 1; - - // For the case of binary32, the result of integer check is not correct for - // 29711844 * 2^-82 - // = 6.1442653300000000008655037797566933477355632930994033813476... * 10^-18 - // and 29711844 * 2^-81 - // = 1.2288530660000000001731007559513386695471126586198806762695... * 10^-17, - // and they are the unique counterexamples. However, since 29711844 is even, - // this does not cause any problem for the endpoints calculations; it can only - // cause a problem when we need to perform integer check for the center. - // Fortunately, with these inputs, that branch is never executed, so we are - // fine. - const typename cache_accessor::compute_mul_result z_mul = - cache_accessor::compute_mul((two_fc | 1) << beta, cache); - - // Step 2: Try larger divisor; remove trailing zeros if necessary. - - // Using an upper bound on zi, we might be able to optimize the division - // better than the compiler; we are computing zi / big_divisor here. - decimal_fp ret_value; - ret_value.significand = divide_by_10_to_kappa_plus_1(z_mul.result); - uint32_t r = static_cast(z_mul.result - float_info::big_divisor * - ret_value.significand); - - if (r < deltai) { - // Exclude the right endpoint if necessary. - if (r == 0 && (z_mul.is_integer & !include_right_endpoint)) { - --ret_value.significand; - r = float_info::big_divisor; - goto small_divisor_case_label; - } - } else if (r > deltai) { - goto small_divisor_case_label; - } else { - // r == deltai; compare fractional parts. - const typename cache_accessor::compute_mul_parity_result x_mul = - cache_accessor::compute_mul_parity(two_fc - 1, cache, beta); - - if (!(x_mul.parity | (x_mul.is_integer & include_left_endpoint))) - goto small_divisor_case_label; - } - ret_value.exponent = minus_k + float_info::kappa + 1; - - // We may need to remove trailing zeros. - ret_value.exponent += remove_trailing_zeros(ret_value.significand); - return ret_value; - - // Step 3: Find the significand with the smaller divisor. - -small_divisor_case_label: - ret_value.significand *= 10; - ret_value.exponent = minus_k + float_info::kappa; - - uint32_t dist = r - (deltai / 2) + (float_info::small_divisor / 2); - const bool approx_y_parity = - ((dist ^ (float_info::small_divisor / 2)) & 1) != 0; - - // Is dist divisible by 10^kappa? - const bool divisible_by_small_divisor = - check_divisibility_and_divide_by_pow10::kappa>(dist); - - // Add dist / 10^kappa to the significand. - ret_value.significand += dist; - - if (!divisible_by_small_divisor) return ret_value; - - // Check z^(f) >= epsilon^(f). - // We have either yi == zi - epsiloni or yi == (zi - epsiloni) - 1, - // where yi == zi - epsiloni if and only if z^(f) >= epsilon^(f). - // Since there are only 2 possibilities, we only need to care about the - // parity. Also, zi and r should have the same parity since the divisor - // is an even number. - const auto y_mul = cache_accessor::compute_mul_parity(two_fc, cache, beta); - - // If z^(f) >= epsilon^(f), we might have a tie when z^(f) == epsilon^(f), - // or equivalently, when y is an integer. - if (y_mul.parity != approx_y_parity) - --ret_value.significand; - else if (y_mul.is_integer & (ret_value.significand % 2 != 0)) - --ret_value.significand; - return ret_value; -} -} // namespace dragonbox -} // namespace detail - -template <> struct formatter { - FMT_CONSTEXPR auto parse(format_parse_context& ctx) - -> format_parse_context::iterator { - return ctx.begin(); - } - - auto format(const detail::bigint& n, format_context& ctx) const - -> format_context::iterator { - auto out = ctx.out(); - bool first = true; - for (auto i = n.bigits_.size(); i > 0; --i) { - auto value = n.bigits_[i - 1u]; - if (first) { - out = fmt::format_to(out, FMT_STRING("{:x}"), value); - first = false; - continue; - } - out = fmt::format_to(out, FMT_STRING("{:08x}"), value); - } - if (n.exp_ > 0) - out = fmt::format_to(out, FMT_STRING("p{}"), - n.exp_ * detail::bigint::bigit_bits); - return out; - } -}; - -FMT_FUNC detail::utf8_to_utf16::utf8_to_utf16(string_view s) { - for_each_codepoint(s, [this](uint32_t cp, string_view) { - if (cp == invalid_code_point) FMT_THROW(std::runtime_error("invalid utf8")); - if (cp <= 0xFFFF) { - buffer_.push_back(static_cast(cp)); - } else { - cp -= 0x10000; - buffer_.push_back(static_cast(0xD800 + (cp >> 10))); - buffer_.push_back(static_cast(0xDC00 + (cp & 0x3FF))); - } - return true; - }); - buffer_.push_back(0); -} - -FMT_FUNC void format_system_error(detail::buffer& out, int error_code, - const char* message) noexcept { - FMT_TRY { - auto ec = std::error_code(error_code, std::generic_category()); - detail::write(appender(out), std::system_error(ec, message).what()); - return; - } - FMT_CATCH(...) {} - format_error_code(out, error_code, message); -} - -FMT_FUNC void report_system_error(int error_code, - const char* message) noexcept { - report_error(format_system_error, error_code, message); -} - -FMT_FUNC auto vformat(string_view fmt, format_args args) -> std::string { - // Don't optimize the "{}" case to keep the binary size small and because it - // can be better optimized in fmt::format anyway. - auto buffer = memory_buffer(); - detail::vformat_to(buffer, fmt, args); - return to_string(buffer); -} - -namespace detail { - -template struct span { - T* data; - size_t size; -}; - -#ifdef _WIN32 -inline void flockfile(FILE* f) { _lock_file(f); } -inline void funlockfile(FILE* f) { _unlock_file(f); } -inline int getc_unlocked(FILE* f) { return _fgetc_nolock(f); } -#endif - -// A FILE wrapper. F is FILE defined as a template parameter to make system API -// detection work. -template class file_base { - public: - F* file_; - - public: - file_base(F* file) : file_(file) {} - operator F*() const { return file_; } - - // Reads a code unit from the stream. - auto get() -> int { - int result = getc_unlocked(file_); - if (result == EOF && ferror(file_) != 0) - FMT_THROW(system_error(errno, FMT_STRING("getc failed"))); - return result; - } - - // Puts the code unit back into the stream buffer. - void unget(char c) { - if (ungetc(c, file_) == EOF) - FMT_THROW(system_error(errno, FMT_STRING("ungetc failed"))); - } - - void flush() { fflush(this->file_); } -}; - -// A FILE wrapper for glibc. -template class glibc_file : public file_base { - private: - enum { - line_buffered = 0x200, // _IO_LINE_BUF - unbuffered = 2 // _IO_UNBUFFERED - }; - - public: - using file_base::file_base; - - auto is_buffered() const -> bool { - return (this->file_->_flags & unbuffered) == 0; - } - - void init_buffer() { - if (this->file_->_IO_write_ptr) return; - // Force buffer initialization by placing and removing a char in a buffer. - putc_unlocked(0, this->file_); - --this->file_->_IO_write_ptr; - } - - // Returns the file's read buffer. - auto get_read_buffer() const -> span { - auto ptr = this->file_->_IO_read_ptr; - return {ptr, to_unsigned(this->file_->_IO_read_end - ptr)}; - } - - // Returns the file's write buffer. - auto get_write_buffer() const -> span { - auto ptr = this->file_->_IO_write_ptr; - return {ptr, to_unsigned(this->file_->_IO_buf_end - ptr)}; - } - - void advance_write_buffer(size_t size) { this->file_->_IO_write_ptr += size; } - - bool needs_flush() const { - if ((this->file_->_flags & line_buffered) == 0) return false; - char* end = this->file_->_IO_write_end; - return memchr(end, '\n', to_unsigned(this->file_->_IO_write_ptr - end)); - } - - void flush() { fflush_unlocked(this->file_); } -}; - -// A FILE wrapper for Apple's libc. -template class apple_file : public file_base { - private: - enum { - line_buffered = 1, // __SNBF - unbuffered = 2 // __SLBF - }; - - public: - using file_base::file_base; - - auto is_buffered() const -> bool { - return (this->file_->_flags & unbuffered) == 0; - } - - void init_buffer() { - if (this->file_->_p) return; - // Force buffer initialization by placing and removing a char in a buffer. - putc_unlocked(0, this->file_); - --this->file_->_p; - ++this->file_->_w; - } - - auto get_read_buffer() const -> span { - return {reinterpret_cast(this->file_->_p), - to_unsigned(this->file_->_r)}; - } - - auto get_write_buffer() const -> span { - return {reinterpret_cast(this->file_->_p), - to_unsigned(this->file_->_bf._base + this->file_->_bf._size - - this->file_->_p)}; - } - - void advance_write_buffer(size_t size) { - this->file_->_p += size; - this->file_->_w -= size; - } - - bool needs_flush() const { - if ((this->file_->_flags & line_buffered) == 0) return false; - return memchr(this->file_->_p + this->file_->_w, '\n', - to_unsigned(-this->file_->_w)); - } -}; - -// A fallback FILE wrapper. -template class fallback_file : public file_base { - private: - char next_; // The next unconsumed character in the buffer. - bool has_next_ = false; - - public: - using file_base::file_base; - - auto is_buffered() const -> bool { return false; } - auto needs_flush() const -> bool { return false; } - void init_buffer() {} - - auto get_read_buffer() const -> span { - return {&next_, has_next_ ? 1u : 0u}; - } - - auto get_write_buffer() const -> span { return {nullptr, 0}; } - - void advance_write_buffer(size_t) {} - - auto get() -> int { - has_next_ = false; - return file_base::get(); - } - - void unget(char c) { - file_base::unget(c); - next_ = c; - has_next_ = true; - } -}; - -#ifndef FMT_USE_FALLBACK_FILE -# define FMT_USE_FALLBACK_FILE 1 -#endif - -template -auto get_file(F* f, int) -> apple_file { - return f; -} -template -inline auto get_file(F* f, int) -> glibc_file { - return f; -} - -inline auto get_file(FILE* f, ...) -> fallback_file { return f; } - -using file_ref = decltype(get_file(static_cast(nullptr), 0)); - -class file_print_buffer : public buffer { - private: - file_ref file_; - - static void grow(buffer& base, size_t) { - auto& self = static_cast(base); - self.file_.advance_write_buffer(self.size()); - if (self.file_.get_write_buffer().size == 0) self.file_.flush(); - auto buf = self.file_.get_write_buffer(); - FMT_ASSERT(buf.size > 0, ""); - self.set(buf.data, buf.size); - self.clear(); - } - - public: - explicit file_print_buffer(FILE* f) : buffer(grow, size_t()), file_(f) { - flockfile(f); - file_.init_buffer(); - auto buf = file_.get_write_buffer(); - set(buf.data, buf.size); - } - ~file_print_buffer() { - file_.advance_write_buffer(size()); - bool flush = file_.needs_flush(); - funlockfile(file_); - if (flush) fflush(file_); - } -}; - -#if !defined(_WIN32) || defined(FMT_USE_WRITE_CONSOLE) -FMT_FUNC auto write_console(int, string_view) -> bool { return false; } -#else -using dword = conditional_t; -extern "C" __declspec(dllimport) int __stdcall WriteConsoleW( // - void*, const void*, dword, dword*, void*); - -FMT_FUNC bool write_console(int fd, string_view text) { - auto u16 = utf8_to_utf16(text); - return WriteConsoleW(reinterpret_cast(_get_osfhandle(fd)), u16.c_str(), - static_cast(u16.size()), nullptr, nullptr) != 0; -} -#endif - -#ifdef _WIN32 -// Print assuming legacy (non-Unicode) encoding. -FMT_FUNC void vprint_mojibake(std::FILE* f, string_view fmt, format_args args, - bool newline) { - auto buffer = memory_buffer(); - detail::vformat_to(buffer, fmt, args); - if (newline) buffer.push_back('\n'); - fwrite_fully(buffer.data(), buffer.size(), f); -} -#endif - -FMT_FUNC void print(std::FILE* f, string_view text) { -#if defined(_WIN32) && !defined(FMT_USE_WRITE_CONSOLE) - int fd = _fileno(f); - if (_isatty(fd)) { - std::fflush(f); - if (write_console(fd, text)) return; - } -#endif - fwrite_fully(text.data(), text.size(), f); -} -} // namespace detail - -FMT_FUNC void vprint_buffered(std::FILE* f, string_view fmt, format_args args) { - auto buffer = memory_buffer(); - detail::vformat_to(buffer, fmt, args); - detail::print(f, {buffer.data(), buffer.size()}); -} - -FMT_FUNC void vprint(std::FILE* f, string_view fmt, format_args args) { - if (!detail::file_ref(f).is_buffered()) return vprint_buffered(f, fmt, args); - auto&& buffer = detail::file_print_buffer(f); - return detail::vformat_to(buffer, fmt, args); -} - -FMT_FUNC void vprintln(std::FILE* f, string_view fmt, format_args args) { - auto buffer = memory_buffer(); - detail::vformat_to(buffer, fmt, args); - buffer.push_back('\n'); - detail::print(f, {buffer.data(), buffer.size()}); -} - -FMT_FUNC void vprint(string_view fmt, format_args args) { - vprint(stdout, fmt, args); -} - -namespace detail { - -struct singleton { - unsigned char upper; - unsigned char lower_count; -}; - -inline auto is_printable(uint16_t x, const singleton* singletons, - size_t singletons_size, - const unsigned char* singleton_lowers, - const unsigned char* normal, size_t normal_size) - -> bool { - auto upper = x >> 8; - auto lower_start = 0; - for (size_t i = 0; i < singletons_size; ++i) { - auto s = singletons[i]; - auto lower_end = lower_start + s.lower_count; - if (upper < s.upper) break; - if (upper == s.upper) { - for (auto j = lower_start; j < lower_end; ++j) { - if (singleton_lowers[j] == (x & 0xff)) return false; - } - } - lower_start = lower_end; - } - - auto xsigned = static_cast(x); - auto current = true; - for (size_t i = 0; i < normal_size; ++i) { - auto v = static_cast(normal[i]); - auto len = (v & 0x80) != 0 ? (v & 0x7f) << 8 | normal[++i] : v; - xsigned -= len; - if (xsigned < 0) break; - current = !current; - } - return current; -} - -// This code is generated by support/printable.py. -FMT_FUNC auto is_printable(uint32_t cp) -> bool { - static constexpr singleton singletons0[] = { - {0x00, 1}, {0x03, 5}, {0x05, 6}, {0x06, 3}, {0x07, 6}, {0x08, 8}, - {0x09, 17}, {0x0a, 28}, {0x0b, 25}, {0x0c, 20}, {0x0d, 16}, {0x0e, 13}, - {0x0f, 4}, {0x10, 3}, {0x12, 18}, {0x13, 9}, {0x16, 1}, {0x17, 5}, - {0x18, 2}, {0x19, 3}, {0x1a, 7}, {0x1c, 2}, {0x1d, 1}, {0x1f, 22}, - {0x20, 3}, {0x2b, 3}, {0x2c, 2}, {0x2d, 11}, {0x2e, 1}, {0x30, 3}, - {0x31, 2}, {0x32, 1}, {0xa7, 2}, {0xa9, 2}, {0xaa, 4}, {0xab, 8}, - {0xfa, 2}, {0xfb, 5}, {0xfd, 4}, {0xfe, 3}, {0xff, 9}, - }; - static constexpr unsigned char singletons0_lower[] = { - 0xad, 0x78, 0x79, 0x8b, 0x8d, 0xa2, 0x30, 0x57, 0x58, 0x8b, 0x8c, 0x90, - 0x1c, 0x1d, 0xdd, 0x0e, 0x0f, 0x4b, 0x4c, 0xfb, 0xfc, 0x2e, 0x2f, 0x3f, - 0x5c, 0x5d, 0x5f, 0xb5, 0xe2, 0x84, 0x8d, 0x8e, 0x91, 0x92, 0xa9, 0xb1, - 0xba, 0xbb, 0xc5, 0xc6, 0xc9, 0xca, 0xde, 0xe4, 0xe5, 0xff, 0x00, 0x04, - 0x11, 0x12, 0x29, 0x31, 0x34, 0x37, 0x3a, 0x3b, 0x3d, 0x49, 0x4a, 0x5d, - 0x84, 0x8e, 0x92, 0xa9, 0xb1, 0xb4, 0xba, 0xbb, 0xc6, 0xca, 0xce, 0xcf, - 0xe4, 0xe5, 0x00, 0x04, 0x0d, 0x0e, 0x11, 0x12, 0x29, 0x31, 0x34, 0x3a, - 0x3b, 0x45, 0x46, 0x49, 0x4a, 0x5e, 0x64, 0x65, 0x84, 0x91, 0x9b, 0x9d, - 0xc9, 0xce, 0xcf, 0x0d, 0x11, 0x29, 0x45, 0x49, 0x57, 0x64, 0x65, 0x8d, - 0x91, 0xa9, 0xb4, 0xba, 0xbb, 0xc5, 0xc9, 0xdf, 0xe4, 0xe5, 0xf0, 0x0d, - 0x11, 0x45, 0x49, 0x64, 0x65, 0x80, 0x84, 0xb2, 0xbc, 0xbe, 0xbf, 0xd5, - 0xd7, 0xf0, 0xf1, 0x83, 0x85, 0x8b, 0xa4, 0xa6, 0xbe, 0xbf, 0xc5, 0xc7, - 0xce, 0xcf, 0xda, 0xdb, 0x48, 0x98, 0xbd, 0xcd, 0xc6, 0xce, 0xcf, 0x49, - 0x4e, 0x4f, 0x57, 0x59, 0x5e, 0x5f, 0x89, 0x8e, 0x8f, 0xb1, 0xb6, 0xb7, - 0xbf, 0xc1, 0xc6, 0xc7, 0xd7, 0x11, 0x16, 0x17, 0x5b, 0x5c, 0xf6, 0xf7, - 0xfe, 0xff, 0x80, 0x0d, 0x6d, 0x71, 0xde, 0xdf, 0x0e, 0x0f, 0x1f, 0x6e, - 0x6f, 0x1c, 0x1d, 0x5f, 0x7d, 0x7e, 0xae, 0xaf, 0xbb, 0xbc, 0xfa, 0x16, - 0x17, 0x1e, 0x1f, 0x46, 0x47, 0x4e, 0x4f, 0x58, 0x5a, 0x5c, 0x5e, 0x7e, - 0x7f, 0xb5, 0xc5, 0xd4, 0xd5, 0xdc, 0xf0, 0xf1, 0xf5, 0x72, 0x73, 0x8f, - 0x74, 0x75, 0x96, 0x2f, 0x5f, 0x26, 0x2e, 0x2f, 0xa7, 0xaf, 0xb7, 0xbf, - 0xc7, 0xcf, 0xd7, 0xdf, 0x9a, 0x40, 0x97, 0x98, 0x30, 0x8f, 0x1f, 0xc0, - 0xc1, 0xce, 0xff, 0x4e, 0x4f, 0x5a, 0x5b, 0x07, 0x08, 0x0f, 0x10, 0x27, - 0x2f, 0xee, 0xef, 0x6e, 0x6f, 0x37, 0x3d, 0x3f, 0x42, 0x45, 0x90, 0x91, - 0xfe, 0xff, 0x53, 0x67, 0x75, 0xc8, 0xc9, 0xd0, 0xd1, 0xd8, 0xd9, 0xe7, - 0xfe, 0xff, - }; - static constexpr singleton singletons1[] = { - {0x00, 6}, {0x01, 1}, {0x03, 1}, {0x04, 2}, {0x08, 8}, {0x09, 2}, - {0x0a, 5}, {0x0b, 2}, {0x0e, 4}, {0x10, 1}, {0x11, 2}, {0x12, 5}, - {0x13, 17}, {0x14, 1}, {0x15, 2}, {0x17, 2}, {0x19, 13}, {0x1c, 5}, - {0x1d, 8}, {0x24, 1}, {0x6a, 3}, {0x6b, 2}, {0xbc, 2}, {0xd1, 2}, - {0xd4, 12}, {0xd5, 9}, {0xd6, 2}, {0xd7, 2}, {0xda, 1}, {0xe0, 5}, - {0xe1, 2}, {0xe8, 2}, {0xee, 32}, {0xf0, 4}, {0xf8, 2}, {0xf9, 2}, - {0xfa, 2}, {0xfb, 1}, - }; - static constexpr unsigned char singletons1_lower[] = { - 0x0c, 0x27, 0x3b, 0x3e, 0x4e, 0x4f, 0x8f, 0x9e, 0x9e, 0x9f, 0x06, 0x07, - 0x09, 0x36, 0x3d, 0x3e, 0x56, 0xf3, 0xd0, 0xd1, 0x04, 0x14, 0x18, 0x36, - 0x37, 0x56, 0x57, 0x7f, 0xaa, 0xae, 0xaf, 0xbd, 0x35, 0xe0, 0x12, 0x87, - 0x89, 0x8e, 0x9e, 0x04, 0x0d, 0x0e, 0x11, 0x12, 0x29, 0x31, 0x34, 0x3a, - 0x45, 0x46, 0x49, 0x4a, 0x4e, 0x4f, 0x64, 0x65, 0x5c, 0xb6, 0xb7, 0x1b, - 0x1c, 0x07, 0x08, 0x0a, 0x0b, 0x14, 0x17, 0x36, 0x39, 0x3a, 0xa8, 0xa9, - 0xd8, 0xd9, 0x09, 0x37, 0x90, 0x91, 0xa8, 0x07, 0x0a, 0x3b, 0x3e, 0x66, - 0x69, 0x8f, 0x92, 0x6f, 0x5f, 0xee, 0xef, 0x5a, 0x62, 0x9a, 0x9b, 0x27, - 0x28, 0x55, 0x9d, 0xa0, 0xa1, 0xa3, 0xa4, 0xa7, 0xa8, 0xad, 0xba, 0xbc, - 0xc4, 0x06, 0x0b, 0x0c, 0x15, 0x1d, 0x3a, 0x3f, 0x45, 0x51, 0xa6, 0xa7, - 0xcc, 0xcd, 0xa0, 0x07, 0x19, 0x1a, 0x22, 0x25, 0x3e, 0x3f, 0xc5, 0xc6, - 0x04, 0x20, 0x23, 0x25, 0x26, 0x28, 0x33, 0x38, 0x3a, 0x48, 0x4a, 0x4c, - 0x50, 0x53, 0x55, 0x56, 0x58, 0x5a, 0x5c, 0x5e, 0x60, 0x63, 0x65, 0x66, - 0x6b, 0x73, 0x78, 0x7d, 0x7f, 0x8a, 0xa4, 0xaa, 0xaf, 0xb0, 0xc0, 0xd0, - 0xae, 0xaf, 0x79, 0xcc, 0x6e, 0x6f, 0x93, - }; - static constexpr unsigned char normal0[] = { - 0x00, 0x20, 0x5f, 0x22, 0x82, 0xdf, 0x04, 0x82, 0x44, 0x08, 0x1b, 0x04, - 0x06, 0x11, 0x81, 0xac, 0x0e, 0x80, 0xab, 0x35, 0x28, 0x0b, 0x80, 0xe0, - 0x03, 0x19, 0x08, 0x01, 0x04, 0x2f, 0x04, 0x34, 0x04, 0x07, 0x03, 0x01, - 0x07, 0x06, 0x07, 0x11, 0x0a, 0x50, 0x0f, 0x12, 0x07, 0x55, 0x07, 0x03, - 0x04, 0x1c, 0x0a, 0x09, 0x03, 0x08, 0x03, 0x07, 0x03, 0x02, 0x03, 0x03, - 0x03, 0x0c, 0x04, 0x05, 0x03, 0x0b, 0x06, 0x01, 0x0e, 0x15, 0x05, 0x3a, - 0x03, 0x11, 0x07, 0x06, 0x05, 0x10, 0x07, 0x57, 0x07, 0x02, 0x07, 0x15, - 0x0d, 0x50, 0x04, 0x43, 0x03, 0x2d, 0x03, 0x01, 0x04, 0x11, 0x06, 0x0f, - 0x0c, 0x3a, 0x04, 0x1d, 0x25, 0x5f, 0x20, 0x6d, 0x04, 0x6a, 0x25, 0x80, - 0xc8, 0x05, 0x82, 0xb0, 0x03, 0x1a, 0x06, 0x82, 0xfd, 0x03, 0x59, 0x07, - 0x15, 0x0b, 0x17, 0x09, 0x14, 0x0c, 0x14, 0x0c, 0x6a, 0x06, 0x0a, 0x06, - 0x1a, 0x06, 0x59, 0x07, 0x2b, 0x05, 0x46, 0x0a, 0x2c, 0x04, 0x0c, 0x04, - 0x01, 0x03, 0x31, 0x0b, 0x2c, 0x04, 0x1a, 0x06, 0x0b, 0x03, 0x80, 0xac, - 0x06, 0x0a, 0x06, 0x21, 0x3f, 0x4c, 0x04, 0x2d, 0x03, 0x74, 0x08, 0x3c, - 0x03, 0x0f, 0x03, 0x3c, 0x07, 0x38, 0x08, 0x2b, 0x05, 0x82, 0xff, 0x11, - 0x18, 0x08, 0x2f, 0x11, 0x2d, 0x03, 0x20, 0x10, 0x21, 0x0f, 0x80, 0x8c, - 0x04, 0x82, 0x97, 0x19, 0x0b, 0x15, 0x88, 0x94, 0x05, 0x2f, 0x05, 0x3b, - 0x07, 0x02, 0x0e, 0x18, 0x09, 0x80, 0xb3, 0x2d, 0x74, 0x0c, 0x80, 0xd6, - 0x1a, 0x0c, 0x05, 0x80, 0xff, 0x05, 0x80, 0xdf, 0x0c, 0xee, 0x0d, 0x03, - 0x84, 0x8d, 0x03, 0x37, 0x09, 0x81, 0x5c, 0x14, 0x80, 0xb8, 0x08, 0x80, - 0xcb, 0x2a, 0x38, 0x03, 0x0a, 0x06, 0x38, 0x08, 0x46, 0x08, 0x0c, 0x06, - 0x74, 0x0b, 0x1e, 0x03, 0x5a, 0x04, 0x59, 0x09, 0x80, 0x83, 0x18, 0x1c, - 0x0a, 0x16, 0x09, 0x4c, 0x04, 0x80, 0x8a, 0x06, 0xab, 0xa4, 0x0c, 0x17, - 0x04, 0x31, 0xa1, 0x04, 0x81, 0xda, 0x26, 0x07, 0x0c, 0x05, 0x05, 0x80, - 0xa5, 0x11, 0x81, 0x6d, 0x10, 0x78, 0x28, 0x2a, 0x06, 0x4c, 0x04, 0x80, - 0x8d, 0x04, 0x80, 0xbe, 0x03, 0x1b, 0x03, 0x0f, 0x0d, - }; - static constexpr unsigned char normal1[] = { - 0x5e, 0x22, 0x7b, 0x05, 0x03, 0x04, 0x2d, 0x03, 0x66, 0x03, 0x01, 0x2f, - 0x2e, 0x80, 0x82, 0x1d, 0x03, 0x31, 0x0f, 0x1c, 0x04, 0x24, 0x09, 0x1e, - 0x05, 0x2b, 0x05, 0x44, 0x04, 0x0e, 0x2a, 0x80, 0xaa, 0x06, 0x24, 0x04, - 0x24, 0x04, 0x28, 0x08, 0x34, 0x0b, 0x01, 0x80, 0x90, 0x81, 0x37, 0x09, - 0x16, 0x0a, 0x08, 0x80, 0x98, 0x39, 0x03, 0x63, 0x08, 0x09, 0x30, 0x16, - 0x05, 0x21, 0x03, 0x1b, 0x05, 0x01, 0x40, 0x38, 0x04, 0x4b, 0x05, 0x2f, - 0x04, 0x0a, 0x07, 0x09, 0x07, 0x40, 0x20, 0x27, 0x04, 0x0c, 0x09, 0x36, - 0x03, 0x3a, 0x05, 0x1a, 0x07, 0x04, 0x0c, 0x07, 0x50, 0x49, 0x37, 0x33, - 0x0d, 0x33, 0x07, 0x2e, 0x08, 0x0a, 0x81, 0x26, 0x52, 0x4e, 0x28, 0x08, - 0x2a, 0x56, 0x1c, 0x14, 0x17, 0x09, 0x4e, 0x04, 0x1e, 0x0f, 0x43, 0x0e, - 0x19, 0x07, 0x0a, 0x06, 0x48, 0x08, 0x27, 0x09, 0x75, 0x0b, 0x3f, 0x41, - 0x2a, 0x06, 0x3b, 0x05, 0x0a, 0x06, 0x51, 0x06, 0x01, 0x05, 0x10, 0x03, - 0x05, 0x80, 0x8b, 0x62, 0x1e, 0x48, 0x08, 0x0a, 0x80, 0xa6, 0x5e, 0x22, - 0x45, 0x0b, 0x0a, 0x06, 0x0d, 0x13, 0x39, 0x07, 0x0a, 0x36, 0x2c, 0x04, - 0x10, 0x80, 0xc0, 0x3c, 0x64, 0x53, 0x0c, 0x48, 0x09, 0x0a, 0x46, 0x45, - 0x1b, 0x48, 0x08, 0x53, 0x1d, 0x39, 0x81, 0x07, 0x46, 0x0a, 0x1d, 0x03, - 0x47, 0x49, 0x37, 0x03, 0x0e, 0x08, 0x0a, 0x06, 0x39, 0x07, 0x0a, 0x81, - 0x36, 0x19, 0x80, 0xb7, 0x01, 0x0f, 0x32, 0x0d, 0x83, 0x9b, 0x66, 0x75, - 0x0b, 0x80, 0xc4, 0x8a, 0xbc, 0x84, 0x2f, 0x8f, 0xd1, 0x82, 0x47, 0xa1, - 0xb9, 0x82, 0x39, 0x07, 0x2a, 0x04, 0x02, 0x60, 0x26, 0x0a, 0x46, 0x0a, - 0x28, 0x05, 0x13, 0x82, 0xb0, 0x5b, 0x65, 0x4b, 0x04, 0x39, 0x07, 0x11, - 0x40, 0x05, 0x0b, 0x02, 0x0e, 0x97, 0xf8, 0x08, 0x84, 0xd6, 0x2a, 0x09, - 0xa2, 0xf7, 0x81, 0x1f, 0x31, 0x03, 0x11, 0x04, 0x08, 0x81, 0x8c, 0x89, - 0x04, 0x6b, 0x05, 0x0d, 0x03, 0x09, 0x07, 0x10, 0x93, 0x60, 0x80, 0xf6, - 0x0a, 0x73, 0x08, 0x6e, 0x17, 0x46, 0x80, 0x9a, 0x14, 0x0c, 0x57, 0x09, - 0x19, 0x80, 0x87, 0x81, 0x47, 0x03, 0x85, 0x42, 0x0f, 0x15, 0x85, 0x50, - 0x2b, 0x80, 0xd5, 0x2d, 0x03, 0x1a, 0x04, 0x02, 0x81, 0x70, 0x3a, 0x05, - 0x01, 0x85, 0x00, 0x80, 0xd7, 0x29, 0x4c, 0x04, 0x0a, 0x04, 0x02, 0x83, - 0x11, 0x44, 0x4c, 0x3d, 0x80, 0xc2, 0x3c, 0x06, 0x01, 0x04, 0x55, 0x05, - 0x1b, 0x34, 0x02, 0x81, 0x0e, 0x2c, 0x04, 0x64, 0x0c, 0x56, 0x0a, 0x80, - 0xae, 0x38, 0x1d, 0x0d, 0x2c, 0x04, 0x09, 0x07, 0x02, 0x0e, 0x06, 0x80, - 0x9a, 0x83, 0xd8, 0x08, 0x0d, 0x03, 0x0d, 0x03, 0x74, 0x0c, 0x59, 0x07, - 0x0c, 0x14, 0x0c, 0x04, 0x38, 0x08, 0x0a, 0x06, 0x28, 0x08, 0x22, 0x4e, - 0x81, 0x54, 0x0c, 0x15, 0x03, 0x03, 0x05, 0x07, 0x09, 0x19, 0x07, 0x07, - 0x09, 0x03, 0x0d, 0x07, 0x29, 0x80, 0xcb, 0x25, 0x0a, 0x84, 0x06, - }; - auto lower = static_cast(cp); - if (cp < 0x10000) { - return is_printable(lower, singletons0, - sizeof(singletons0) / sizeof(*singletons0), - singletons0_lower, normal0, sizeof(normal0)); - } - if (cp < 0x20000) { - return is_printable(lower, singletons1, - sizeof(singletons1) / sizeof(*singletons1), - singletons1_lower, normal1, sizeof(normal1)); - } - if (0x2a6de <= cp && cp < 0x2a700) return false; - if (0x2b735 <= cp && cp < 0x2b740) return false; - if (0x2b81e <= cp && cp < 0x2b820) return false; - if (0x2cea2 <= cp && cp < 0x2ceb0) return false; - if (0x2ebe1 <= cp && cp < 0x2f800) return false; - if (0x2fa1e <= cp && cp < 0x30000) return false; - if (0x3134b <= cp && cp < 0xe0100) return false; - if (0xe01f0 <= cp && cp < 0x110000) return false; - return cp < 0x110000; -} - -} // namespace detail - -FMT_END_NAMESPACE - -#endif // FMT_FORMAT_INL_H_ diff --git a/tt_metal/third_party/fmt/fmt/format.h b/tt_metal/third_party/fmt/fmt/format.h deleted file mode 100644 index 7c2a19b4084..00000000000 --- a/tt_metal/third_party/fmt/fmt/format.h +++ /dev/null @@ -1,4419 +0,0 @@ -/* - Formatting library for C++ - - Copyright (c) 2012 - present, Victor Zverovich - - Permission is hereby granted, free of charge, to any person obtaining - a copy of this software and associated documentation files (the - "Software"), to deal in the Software without restriction, including - without limitation the rights to use, copy, modify, merge, publish, - distribute, sublicense, and/or sell copies of the Software, and to - permit persons to whom the Software is furnished to do so, subject to - the following conditions: - - The above copyright notice and this permission notice shall be - included in all copies or substantial portions of the Software. - - THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, - EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND - NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE - LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION - OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION - WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - --- Optional exception to the license --- - - As an exception, if, as a result of your compiling your source code, portions - of this Software are embedded into a machine-executable object form of such - source code, you may redistribute such embedded portions in such object form - without including the above copyright and permission notices. - */ - -#ifndef FMT_FORMAT_H_ -#define FMT_FORMAT_H_ - -#ifndef _LIBCPP_REMOVE_TRANSITIVE_INCLUDES -# define _LIBCPP_REMOVE_TRANSITIVE_INCLUDES -# define FMT_REMOVE_TRANSITIVE_INCLUDES -#endif - -#include "base.h" - -#ifndef FMT_MODULE -# include // std::signbit -# include // uint32_t -# include // std::memcpy -# include // std::initializer_list -# include // std::numeric_limits -# if defined(__GLIBCXX__) && !defined(_GLIBCXX_USE_DUAL_ABI) -// Workaround for pre gcc 5 libstdc++. -# include // std::allocator_traits -# endif -# include // std::runtime_error -# include // std::string -# include // std::system_error - -// Checking FMT_CPLUSPLUS for warning suppression in MSVC. -# if FMT_HAS_INCLUDE() && FMT_CPLUSPLUS > 201703L -# include // std::bit_cast -# endif - -// libc++ supports string_view in pre-c++17. -# if FMT_HAS_INCLUDE() && \ - (FMT_CPLUSPLUS >= 201703L || defined(_LIBCPP_VERSION)) -# include -# define FMT_USE_STRING_VIEW -# endif -#endif // FMT_MODULE - -#if defined __cpp_inline_variables && __cpp_inline_variables >= 201606L -# define FMT_INLINE_VARIABLE inline -#else -# define FMT_INLINE_VARIABLE -#endif - -#ifndef FMT_NO_UNIQUE_ADDRESS -# if FMT_CPLUSPLUS >= 202002L -# if FMT_HAS_CPP_ATTRIBUTE(no_unique_address) -# define FMT_NO_UNIQUE_ADDRESS [[no_unique_address]] -// VS2019 v16.10 and later except clang-cl (https://reviews.llvm.org/D110485). -# elif (FMT_MSC_VERSION >= 1929) && !FMT_CLANG_VERSION -# define FMT_NO_UNIQUE_ADDRESS [[msvc::no_unique_address]] -# endif -# endif -#endif -#ifndef FMT_NO_UNIQUE_ADDRESS -# define FMT_NO_UNIQUE_ADDRESS -#endif - -// Visibility when compiled as a shared library/object. -#if defined(FMT_LIB_EXPORT) || defined(FMT_SHARED) -# define FMT_SO_VISIBILITY(value) FMT_VISIBILITY(value) -#else -# define FMT_SO_VISIBILITY(value) -#endif - -#ifdef __has_builtin -# define FMT_HAS_BUILTIN(x) __has_builtin(x) -#else -# define FMT_HAS_BUILTIN(x) 0 -#endif - -#if FMT_GCC_VERSION || FMT_CLANG_VERSION -# define FMT_NOINLINE __attribute__((noinline)) -#else -# define FMT_NOINLINE -#endif - -#ifndef FMT_THROW -# if FMT_EXCEPTIONS -# if FMT_MSC_VERSION || defined(__NVCC__) -FMT_BEGIN_NAMESPACE -namespace detail { -template inline void do_throw(const Exception& x) { - // Silence unreachable code warnings in MSVC and NVCC because these - // are nearly impossible to fix in a generic code. - volatile bool b = true; - if (b) throw x; -} -} // namespace detail -FMT_END_NAMESPACE -# define FMT_THROW(x) detail::do_throw(x) -# else -# define FMT_THROW(x) throw x -# endif -# else -# define FMT_THROW(x) \ - ::fmt::detail::assert_fail(__FILE__, __LINE__, (x).what()) -# endif -#endif - -#ifndef FMT_MAYBE_UNUSED -# if FMT_HAS_CPP17_ATTRIBUTE(maybe_unused) -# define FMT_MAYBE_UNUSED [[maybe_unused]] -# else -# define FMT_MAYBE_UNUSED -# endif -#endif - -#ifndef FMT_USE_USER_DEFINED_LITERALS -// EDG based compilers (Intel, NVIDIA, Elbrus, etc), GCC and MSVC support UDLs. -// -// GCC before 4.9 requires a space in `operator"" _a` which is invalid in later -// compiler versions. -# if (FMT_HAS_FEATURE(cxx_user_literals) || FMT_GCC_VERSION >= 409 || \ - FMT_MSC_VERSION >= 1900) && \ - (!defined(__EDG_VERSION__) || __EDG_VERSION__ >= /* UDL feature */ 480) -# define FMT_USE_USER_DEFINED_LITERALS 1 -# else -# define FMT_USE_USER_DEFINED_LITERALS 0 -# endif -#endif - -// Defining FMT_REDUCE_INT_INSTANTIATIONS to 1, will reduce the number of -// integer formatter template instantiations to just one by only using the -// largest integer type. This results in a reduction in binary size but will -// cause a decrease in integer formatting performance. -#if !defined(FMT_REDUCE_INT_INSTANTIATIONS) -# define FMT_REDUCE_INT_INSTANTIATIONS 0 -#endif - -// __builtin_clz is broken in clang with Microsoft CodeGen: -// https://github.com/fmtlib/fmt/issues/519. -#if !FMT_MSC_VERSION -# if FMT_HAS_BUILTIN(__builtin_clz) || FMT_GCC_VERSION || FMT_ICC_VERSION -# define FMT_BUILTIN_CLZ(n) __builtin_clz(n) -# endif -# if FMT_HAS_BUILTIN(__builtin_clzll) || FMT_GCC_VERSION || FMT_ICC_VERSION -# define FMT_BUILTIN_CLZLL(n) __builtin_clzll(n) -# endif -#endif - -// __builtin_ctz is broken in Intel Compiler Classic on Windows: -// https://github.com/fmtlib/fmt/issues/2510. -#ifndef __ICL -# if FMT_HAS_BUILTIN(__builtin_ctz) || FMT_GCC_VERSION || FMT_ICC_VERSION || \ - defined(__NVCOMPILER) -# define FMT_BUILTIN_CTZ(n) __builtin_ctz(n) -# endif -# if FMT_HAS_BUILTIN(__builtin_ctzll) || FMT_GCC_VERSION || \ - FMT_ICC_VERSION || defined(__NVCOMPILER) -# define FMT_BUILTIN_CTZLL(n) __builtin_ctzll(n) -# endif -#endif - -#if FMT_MSC_VERSION -# include // _BitScanReverse[64], _BitScanForward[64], _umul128 -#endif - -// Some compilers masquerade as both MSVC and GCC-likes or otherwise support -// __builtin_clz and __builtin_clzll, so only define FMT_BUILTIN_CLZ using the -// MSVC intrinsics if the clz and clzll builtins are not available. -#if FMT_MSC_VERSION && !defined(FMT_BUILTIN_CLZLL) && \ - !defined(FMT_BUILTIN_CTZLL) -FMT_BEGIN_NAMESPACE -namespace detail { -// Avoid Clang with Microsoft CodeGen's -Wunknown-pragmas warning. -# if !defined(__clang__) -# pragma intrinsic(_BitScanForward) -# pragma intrinsic(_BitScanReverse) -# if defined(_WIN64) -# pragma intrinsic(_BitScanForward64) -# pragma intrinsic(_BitScanReverse64) -# endif -# endif - -inline auto clz(uint32_t x) -> int { - unsigned long r = 0; - _BitScanReverse(&r, x); - FMT_ASSERT(x != 0, ""); - // Static analysis complains about using uninitialized data - // "r", but the only way that can happen is if "x" is 0, - // which the callers guarantee to not happen. - FMT_MSC_WARNING(suppress : 6102) - return 31 ^ static_cast(r); -} -# define FMT_BUILTIN_CLZ(n) detail::clz(n) - -inline auto clzll(uint64_t x) -> int { - unsigned long r = 0; -# ifdef _WIN64 - _BitScanReverse64(&r, x); -# else - // Scan the high 32 bits. - if (_BitScanReverse(&r, static_cast(x >> 32))) - return 63 ^ static_cast(r + 32); - // Scan the low 32 bits. - _BitScanReverse(&r, static_cast(x)); -# endif - FMT_ASSERT(x != 0, ""); - FMT_MSC_WARNING(suppress : 6102) // Suppress a bogus static analysis warning. - return 63 ^ static_cast(r); -} -# define FMT_BUILTIN_CLZLL(n) detail::clzll(n) - -inline auto ctz(uint32_t x) -> int { - unsigned long r = 0; - _BitScanForward(&r, x); - FMT_ASSERT(x != 0, ""); - FMT_MSC_WARNING(suppress : 6102) // Suppress a bogus static analysis warning. - return static_cast(r); -} -# define FMT_BUILTIN_CTZ(n) detail::ctz(n) - -inline auto ctzll(uint64_t x) -> int { - unsigned long r = 0; - FMT_ASSERT(x != 0, ""); - FMT_MSC_WARNING(suppress : 6102) // Suppress a bogus static analysis warning. -# ifdef _WIN64 - _BitScanForward64(&r, x); -# else - // Scan the low 32 bits. - if (_BitScanForward(&r, static_cast(x))) return static_cast(r); - // Scan the high 32 bits. - _BitScanForward(&r, static_cast(x >> 32)); - r += 32; -# endif - return static_cast(r); -} -# define FMT_BUILTIN_CTZLL(n) detail::ctzll(n) -} // namespace detail -FMT_END_NAMESPACE -#endif - -FMT_BEGIN_NAMESPACE - -template -struct is_contiguous> - : std::true_type {}; - -namespace detail { - -FMT_CONSTEXPR inline void abort_fuzzing_if(bool condition) { - ignore_unused(condition); -#ifdef FMT_FUZZ - if (condition) throw std::runtime_error("fuzzing limit reached"); -#endif -} - -#if defined(FMT_USE_STRING_VIEW) -template using std_string_view = std::basic_string_view; -#else -template struct std_string_view {}; -#endif - -// Implementation of std::bit_cast for pre-C++20. -template -FMT_CONSTEXPR20 auto bit_cast(const From& from) -> To { -#ifdef __cpp_lib_bit_cast - if (is_constant_evaluated()) return std::bit_cast(from); -#endif - auto to = To(); - // The cast suppresses a bogus -Wclass-memaccess on GCC. - std::memcpy(static_cast(&to), &from, sizeof(to)); - return to; -} - -inline auto is_big_endian() -> bool { -#ifdef _WIN32 - return false; -#elif defined(__BIG_ENDIAN__) - return true; -#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) - return __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__; -#else - struct bytes { - char data[sizeof(int)]; - }; - return bit_cast(1).data[0] == 0; -#endif -} - -class uint128_fallback { - private: - uint64_t lo_, hi_; - - public: - constexpr uint128_fallback(uint64_t hi, uint64_t lo) : lo_(lo), hi_(hi) {} - constexpr uint128_fallback(uint64_t value = 0) : lo_(value), hi_(0) {} - - constexpr auto high() const noexcept -> uint64_t { return hi_; } - constexpr auto low() const noexcept -> uint64_t { return lo_; } - - template ::value)> - constexpr explicit operator T() const { - return static_cast(lo_); - } - - friend constexpr auto operator==(const uint128_fallback& lhs, - const uint128_fallback& rhs) -> bool { - return lhs.hi_ == rhs.hi_ && lhs.lo_ == rhs.lo_; - } - friend constexpr auto operator!=(const uint128_fallback& lhs, - const uint128_fallback& rhs) -> bool { - return !(lhs == rhs); - } - friend constexpr auto operator>(const uint128_fallback& lhs, - const uint128_fallback& rhs) -> bool { - return lhs.hi_ != rhs.hi_ ? lhs.hi_ > rhs.hi_ : lhs.lo_ > rhs.lo_; - } - friend constexpr auto operator|(const uint128_fallback& lhs, - const uint128_fallback& rhs) - -> uint128_fallback { - return {lhs.hi_ | rhs.hi_, lhs.lo_ | rhs.lo_}; - } - friend constexpr auto operator&(const uint128_fallback& lhs, - const uint128_fallback& rhs) - -> uint128_fallback { - return {lhs.hi_ & rhs.hi_, lhs.lo_ & rhs.lo_}; - } - friend constexpr auto operator~(const uint128_fallback& n) - -> uint128_fallback { - return {~n.hi_, ~n.lo_}; - } - friend auto operator+(const uint128_fallback& lhs, - const uint128_fallback& rhs) -> uint128_fallback { - auto result = uint128_fallback(lhs); - result += rhs; - return result; - } - friend auto operator*(const uint128_fallback& lhs, uint32_t rhs) - -> uint128_fallback { - FMT_ASSERT(lhs.hi_ == 0, ""); - uint64_t hi = (lhs.lo_ >> 32) * rhs; - uint64_t lo = (lhs.lo_ & ~uint32_t()) * rhs; - uint64_t new_lo = (hi << 32) + lo; - return {(hi >> 32) + (new_lo < lo ? 1 : 0), new_lo}; - } - friend auto operator-(const uint128_fallback& lhs, uint64_t rhs) - -> uint128_fallback { - return {lhs.hi_ - (lhs.lo_ < rhs ? 1 : 0), lhs.lo_ - rhs}; - } - FMT_CONSTEXPR auto operator>>(int shift) const -> uint128_fallback { - if (shift == 64) return {0, hi_}; - if (shift > 64) return uint128_fallback(0, hi_) >> (shift - 64); - return {hi_ >> shift, (hi_ << (64 - shift)) | (lo_ >> shift)}; - } - FMT_CONSTEXPR auto operator<<(int shift) const -> uint128_fallback { - if (shift == 64) return {lo_, 0}; - if (shift > 64) return uint128_fallback(lo_, 0) << (shift - 64); - return {hi_ << shift | (lo_ >> (64 - shift)), (lo_ << shift)}; - } - FMT_CONSTEXPR auto operator>>=(int shift) -> uint128_fallback& { - return *this = *this >> shift; - } - FMT_CONSTEXPR void operator+=(uint128_fallback n) { - uint64_t new_lo = lo_ + n.lo_; - uint64_t new_hi = hi_ + n.hi_ + (new_lo < lo_ ? 1 : 0); - FMT_ASSERT(new_hi >= hi_, ""); - lo_ = new_lo; - hi_ = new_hi; - } - FMT_CONSTEXPR void operator&=(uint128_fallback n) { - lo_ &= n.lo_; - hi_ &= n.hi_; - } - - FMT_CONSTEXPR20 auto operator+=(uint64_t n) noexcept -> uint128_fallback& { - if (is_constant_evaluated()) { - lo_ += n; - hi_ += (lo_ < n ? 1 : 0); - return *this; - } -#if FMT_HAS_BUILTIN(__builtin_addcll) && !defined(__ibmxl__) - unsigned long long carry; - lo_ = __builtin_addcll(lo_, n, 0, &carry); - hi_ += carry; -#elif FMT_HAS_BUILTIN(__builtin_ia32_addcarryx_u64) && !defined(__ibmxl__) - unsigned long long result; - auto carry = __builtin_ia32_addcarryx_u64(0, lo_, n, &result); - lo_ = result; - hi_ += carry; -#elif defined(_MSC_VER) && defined(_M_X64) - auto carry = _addcarry_u64(0, lo_, n, &lo_); - _addcarry_u64(carry, hi_, 0, &hi_); -#else - lo_ += n; - hi_ += (lo_ < n ? 1 : 0); -#endif - return *this; - } -}; - -using uint128_t = conditional_t; - -#ifdef UINTPTR_MAX -using uintptr_t = ::uintptr_t; -#else -using uintptr_t = uint128_t; -#endif - -// Returns the largest possible value for type T. Same as -// std::numeric_limits::max() but shorter and not affected by the max macro. -template constexpr auto max_value() -> T { - return (std::numeric_limits::max)(); -} -template constexpr auto num_bits() -> int { - return std::numeric_limits::digits; -} -// std::numeric_limits::digits may return 0 for 128-bit ints. -template <> constexpr auto num_bits() -> int { return 128; } -template <> constexpr auto num_bits() -> int { return 128; } -template <> constexpr auto num_bits() -> int { return 128; } - -// A heterogeneous bit_cast used for converting 96-bit long double to uint128_t -// and 128-bit pointers to uint128_fallback. -template sizeof(From))> -inline auto bit_cast(const From& from) -> To { - constexpr auto size = static_cast(sizeof(From) / sizeof(unsigned)); - struct data_t { - unsigned value[static_cast(size)]; - } data = bit_cast(from); - auto result = To(); - if (const_check(is_big_endian())) { - for (int i = 0; i < size; ++i) - result = (result << num_bits()) | data.value[i]; - } else { - for (int i = size - 1; i >= 0; --i) - result = (result << num_bits()) | data.value[i]; - } - return result; -} - -template -FMT_CONSTEXPR20 inline auto countl_zero_fallback(UInt n) -> int { - int lz = 0; - constexpr UInt msb_mask = static_cast(1) << (num_bits() - 1); - for (; (n & msb_mask) == 0; n <<= 1) lz++; - return lz; -} - -FMT_CONSTEXPR20 inline auto countl_zero(uint32_t n) -> int { -#ifdef FMT_BUILTIN_CLZ - if (!is_constant_evaluated()) return FMT_BUILTIN_CLZ(n); -#endif - return countl_zero_fallback(n); -} - -FMT_CONSTEXPR20 inline auto countl_zero(uint64_t n) -> int { -#ifdef FMT_BUILTIN_CLZLL - if (!is_constant_evaluated()) return FMT_BUILTIN_CLZLL(n); -#endif - return countl_zero_fallback(n); -} - -FMT_INLINE void assume(bool condition) { - (void)condition; -#if FMT_HAS_BUILTIN(__builtin_assume) && !FMT_ICC_VERSION - __builtin_assume(condition); -#elif FMT_GCC_VERSION - if (!condition) __builtin_unreachable(); -#endif -} - -// An approximation of iterator_t for pre-C++20 systems. -template -using iterator_t = decltype(std::begin(std::declval())); -template using sentinel_t = decltype(std::end(std::declval())); - -// A workaround for std::string not having mutable data() until C++17. -template -inline auto get_data(std::basic_string& s) -> Char* { - return &s[0]; -} -template -inline auto get_data(Container& c) -> typename Container::value_type* { - return c.data(); -} - -// Attempts to reserve space for n extra characters in the output range. -// Returns a pointer to the reserved range or a reference to it. -template ::value&& - is_contiguous::value)> -#if FMT_CLANG_VERSION >= 307 && !FMT_ICC_VERSION -__attribute__((no_sanitize("undefined"))) -#endif -inline auto -reserve(OutputIt it, size_t n) -> typename OutputIt::value_type* { - auto& c = get_container(it); - size_t size = c.size(); - c.resize(size + n); - return get_data(c) + size; -} - -template -inline auto reserve(basic_appender it, size_t n) -> basic_appender { - buffer& buf = get_container(it); - buf.try_reserve(buf.size() + n); - return it; -} - -template -constexpr auto reserve(Iterator& it, size_t) -> Iterator& { - return it; -} - -template -using reserve_iterator = - remove_reference_t(), 0))>; - -template -constexpr auto to_pointer(OutputIt, size_t) -> T* { - return nullptr; -} -template auto to_pointer(basic_appender it, size_t n) -> T* { - buffer& buf = get_container(it); - auto size = buf.size(); - if (buf.capacity() < size + n) return nullptr; - buf.try_resize(size + n); - return buf.data() + size; -} - -template ::value&& - is_contiguous::value)> -inline auto base_iterator(OutputIt it, - typename OutputIt::container_type::value_type*) - -> OutputIt { - return it; -} - -template -constexpr auto base_iterator(Iterator, Iterator it) -> Iterator { - return it; -} - -// is spectacularly slow to compile in C++20 so use a simple fill_n -// instead (#1998). -template -FMT_CONSTEXPR auto fill_n(OutputIt out, Size count, const T& value) - -> OutputIt { - for (Size i = 0; i < count; ++i) *out++ = value; - return out; -} -template -FMT_CONSTEXPR20 auto fill_n(T* out, Size count, char value) -> T* { - if (is_constant_evaluated()) { - return fill_n(out, count, value); - } - std::memset(out, value, to_unsigned(count)); - return out + count; -} - -template -FMT_CONSTEXPR FMT_NOINLINE auto copy_noinline(InputIt begin, InputIt end, - OutputIt out) -> OutputIt { - return copy(begin, end, out); -} - -// A public domain branchless UTF-8 decoder by Christopher Wellons: -// https://github.com/skeeto/branchless-utf8 -/* Decode the next character, c, from s, reporting errors in e. - * - * Since this is a branchless decoder, four bytes will be read from the - * buffer regardless of the actual length of the next character. This - * means the buffer _must_ have at least three bytes of zero padding - * following the end of the data stream. - * - * Errors are reported in e, which will be non-zero if the parsed - * character was somehow invalid: invalid byte sequence, non-canonical - * encoding, or a surrogate half. - * - * The function returns a pointer to the next character. When an error - * occurs, this pointer will be a guess that depends on the particular - * error, but it will always advance at least one byte. - */ -FMT_CONSTEXPR inline auto utf8_decode(const char* s, uint32_t* c, int* e) - -> const char* { - constexpr const int masks[] = {0x00, 0x7f, 0x1f, 0x0f, 0x07}; - constexpr const uint32_t mins[] = {4194304, 0, 128, 2048, 65536}; - constexpr const int shiftc[] = {0, 18, 12, 6, 0}; - constexpr const int shifte[] = {0, 6, 4, 2, 0}; - - int len = "\1\1\1\1\1\1\1\1\1\1\1\1\1\1\1\1\0\0\0\0\0\0\0\0\2\2\2\2\3\3\4" - [static_cast(*s) >> 3]; - // Compute the pointer to the next character early so that the next - // iteration can start working on the next character. Neither Clang - // nor GCC figure out this reordering on their own. - const char* next = s + len + !len; - - using uchar = unsigned char; - - // Assume a four-byte character and load four bytes. Unused bits are - // shifted out. - *c = uint32_t(uchar(s[0]) & masks[len]) << 18; - *c |= uint32_t(uchar(s[1]) & 0x3f) << 12; - *c |= uint32_t(uchar(s[2]) & 0x3f) << 6; - *c |= uint32_t(uchar(s[3]) & 0x3f) << 0; - *c >>= shiftc[len]; - - // Accumulate the various error conditions. - *e = (*c < mins[len]) << 6; // non-canonical encoding - *e |= ((*c >> 11) == 0x1b) << 7; // surrogate half? - *e |= (*c > 0x10FFFF) << 8; // out of range? - *e |= (uchar(s[1]) & 0xc0) >> 2; - *e |= (uchar(s[2]) & 0xc0) >> 4; - *e |= uchar(s[3]) >> 6; - *e ^= 0x2a; // top two bits of each tail byte correct? - *e >>= shifte[len]; - - return next; -} - -constexpr FMT_INLINE_VARIABLE uint32_t invalid_code_point = ~uint32_t(); - -// Invokes f(cp, sv) for every code point cp in s with sv being the string view -// corresponding to the code point. cp is invalid_code_point on error. -template -FMT_CONSTEXPR void for_each_codepoint(string_view s, F f) { - auto decode = [f](const char* buf_ptr, const char* ptr) { - auto cp = uint32_t(); - auto error = 0; - auto end = utf8_decode(buf_ptr, &cp, &error); - bool result = f(error ? invalid_code_point : cp, - string_view(ptr, error ? 1 : to_unsigned(end - buf_ptr))); - return result ? (error ? buf_ptr + 1 : end) : nullptr; - }; - auto p = s.data(); - const size_t block_size = 4; // utf8_decode always reads blocks of 4 chars. - if (s.size() >= block_size) { - for (auto end = p + s.size() - block_size + 1; p < end;) { - p = decode(p, p); - if (!p) return; - } - } - if (auto num_chars_left = s.data() + s.size() - p) { - char buf[2 * block_size - 1] = {}; - copy(p, p + num_chars_left, buf); - const char* buf_ptr = buf; - do { - auto end = decode(buf_ptr, p); - if (!end) return; - p += end - buf_ptr; - buf_ptr = end; - } while (buf_ptr - buf < num_chars_left); - } -} - -template -inline auto compute_width(basic_string_view s) -> size_t { - return s.size(); -} - -// Computes approximate display width of a UTF-8 string. -FMT_CONSTEXPR inline auto compute_width(string_view s) -> size_t { - size_t num_code_points = 0; - // It is not a lambda for compatibility with C++14. - struct count_code_points { - size_t* count; - FMT_CONSTEXPR auto operator()(uint32_t cp, string_view) const -> bool { - *count += detail::to_unsigned( - 1 + - (cp >= 0x1100 && - (cp <= 0x115f || // Hangul Jamo init. consonants - cp == 0x2329 || // LEFT-POINTING ANGLE BRACKET - cp == 0x232a || // RIGHT-POINTING ANGLE BRACKET - // CJK ... Yi except IDEOGRAPHIC HALF FILL SPACE: - (cp >= 0x2e80 && cp <= 0xa4cf && cp != 0x303f) || - (cp >= 0xac00 && cp <= 0xd7a3) || // Hangul Syllables - (cp >= 0xf900 && cp <= 0xfaff) || // CJK Compatibility Ideographs - (cp >= 0xfe10 && cp <= 0xfe19) || // Vertical Forms - (cp >= 0xfe30 && cp <= 0xfe6f) || // CJK Compatibility Forms - (cp >= 0xff00 && cp <= 0xff60) || // Fullwidth Forms - (cp >= 0xffe0 && cp <= 0xffe6) || // Fullwidth Forms - (cp >= 0x20000 && cp <= 0x2fffd) || // CJK - (cp >= 0x30000 && cp <= 0x3fffd) || - // Miscellaneous Symbols and Pictographs + Emoticons: - (cp >= 0x1f300 && cp <= 0x1f64f) || - // Supplemental Symbols and Pictographs: - (cp >= 0x1f900 && cp <= 0x1f9ff)))); - return true; - } - }; - // We could avoid branches by using utf8_decode directly. - for_each_codepoint(s, count_code_points{&num_code_points}); - return num_code_points; -} - -template -inline auto code_point_index(basic_string_view s, size_t n) -> size_t { - size_t size = s.size(); - return n < size ? n : size; -} - -// Calculates the index of the nth code point in a UTF-8 string. -inline auto code_point_index(string_view s, size_t n) -> size_t { - size_t result = s.size(); - const char* begin = s.begin(); - for_each_codepoint(s, [begin, &n, &result](uint32_t, string_view sv) { - if (n != 0) { - --n; - return true; - } - result = to_unsigned(sv.begin() - begin); - return false; - }); - return result; -} - -template struct is_integral : std::is_integral {}; -template <> struct is_integral : std::true_type {}; -template <> struct is_integral : std::true_type {}; - -template -using is_signed = - std::integral_constant::is_signed || - std::is_same::value>; - -template -using is_integer = - bool_constant::value && !std::is_same::value && - !std::is_same::value && - !std::is_same::value>; - -#ifndef FMT_USE_FLOAT -# define FMT_USE_FLOAT 1 -#endif -#ifndef FMT_USE_DOUBLE -# define FMT_USE_DOUBLE 1 -#endif -#ifndef FMT_USE_LONG_DOUBLE -# define FMT_USE_LONG_DOUBLE 1 -#endif - -#if defined(FMT_USE_FLOAT128) -// Use the provided definition. -#elif FMT_CLANG_VERSION && FMT_HAS_INCLUDE() -# define FMT_USE_FLOAT128 1 -#elif FMT_GCC_VERSION && defined(_GLIBCXX_USE_FLOAT128) && \ - !defined(__STRICT_ANSI__) -# define FMT_USE_FLOAT128 1 -#else -# define FMT_USE_FLOAT128 0 -#endif -#if FMT_USE_FLOAT128 -using float128 = __float128; -#else -using float128 = void; -#endif - -template using is_float128 = std::is_same; - -template -using is_floating_point = - bool_constant::value || is_float128::value>; - -template ::value> -struct is_fast_float : bool_constant::is_iec559 && - sizeof(T) <= sizeof(double)> {}; -template struct is_fast_float : std::false_type {}; - -template -using is_double_double = bool_constant::digits == 106>; - -#ifndef FMT_USE_FULL_CACHE_DRAGONBOX -# define FMT_USE_FULL_CACHE_DRAGONBOX 0 -#endif - -template -struct is_locale : std::false_type {}; -template -struct is_locale> : std::true_type {}; -} // namespace detail - -FMT_BEGIN_EXPORT - -// The number of characters to store in the basic_memory_buffer object itself -// to avoid dynamic memory allocation. -enum { inline_buffer_size = 500 }; - -/** - * A dynamically growing memory buffer for trivially copyable/constructible - * types with the first `SIZE` elements stored in the object itself. Most - * commonly used via the `memory_buffer` alias for `char`. - * - * **Example**: - * - * auto out = fmt::memory_buffer(); - * fmt::format_to(std::back_inserter(out), "The answer is {}.", 42); - * - * This will append "The answer is 42." to `out`. The buffer content can be - * converted to `std::string` with `to_string(out)`. - */ -template > -class basic_memory_buffer : public detail::buffer { - private: - T store_[SIZE]; - - // Don't inherit from Allocator to avoid generating type_info for it. - FMT_NO_UNIQUE_ADDRESS Allocator alloc_; - - // Deallocate memory allocated by the buffer. - FMT_CONSTEXPR20 void deallocate() { - T* data = this->data(); - if (data != store_) alloc_.deallocate(data, this->capacity()); - } - - static FMT_CONSTEXPR20 void grow(detail::buffer& buf, size_t size) { - detail::abort_fuzzing_if(size > 5000); - auto& self = static_cast(buf); - const size_t max_size = - std::allocator_traits::max_size(self.alloc_); - size_t old_capacity = buf.capacity(); - size_t new_capacity = old_capacity + old_capacity / 2; - if (size > new_capacity) - new_capacity = size; - else if (new_capacity > max_size) - new_capacity = size > max_size ? size : max_size; - T* old_data = buf.data(); - T* new_data = self.alloc_.allocate(new_capacity); - // Suppress a bogus -Wstringop-overflow in gcc 13.1 (#3481). - detail::assume(buf.size() <= new_capacity); - // The following code doesn't throw, so the raw pointer above doesn't leak. - memcpy(new_data, old_data, buf.size() * sizeof(T)); - self.set(new_data, new_capacity); - // deallocate must not throw according to the standard, but even if it does, - // the buffer already uses the new storage and will deallocate it in - // destructor. - if (old_data != self.store_) self.alloc_.deallocate(old_data, old_capacity); - } - - public: - using value_type = T; - using const_reference = const T&; - - FMT_CONSTEXPR20 explicit basic_memory_buffer( - const Allocator& alloc = Allocator()) - : detail::buffer(grow), alloc_(alloc) { - this->set(store_, SIZE); - if (detail::is_constant_evaluated()) detail::fill_n(store_, SIZE, T()); - } - FMT_CONSTEXPR20 ~basic_memory_buffer() { deallocate(); } - - private: - // Move data from other to this buffer. - FMT_CONSTEXPR20 void move(basic_memory_buffer& other) { - alloc_ = std::move(other.alloc_); - T* data = other.data(); - size_t size = other.size(), capacity = other.capacity(); - if (data == other.store_) { - this->set(store_, capacity); - detail::copy(other.store_, other.store_ + size, store_); - } else { - this->set(data, capacity); - // Set pointer to the inline array so that delete is not called - // when deallocating. - other.set(other.store_, 0); - other.clear(); - } - this->resize(size); - } - - public: - /// Constructs a `basic_memory_buffer` object moving the content of the other - /// object to it. - FMT_CONSTEXPR20 basic_memory_buffer(basic_memory_buffer&& other) noexcept - : detail::buffer(grow) { - move(other); - } - - /// Moves the content of the other `basic_memory_buffer` object to this one. - auto operator=(basic_memory_buffer&& other) noexcept -> basic_memory_buffer& { - FMT_ASSERT(this != &other, ""); - deallocate(); - move(other); - return *this; - } - - // Returns a copy of the allocator associated with this buffer. - auto get_allocator() const -> Allocator { return alloc_; } - - /// Resizes the buffer to contain `count` elements. If T is a POD type new - /// elements may not be initialized. - FMT_CONSTEXPR20 void resize(size_t count) { this->try_resize(count); } - - /// Increases the buffer capacity to `new_capacity`. - void reserve(size_t new_capacity) { this->try_reserve(new_capacity); } - - using detail::buffer::append; - template - void append(const ContiguousRange& range) { - append(range.data(), range.data() + range.size()); - } -}; - -using memory_buffer = basic_memory_buffer; - -template -struct is_contiguous> : std::true_type { -}; - -FMT_END_EXPORT -namespace detail { -FMT_API auto write_console(int fd, string_view text) -> bool; -FMT_API void print(std::FILE*, string_view); -} // namespace detail - -FMT_BEGIN_EXPORT - -// Suppress a misleading warning in older versions of clang. -#if FMT_CLANG_VERSION -# pragma clang diagnostic ignored "-Wweak-vtables" -#endif - -/// An error reported from a formatting function. -class FMT_SO_VISIBILITY("default") format_error : public std::runtime_error { - public: - using std::runtime_error::runtime_error; -}; - -namespace detail_exported { -#if FMT_USE_NONTYPE_TEMPLATE_ARGS -template struct fixed_string { - constexpr fixed_string(const Char (&str)[N]) { - detail::copy(static_cast(str), - str + N, data); - } - Char data[N] = {}; -}; -#endif - -// Converts a compile-time string to basic_string_view. -template -constexpr auto compile_string_to_view(const Char (&s)[N]) - -> basic_string_view { - // Remove trailing NUL character if needed. Won't be present if this is used - // with a raw character array (i.e. not defined as a string). - return {s, N - (std::char_traits::to_int_type(s[N - 1]) == 0 ? 1 : 0)}; -} -template -constexpr auto compile_string_to_view(basic_string_view s) - -> basic_string_view { - return s; -} -} // namespace detail_exported - -// A generic formatting context with custom output iterator and character -// (code unit) support. Char is the format string code unit type which can be -// different from OutputIt::value_type. -template class generic_context { - private: - OutputIt out_; - basic_format_args args_; - detail::locale_ref loc_; - - public: - using char_type = Char; - using iterator = OutputIt; - using parse_context_type = basic_format_parse_context; - template using formatter_type = formatter; - - constexpr generic_context(OutputIt out, - basic_format_args ctx_args, - detail::locale_ref loc = {}) - : out_(out), args_(ctx_args), loc_(loc) {} - generic_context(generic_context&&) = default; - generic_context(const generic_context&) = delete; - void operator=(const generic_context&) = delete; - - constexpr auto arg(int id) const -> basic_format_arg { - return args_.get(id); - } - auto arg(basic_string_view name) -> basic_format_arg { - return args_.get(name); - } - FMT_CONSTEXPR auto arg_id(basic_string_view name) -> int { - return args_.get_id(name); - } - auto args() const -> const basic_format_args& { - return args_; - } - - FMT_CONSTEXPR auto out() -> iterator { return out_; } - - void advance_to(iterator it) { - if (!detail::is_back_insert_iterator()) out_ = it; - } - - FMT_CONSTEXPR auto locale() -> detail::locale_ref { return loc_; } -}; - -class loc_value { - private: - basic_format_arg value_; - - public: - template ::value)> - loc_value(T value) : value_(detail::make_arg(value)) {} - - template ::value)> - loc_value(T) {} - - template auto visit(Visitor&& vis) -> decltype(vis(0)) { - return value_.visit(vis); - } -}; - -// A locale facet that formats values in UTF-8. -// It is parameterized on the locale to avoid the heavy include. -template class format_facet : public Locale::facet { - private: - std::string separator_; - std::string grouping_; - std::string decimal_point_; - - protected: - virtual auto do_put(appender out, loc_value val, - const format_specs& specs) const -> bool; - - public: - static FMT_API typename Locale::id id; - - explicit format_facet(Locale& loc); - explicit format_facet(string_view sep = "", - std::initializer_list g = {3}, - std::string decimal_point = ".") - : separator_(sep.data(), sep.size()), - grouping_(g.begin(), g.end()), - decimal_point_(decimal_point) {} - - auto put(appender out, loc_value val, const format_specs& specs) const - -> bool { - return do_put(out, val, specs); - } -}; - -FMT_END_EXPORT - -namespace detail { - -// Returns true if value is negative, false otherwise. -// Same as `value < 0` but doesn't produce warnings if T is an unsigned type. -template ::value)> -constexpr auto is_negative(T value) -> bool { - return value < 0; -} -template ::value)> -constexpr auto is_negative(T) -> bool { - return false; -} - -template -FMT_CONSTEXPR auto is_supported_floating_point(T) -> bool { - if (std::is_same()) return FMT_USE_FLOAT; - if (std::is_same()) return FMT_USE_DOUBLE; - if (std::is_same()) return FMT_USE_LONG_DOUBLE; - return true; -} - -// Smallest of uint32_t, uint64_t, uint128_t that is large enough to -// represent all values of an integral type T. -template -using uint32_or_64_or_128_t = - conditional_t() <= 32 && !FMT_REDUCE_INT_INSTANTIATIONS, - uint32_t, - conditional_t() <= 64, uint64_t, uint128_t>>; -template -using uint64_or_128_t = conditional_t() <= 64, uint64_t, uint128_t>; - -#define FMT_POWERS_OF_10(factor) \ - factor * 10, (factor) * 100, (factor) * 1000, (factor) * 10000, \ - (factor) * 100000, (factor) * 1000000, (factor) * 10000000, \ - (factor) * 100000000, (factor) * 1000000000 - -// Converts value in the range [0, 100) to a string. -constexpr auto digits2(size_t value) -> const char* { - // GCC generates slightly better code when value is pointer-size. - return &"0001020304050607080910111213141516171819" - "2021222324252627282930313233343536373839" - "4041424344454647484950515253545556575859" - "6061626364656667686970717273747576777879" - "8081828384858687888990919293949596979899"[value * 2]; -} - -// Sign is a template parameter to workaround a bug in gcc 4.8. -template constexpr auto sign(Sign s) -> Char { -#if !FMT_GCC_VERSION || FMT_GCC_VERSION >= 604 - static_assert(std::is_same::value, ""); -#endif - return static_cast("\0-+ "[s]); -} - -template FMT_CONSTEXPR auto count_digits_fallback(T n) -> int { - int count = 1; - for (;;) { - // Integer division is slow so do it for a group of four digits instead - // of for every digit. The idea comes from the talk by Alexandrescu - // "Three Optimization Tips for C++". See speed-test for a comparison. - if (n < 10) return count; - if (n < 100) return count + 1; - if (n < 1000) return count + 2; - if (n < 10000) return count + 3; - n /= 10000u; - count += 4; - } -} -#if FMT_USE_INT128 -FMT_CONSTEXPR inline auto count_digits(uint128_opt n) -> int { - return count_digits_fallback(n); -} -#endif - -#ifdef FMT_BUILTIN_CLZLL -// It is a separate function rather than a part of count_digits to workaround -// the lack of static constexpr in constexpr functions. -inline auto do_count_digits(uint64_t n) -> int { - // This has comparable performance to the version by Kendall Willets - // (https://github.com/fmtlib/format-benchmark/blob/master/digits10) - // but uses smaller tables. - // Maps bsr(n) to ceil(log10(pow(2, bsr(n) + 1) - 1)). - static constexpr uint8_t bsr2log10[] = { - 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, - 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, - 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 15, 15, - 15, 16, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 19, 20}; - auto t = bsr2log10[FMT_BUILTIN_CLZLL(n | 1) ^ 63]; - static constexpr const uint64_t zero_or_powers_of_10[] = { - 0, 0, FMT_POWERS_OF_10(1U), FMT_POWERS_OF_10(1000000000ULL), - 10000000000000000000ULL}; - return t - (n < zero_or_powers_of_10[t]); -} -#endif - -// Returns the number of decimal digits in n. Leading zeros are not counted -// except for n == 0 in which case count_digits returns 1. -FMT_CONSTEXPR20 inline auto count_digits(uint64_t n) -> int { -#ifdef FMT_BUILTIN_CLZLL - if (!is_constant_evaluated()) return do_count_digits(n); -#endif - return count_digits_fallback(n); -} - -// Counts the number of digits in n. BITS = log2(radix). -template -FMT_CONSTEXPR auto count_digits(UInt n) -> int { -#ifdef FMT_BUILTIN_CLZ - if (!is_constant_evaluated() && num_bits() == 32) - return (FMT_BUILTIN_CLZ(static_cast(n) | 1) ^ 31) / BITS + 1; -#endif - // Lambda avoids unreachable code warnings from NVHPC. - return [](UInt m) { - int num_digits = 0; - do { - ++num_digits; - } while ((m >>= BITS) != 0); - return num_digits; - }(n); -} - -#ifdef FMT_BUILTIN_CLZ -// It is a separate function rather than a part of count_digits to workaround -// the lack of static constexpr in constexpr functions. -FMT_INLINE auto do_count_digits(uint32_t n) -> int { -// An optimization by Kendall Willets from https://bit.ly/3uOIQrB. -// This increments the upper 32 bits (log10(T) - 1) when >= T is added. -# define FMT_INC(T) (((sizeof(#T) - 1ull) << 32) - T) - static constexpr uint64_t table[] = { - FMT_INC(0), FMT_INC(0), FMT_INC(0), // 8 - FMT_INC(10), FMT_INC(10), FMT_INC(10), // 64 - FMT_INC(100), FMT_INC(100), FMT_INC(100), // 512 - FMT_INC(1000), FMT_INC(1000), FMT_INC(1000), // 4096 - FMT_INC(10000), FMT_INC(10000), FMT_INC(10000), // 32k - FMT_INC(100000), FMT_INC(100000), FMT_INC(100000), // 256k - FMT_INC(1000000), FMT_INC(1000000), FMT_INC(1000000), // 2048k - FMT_INC(10000000), FMT_INC(10000000), FMT_INC(10000000), // 16M - FMT_INC(100000000), FMT_INC(100000000), FMT_INC(100000000), // 128M - FMT_INC(1000000000), FMT_INC(1000000000), FMT_INC(1000000000), // 1024M - FMT_INC(1000000000), FMT_INC(1000000000) // 4B - }; - auto inc = table[FMT_BUILTIN_CLZ(n | 1) ^ 31]; - return static_cast((n + inc) >> 32); -} -#endif - -// Optional version of count_digits for better performance on 32-bit platforms. -FMT_CONSTEXPR20 inline auto count_digits(uint32_t n) -> int { -#ifdef FMT_BUILTIN_CLZ - if (!is_constant_evaluated()) { - return do_count_digits(n); - } -#endif - return count_digits_fallback(n); -} - -template constexpr auto digits10() noexcept -> int { - return std::numeric_limits::digits10; -} -template <> constexpr auto digits10() noexcept -> int { return 38; } -template <> constexpr auto digits10() noexcept -> int { return 38; } - -template struct thousands_sep_result { - std::string grouping; - Char thousands_sep; -}; - -template -FMT_API auto thousands_sep_impl(locale_ref loc) -> thousands_sep_result; -template -inline auto thousands_sep(locale_ref loc) -> thousands_sep_result { - auto result = thousands_sep_impl(loc); - return {result.grouping, Char(result.thousands_sep)}; -} -template <> -inline auto thousands_sep(locale_ref loc) -> thousands_sep_result { - return thousands_sep_impl(loc); -} - -template -FMT_API auto decimal_point_impl(locale_ref loc) -> Char; -template inline auto decimal_point(locale_ref loc) -> Char { - return Char(decimal_point_impl(loc)); -} -template <> inline auto decimal_point(locale_ref loc) -> wchar_t { - return decimal_point_impl(loc); -} - -// Compares two characters for equality. -template auto equal2(const Char* lhs, const char* rhs) -> bool { - return lhs[0] == Char(rhs[0]) && lhs[1] == Char(rhs[1]); -} -inline auto equal2(const char* lhs, const char* rhs) -> bool { - return memcmp(lhs, rhs, 2) == 0; -} - -// Copies two characters from src to dst. -template -FMT_CONSTEXPR20 FMT_INLINE void copy2(Char* dst, const char* src) { - if (!is_constant_evaluated() && sizeof(Char) == sizeof(char)) { - memcpy(dst, src, 2); - return; - } - *dst++ = static_cast(*src++); - *dst = static_cast(*src); -} - -template struct format_decimal_result { - Iterator begin; - Iterator end; -}; - -// Formats a decimal unsigned integer value writing into out pointing to a -// buffer of specified size. The caller must ensure that the buffer is large -// enough. -template -FMT_CONSTEXPR20 auto format_decimal(Char* out, UInt value, int size) - -> format_decimal_result { - FMT_ASSERT(size >= count_digits(value), "invalid digit count"); - out += size; - Char* end = out; - while (value >= 100) { - // Integer division is slow so do it for a group of two digits instead - // of for every digit. The idea comes from the talk by Alexandrescu - // "Three Optimization Tips for C++". See speed-test for a comparison. - out -= 2; - copy2(out, digits2(static_cast(value % 100))); - value /= 100; - } - if (value < 10) { - *--out = static_cast('0' + value); - return {out, end}; - } - out -= 2; - copy2(out, digits2(static_cast(value))); - return {out, end}; -} - -template >::value)> -FMT_CONSTEXPR inline auto format_decimal(Iterator out, UInt value, int size) - -> format_decimal_result { - // Buffer is large enough to hold all digits (digits10 + 1). - Char buffer[digits10() + 1] = {}; - auto end = format_decimal(buffer, value, size).end; - return {out, detail::copy_noinline(buffer, end, out)}; -} - -template -FMT_CONSTEXPR auto format_uint(Char* buffer, UInt value, int num_digits, - bool upper = false) -> Char* { - buffer += num_digits; - Char* end = buffer; - do { - const char* digits = upper ? "0123456789ABCDEF" : "0123456789abcdef"; - unsigned digit = static_cast(value & ((1 << BASE_BITS) - 1)); - *--buffer = static_cast(BASE_BITS < 4 ? static_cast('0' + digit) - : digits[digit]); - } while ((value >>= BASE_BITS) != 0); - return end; -} - -template -FMT_CONSTEXPR inline auto format_uint(It out, UInt value, int num_digits, - bool upper = false) -> It { - if (auto ptr = to_pointer(out, to_unsigned(num_digits))) { - format_uint(ptr, value, num_digits, upper); - return out; - } - // Buffer should be large enough to hold all digits (digits / BASE_BITS + 1). - char buffer[num_bits() / BASE_BITS + 1] = {}; - format_uint(buffer, value, num_digits, upper); - return detail::copy_noinline(buffer, buffer + num_digits, out); -} - -// A converter from UTF-8 to UTF-16. -class utf8_to_utf16 { - private: - basic_memory_buffer buffer_; - - public: - FMT_API explicit utf8_to_utf16(string_view s); - operator basic_string_view() const { return {&buffer_[0], size()}; } - auto size() const -> size_t { return buffer_.size() - 1; } - auto c_str() const -> const wchar_t* { return &buffer_[0]; } - auto str() const -> std::wstring { return {&buffer_[0], size()}; } -}; - -enum class to_utf8_error_policy { abort, replace }; - -// A converter from UTF-16/UTF-32 (host endian) to UTF-8. -template class to_utf8 { - private: - Buffer buffer_; - - public: - to_utf8() {} - explicit to_utf8(basic_string_view s, - to_utf8_error_policy policy = to_utf8_error_policy::abort) { - static_assert(sizeof(WChar) == 2 || sizeof(WChar) == 4, - "Expect utf16 or utf32"); - if (!convert(s, policy)) - FMT_THROW(std::runtime_error(sizeof(WChar) == 2 ? "invalid utf16" - : "invalid utf32")); - } - operator string_view() const { return string_view(&buffer_[0], size()); } - auto size() const -> size_t { return buffer_.size() - 1; } - auto c_str() const -> const char* { return &buffer_[0]; } - auto str() const -> std::string { return std::string(&buffer_[0], size()); } - - // Performs conversion returning a bool instead of throwing exception on - // conversion error. This method may still throw in case of memory allocation - // error. - auto convert(basic_string_view s, - to_utf8_error_policy policy = to_utf8_error_policy::abort) - -> bool { - if (!convert(buffer_, s, policy)) return false; - buffer_.push_back(0); - return true; - } - static auto convert(Buffer& buf, basic_string_view s, - to_utf8_error_policy policy = to_utf8_error_policy::abort) - -> bool { - for (auto p = s.begin(); p != s.end(); ++p) { - uint32_t c = static_cast(*p); - if (sizeof(WChar) == 2 && c >= 0xd800 && c <= 0xdfff) { - // Handle a surrogate pair. - ++p; - if (p == s.end() || (c & 0xfc00) != 0xd800 || (*p & 0xfc00) != 0xdc00) { - if (policy == to_utf8_error_policy::abort) return false; - buf.append(string_view("\xEF\xBF\xBD")); - --p; - } else { - c = (c << 10) + static_cast(*p) - 0x35fdc00; - } - } else if (c < 0x80) { - buf.push_back(static_cast(c)); - } else if (c < 0x800) { - buf.push_back(static_cast(0xc0 | (c >> 6))); - buf.push_back(static_cast(0x80 | (c & 0x3f))); - } else if ((c >= 0x800 && c <= 0xd7ff) || (c >= 0xe000 && c <= 0xffff)) { - buf.push_back(static_cast(0xe0 | (c >> 12))); - buf.push_back(static_cast(0x80 | ((c & 0xfff) >> 6))); - buf.push_back(static_cast(0x80 | (c & 0x3f))); - } else if (c >= 0x10000 && c <= 0x10ffff) { - buf.push_back(static_cast(0xf0 | (c >> 18))); - buf.push_back(static_cast(0x80 | ((c & 0x3ffff) >> 12))); - buf.push_back(static_cast(0x80 | ((c & 0xfff) >> 6))); - buf.push_back(static_cast(0x80 | (c & 0x3f))); - } else { - return false; - } - } - return true; - } -}; - -// Computes 128-bit result of multiplication of two 64-bit unsigned integers. -inline auto umul128(uint64_t x, uint64_t y) noexcept -> uint128_fallback { -#if FMT_USE_INT128 - auto p = static_cast(x) * static_cast(y); - return {static_cast(p >> 64), static_cast(p)}; -#elif defined(_MSC_VER) && defined(_M_X64) - auto hi = uint64_t(); - auto lo = _umul128(x, y, &hi); - return {hi, lo}; -#else - const uint64_t mask = static_cast(max_value()); - - uint64_t a = x >> 32; - uint64_t b = x & mask; - uint64_t c = y >> 32; - uint64_t d = y & mask; - - uint64_t ac = a * c; - uint64_t bc = b * c; - uint64_t ad = a * d; - uint64_t bd = b * d; - - uint64_t intermediate = (bd >> 32) + (ad & mask) + (bc & mask); - - return {ac + (intermediate >> 32) + (ad >> 32) + (bc >> 32), - (intermediate << 32) + (bd & mask)}; -#endif -} - -namespace dragonbox { -// Computes floor(log10(pow(2, e))) for e in [-2620, 2620] using the method from -// https://fmt.dev/papers/Dragonbox.pdf#page=28, section 6.1. -inline auto floor_log10_pow2(int e) noexcept -> int { - FMT_ASSERT(e <= 2620 && e >= -2620, "too large exponent"); - static_assert((-1 >> 1) == -1, "right shift is not arithmetic"); - return (e * 315653) >> 20; -} - -inline auto floor_log2_pow10(int e) noexcept -> int { - FMT_ASSERT(e <= 1233 && e >= -1233, "too large exponent"); - return (e * 1741647) >> 19; -} - -// Computes upper 64 bits of multiplication of two 64-bit unsigned integers. -inline auto umul128_upper64(uint64_t x, uint64_t y) noexcept -> uint64_t { -#if FMT_USE_INT128 - auto p = static_cast(x) * static_cast(y); - return static_cast(p >> 64); -#elif defined(_MSC_VER) && defined(_M_X64) - return __umulh(x, y); -#else - return umul128(x, y).high(); -#endif -} - -// Computes upper 128 bits of multiplication of a 64-bit unsigned integer and a -// 128-bit unsigned integer. -inline auto umul192_upper128(uint64_t x, uint128_fallback y) noexcept - -> uint128_fallback { - uint128_fallback r = umul128(x, y.high()); - r += umul128_upper64(x, y.low()); - return r; -} - -FMT_API auto get_cached_power(int k) noexcept -> uint128_fallback; - -// Type-specific information that Dragonbox uses. -template struct float_info; - -template <> struct float_info { - using carrier_uint = uint32_t; - static const int exponent_bits = 8; - static const int kappa = 1; - static const int big_divisor = 100; - static const int small_divisor = 10; - static const int min_k = -31; - static const int max_k = 46; - static const int shorter_interval_tie_lower_threshold = -35; - static const int shorter_interval_tie_upper_threshold = -35; -}; - -template <> struct float_info { - using carrier_uint = uint64_t; - static const int exponent_bits = 11; - static const int kappa = 2; - static const int big_divisor = 1000; - static const int small_divisor = 100; - static const int min_k = -292; - static const int max_k = 341; - static const int shorter_interval_tie_lower_threshold = -77; - static const int shorter_interval_tie_upper_threshold = -77; -}; - -// An 80- or 128-bit floating point number. -template -struct float_info::digits == 64 || - std::numeric_limits::digits == 113 || - is_float128::value>> { - using carrier_uint = detail::uint128_t; - static const int exponent_bits = 15; -}; - -// A double-double floating point number. -template -struct float_info::value>> { - using carrier_uint = detail::uint128_t; -}; - -template struct decimal_fp { - using significand_type = typename float_info::carrier_uint; - significand_type significand; - int exponent; -}; - -template FMT_API auto to_decimal(T x) noexcept -> decimal_fp; -} // namespace dragonbox - -// Returns true iff Float has the implicit bit which is not stored. -template constexpr auto has_implicit_bit() -> bool { - // An 80-bit FP number has a 64-bit significand an no implicit bit. - return std::numeric_limits::digits != 64; -} - -// Returns the number of significand bits stored in Float. The implicit bit is -// not counted since it is not stored. -template constexpr auto num_significand_bits() -> int { - // std::numeric_limits may not support __float128. - return is_float128() ? 112 - : (std::numeric_limits::digits - - (has_implicit_bit() ? 1 : 0)); -} - -template -constexpr auto exponent_mask() -> - typename dragonbox::float_info::carrier_uint { - using float_uint = typename dragonbox::float_info::carrier_uint; - return ((float_uint(1) << dragonbox::float_info::exponent_bits) - 1) - << num_significand_bits(); -} -template constexpr auto exponent_bias() -> int { - // std::numeric_limits may not support __float128. - return is_float128() ? 16383 - : std::numeric_limits::max_exponent - 1; -} - -// Writes the exponent exp in the form "[+-]d{2,3}" to buffer. -template -FMT_CONSTEXPR auto write_exponent(int exp, It it) -> It { - FMT_ASSERT(-10000 < exp && exp < 10000, "exponent out of range"); - if (exp < 0) { - *it++ = static_cast('-'); - exp = -exp; - } else { - *it++ = static_cast('+'); - } - if (exp >= 100) { - const char* top = digits2(to_unsigned(exp / 100)); - if (exp >= 1000) *it++ = static_cast(top[0]); - *it++ = static_cast(top[1]); - exp %= 100; - } - const char* d = digits2(to_unsigned(exp)); - *it++ = static_cast(d[0]); - *it++ = static_cast(d[1]); - return it; -} - -// A floating-point number f * pow(2, e) where F is an unsigned type. -template struct basic_fp { - F f; - int e; - - static constexpr const int num_significand_bits = - static_cast(sizeof(F) * num_bits()); - - constexpr basic_fp() : f(0), e(0) {} - constexpr basic_fp(uint64_t f_val, int e_val) : f(f_val), e(e_val) {} - - // Constructs fp from an IEEE754 floating-point number. - template FMT_CONSTEXPR basic_fp(Float n) { assign(n); } - - // Assigns n to this and return true iff predecessor is closer than successor. - template ::value)> - FMT_CONSTEXPR auto assign(Float n) -> bool { - static_assert(std::numeric_limits::digits <= 113, "unsupported FP"); - // Assume Float is in the format [sign][exponent][significand]. - using carrier_uint = typename dragonbox::float_info::carrier_uint; - const auto num_float_significand_bits = - detail::num_significand_bits(); - const auto implicit_bit = carrier_uint(1) << num_float_significand_bits; - const auto significand_mask = implicit_bit - 1; - auto u = bit_cast(n); - f = static_cast(u & significand_mask); - auto biased_e = static_cast((u & exponent_mask()) >> - num_float_significand_bits); - // The predecessor is closer if n is a normalized power of 2 (f == 0) - // other than the smallest normalized number (biased_e > 1). - auto is_predecessor_closer = f == 0 && biased_e > 1; - if (biased_e == 0) - biased_e = 1; // Subnormals use biased exponent 1 (min exponent). - else if (has_implicit_bit()) - f += static_cast(implicit_bit); - e = biased_e - exponent_bias() - num_float_significand_bits; - if (!has_implicit_bit()) ++e; - return is_predecessor_closer; - } - - template ::value)> - FMT_CONSTEXPR auto assign(Float n) -> bool { - static_assert(std::numeric_limits::is_iec559, "unsupported FP"); - return assign(static_cast(n)); - } -}; - -using fp = basic_fp; - -// Normalizes the value converted from double and multiplied by (1 << SHIFT). -template -FMT_CONSTEXPR auto normalize(basic_fp value) -> basic_fp { - // Handle subnormals. - const auto implicit_bit = F(1) << num_significand_bits(); - const auto shifted_implicit_bit = implicit_bit << SHIFT; - while ((value.f & shifted_implicit_bit) == 0) { - value.f <<= 1; - --value.e; - } - // Subtract 1 to account for hidden bit. - const auto offset = basic_fp::num_significand_bits - - num_significand_bits() - SHIFT - 1; - value.f <<= offset; - value.e -= offset; - return value; -} - -// Computes lhs * rhs / pow(2, 64) rounded to nearest with half-up tie breaking. -FMT_CONSTEXPR inline auto multiply(uint64_t lhs, uint64_t rhs) -> uint64_t { -#if FMT_USE_INT128 - auto product = static_cast<__uint128_t>(lhs) * rhs; - auto f = static_cast(product >> 64); - return (static_cast(product) & (1ULL << 63)) != 0 ? f + 1 : f; -#else - // Multiply 32-bit parts of significands. - uint64_t mask = (1ULL << 32) - 1; - uint64_t a = lhs >> 32, b = lhs & mask; - uint64_t c = rhs >> 32, d = rhs & mask; - uint64_t ac = a * c, bc = b * c, ad = a * d, bd = b * d; - // Compute mid 64-bit of result and round. - uint64_t mid = (bd >> 32) + (ad & mask) + (bc & mask) + (1U << 31); - return ac + (ad >> 32) + (bc >> 32) + (mid >> 32); -#endif -} - -FMT_CONSTEXPR inline auto operator*(fp x, fp y) -> fp { - return {multiply(x.f, y.f), x.e + y.e + 64}; -} - -template () == num_bits()> -using convert_float_result = - conditional_t::value || doublish, double, T>; - -template -constexpr auto convert_float(T value) -> convert_float_result { - return static_cast>(value); -} - -template -FMT_NOINLINE FMT_CONSTEXPR auto fill(OutputIt it, size_t n, const fill_t& fill) - -> OutputIt { - auto fill_size = fill.size(); - if (fill_size == 1) return detail::fill_n(it, n, fill.template get()); - if (const Char* data = fill.template data()) { - for (size_t i = 0; i < n; ++i) it = copy(data, data + fill_size, it); - } - return it; -} - -// Writes the output of f, padded according to format specifications in specs. -// size: output size in code units. -// width: output display width in (terminal) column positions. -template -FMT_CONSTEXPR auto write_padded(OutputIt out, const format_specs& specs, - size_t size, size_t width, F&& f) -> OutputIt { - static_assert(align == align::left || align == align::right, ""); - unsigned spec_width = to_unsigned(specs.width); - size_t padding = spec_width > width ? spec_width - width : 0; - // Shifts are encoded as string literals because static constexpr is not - // supported in constexpr functions. - auto* shifts = align == align::left ? "\x1f\x1f\x00\x01" : "\x00\x1f\x00\x01"; - size_t left_padding = padding >> shifts[specs.align]; - size_t right_padding = padding - left_padding; - auto it = reserve(out, size + padding * specs.fill.size()); - if (left_padding != 0) it = fill(it, left_padding, specs.fill); - it = f(it); - if (right_padding != 0) it = fill(it, right_padding, specs.fill); - return base_iterator(out, it); -} - -template -constexpr auto write_padded(OutputIt out, const format_specs& specs, - size_t size, F&& f) -> OutputIt { - return write_padded(out, specs, size, size, f); -} - -template -FMT_CONSTEXPR auto write_bytes(OutputIt out, string_view bytes, - const format_specs& specs = {}) -> OutputIt { - return write_padded( - out, specs, bytes.size(), [bytes](reserve_iterator it) { - const char* data = bytes.data(); - return copy(data, data + bytes.size(), it); - }); -} - -template -auto write_ptr(OutputIt out, UIntPtr value, const format_specs* specs) - -> OutputIt { - int num_digits = count_digits<4>(value); - auto size = to_unsigned(num_digits) + size_t(2); - auto write = [=](reserve_iterator it) { - *it++ = static_cast('0'); - *it++ = static_cast('x'); - return format_uint<4, Char>(it, value, num_digits); - }; - return specs ? write_padded(out, *specs, size, write) - : base_iterator(out, write(reserve(out, size))); -} - -// Returns true iff the code point cp is printable. -FMT_API auto is_printable(uint32_t cp) -> bool; - -inline auto needs_escape(uint32_t cp) -> bool { - return cp < 0x20 || cp == 0x7f || cp == '"' || cp == '\\' || - !is_printable(cp); -} - -template struct find_escape_result { - const Char* begin; - const Char* end; - uint32_t cp; -}; - -template -auto find_escape(const Char* begin, const Char* end) - -> find_escape_result { - for (; begin != end; ++begin) { - uint32_t cp = static_cast>(*begin); - if (const_check(sizeof(Char) == 1) && cp >= 0x80) continue; - if (needs_escape(cp)) return {begin, begin + 1, cp}; - } - return {begin, nullptr, 0}; -} - -inline auto find_escape(const char* begin, const char* end) - -> find_escape_result { - if (!use_utf8()) return find_escape(begin, end); - auto result = find_escape_result{end, nullptr, 0}; - for_each_codepoint(string_view(begin, to_unsigned(end - begin)), - [&](uint32_t cp, string_view sv) { - if (needs_escape(cp)) { - result = {sv.begin(), sv.end(), cp}; - return false; - } - return true; - }); - return result; -} - -#define FMT_STRING_IMPL(s, base, explicit) \ - [] { \ - /* Use the hidden visibility as a workaround for a GCC bug (#1973). */ \ - /* Use a macro-like name to avoid shadowing warnings. */ \ - struct FMT_VISIBILITY("hidden") FMT_COMPILE_STRING : base { \ - using char_type FMT_MAYBE_UNUSED = fmt::remove_cvref_t; \ - FMT_MAYBE_UNUSED FMT_CONSTEXPR explicit \ - operator fmt::basic_string_view() const { \ - return fmt::detail_exported::compile_string_to_view(s); \ - } \ - }; \ - return FMT_COMPILE_STRING(); \ - }() - -/** - * Constructs a compile-time format string from a string literal `s`. - * - * **Example**: - * - * // A compile-time error because 'd' is an invalid specifier for strings. - * std::string s = fmt::format(FMT_STRING("{:d}"), "foo"); - */ -#define FMT_STRING(s) FMT_STRING_IMPL(s, fmt::detail::compile_string, ) - -template -auto write_codepoint(OutputIt out, char prefix, uint32_t cp) -> OutputIt { - *out++ = static_cast('\\'); - *out++ = static_cast(prefix); - Char buf[width]; - fill_n(buf, width, static_cast('0')); - format_uint<4>(buf, cp, width); - return copy(buf, buf + width, out); -} - -template -auto write_escaped_cp(OutputIt out, const find_escape_result& escape) - -> OutputIt { - auto c = static_cast(escape.cp); - switch (escape.cp) { - case '\n': - *out++ = static_cast('\\'); - c = static_cast('n'); - break; - case '\r': - *out++ = static_cast('\\'); - c = static_cast('r'); - break; - case '\t': - *out++ = static_cast('\\'); - c = static_cast('t'); - break; - case '"': - FMT_FALLTHROUGH; - case '\'': - FMT_FALLTHROUGH; - case '\\': - *out++ = static_cast('\\'); - break; - default: - if (escape.cp < 0x100) return write_codepoint<2, Char>(out, 'x', escape.cp); - if (escape.cp < 0x10000) - return write_codepoint<4, Char>(out, 'u', escape.cp); - if (escape.cp < 0x110000) - return write_codepoint<8, Char>(out, 'U', escape.cp); - for (Char escape_char : basic_string_view( - escape.begin, to_unsigned(escape.end - escape.begin))) { - out = write_codepoint<2, Char>(out, 'x', - static_cast(escape_char) & 0xFF); - } - return out; - } - *out++ = c; - return out; -} - -template -auto write_escaped_string(OutputIt out, basic_string_view str) - -> OutputIt { - *out++ = static_cast('"'); - auto begin = str.begin(), end = str.end(); - do { - auto escape = find_escape(begin, end); - out = copy(begin, escape.begin, out); - begin = escape.end; - if (!begin) break; - out = write_escaped_cp(out, escape); - } while (begin != end); - *out++ = static_cast('"'); - return out; -} - -template -auto write_escaped_char(OutputIt out, Char v) -> OutputIt { - Char v_array[1] = {v}; - *out++ = static_cast('\''); - if ((needs_escape(static_cast(v)) && v != static_cast('"')) || - v == static_cast('\'')) { - out = write_escaped_cp(out, - find_escape_result{v_array, v_array + 1, - static_cast(v)}); - } else { - *out++ = v; - } - *out++ = static_cast('\''); - return out; -} - -template -FMT_CONSTEXPR auto write_char(OutputIt out, Char value, - const format_specs& specs) -> OutputIt { - bool is_debug = specs.type == presentation_type::debug; - return write_padded(out, specs, 1, [=](reserve_iterator it) { - if (is_debug) return write_escaped_char(it, value); - *it++ = value; - return it; - }); -} -template -FMT_CONSTEXPR auto write(OutputIt out, Char value, const format_specs& specs, - locale_ref loc = {}) -> OutputIt { - // char is formatted as unsigned char for consistency across platforms. - using unsigned_type = - conditional_t::value, unsigned char, unsigned>; - return check_char_specs(specs) - ? write_char(out, value, specs) - : write(out, static_cast(value), specs, loc); -} - -// Data for write_int that doesn't depend on output iterator type. It is used to -// avoid template code bloat. -template struct write_int_data { - size_t size; - size_t padding; - - FMT_CONSTEXPR write_int_data(int num_digits, unsigned prefix, - const format_specs& specs) - : size((prefix >> 24) + to_unsigned(num_digits)), padding(0) { - if (specs.align == align::numeric) { - auto width = to_unsigned(specs.width); - if (width > size) { - padding = width - size; - size = width; - } - } else if (specs.precision > num_digits) { - size = (prefix >> 24) + to_unsigned(specs.precision); - padding = to_unsigned(specs.precision - num_digits); - } - } -}; - -// Writes an integer in the format -// -// where are written by write_digits(it). -// prefix contains chars in three lower bytes and the size in the fourth byte. -template -FMT_CONSTEXPR FMT_INLINE auto write_int(OutputIt out, int num_digits, - unsigned prefix, - const format_specs& specs, - W write_digits) -> OutputIt { - // Slightly faster check for specs.width == 0 && specs.precision == -1. - if ((specs.width | (specs.precision + 1)) == 0) { - auto it = reserve(out, to_unsigned(num_digits) + (prefix >> 24)); - if (prefix != 0) { - for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) - *it++ = static_cast(p & 0xff); - } - return base_iterator(out, write_digits(it)); - } - auto data = write_int_data(num_digits, prefix, specs); - return write_padded( - out, specs, data.size, [=](reserve_iterator it) { - for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) - *it++ = static_cast(p & 0xff); - it = detail::fill_n(it, data.padding, static_cast('0')); - return write_digits(it); - }); -} - -template class digit_grouping { - private: - std::string grouping_; - std::basic_string thousands_sep_; - - struct next_state { - std::string::const_iterator group; - int pos; - }; - auto initial_state() const -> next_state { return {grouping_.begin(), 0}; } - - // Returns the next digit group separator position. - auto next(next_state& state) const -> int { - if (thousands_sep_.empty()) return max_value(); - if (state.group == grouping_.end()) return state.pos += grouping_.back(); - if (*state.group <= 0 || *state.group == max_value()) - return max_value(); - state.pos += *state.group++; - return state.pos; - } - - public: - explicit digit_grouping(locale_ref loc, bool localized = true) { - if (!localized) return; - auto sep = thousands_sep(loc); - grouping_ = sep.grouping; - if (sep.thousands_sep) thousands_sep_.assign(1, sep.thousands_sep); - } - digit_grouping(std::string grouping, std::basic_string sep) - : grouping_(std::move(grouping)), thousands_sep_(std::move(sep)) {} - - auto has_separator() const -> bool { return !thousands_sep_.empty(); } - - auto count_separators(int num_digits) const -> int { - int count = 0; - auto state = initial_state(); - while (num_digits > next(state)) ++count; - return count; - } - - // Applies grouping to digits and write the output to out. - template - auto apply(Out out, basic_string_view digits) const -> Out { - auto num_digits = static_cast(digits.size()); - auto separators = basic_memory_buffer(); - separators.push_back(0); - auto state = initial_state(); - while (int i = next(state)) { - if (i >= num_digits) break; - separators.push_back(i); - } - for (int i = 0, sep_index = static_cast(separators.size() - 1); - i < num_digits; ++i) { - if (num_digits - i == separators[sep_index]) { - out = copy(thousands_sep_.data(), - thousands_sep_.data() + thousands_sep_.size(), out); - --sep_index; - } - *out++ = static_cast(digits[to_unsigned(i)]); - } - return out; - } -}; - -FMT_CONSTEXPR inline void prefix_append(unsigned& prefix, unsigned value) { - prefix |= prefix != 0 ? value << 8 : value; - prefix += (1u + (value > 0xff ? 1 : 0)) << 24; -} - -// Writes a decimal integer with digit grouping. -template -auto write_int(OutputIt out, UInt value, unsigned prefix, - const format_specs& specs, const digit_grouping& grouping) - -> OutputIt { - static_assert(std::is_same, UInt>::value, ""); - int num_digits = 0; - auto buffer = memory_buffer(); - switch (specs.type) { - default: - FMT_ASSERT(false, ""); - FMT_FALLTHROUGH; - case presentation_type::none: - case presentation_type::dec: - num_digits = count_digits(value); - format_decimal(appender(buffer), value, num_digits); - break; - case presentation_type::hex: - if (specs.alt) - prefix_append(prefix, unsigned(specs.upper ? 'X' : 'x') << 8 | '0'); - num_digits = count_digits<4>(value); - format_uint<4, char>(appender(buffer), value, num_digits, specs.upper); - break; - case presentation_type::oct: - num_digits = count_digits<3>(value); - // Octal prefix '0' is counted as a digit, so only add it if precision - // is not greater than the number of digits. - if (specs.alt && specs.precision <= num_digits && value != 0) - prefix_append(prefix, '0'); - format_uint<3, char>(appender(buffer), value, num_digits); - break; - case presentation_type::bin: - if (specs.alt) - prefix_append(prefix, unsigned(specs.upper ? 'B' : 'b') << 8 | '0'); - num_digits = count_digits<1>(value); - format_uint<1, char>(appender(buffer), value, num_digits); - break; - case presentation_type::chr: - return write_char(out, static_cast(value), specs); - } - - unsigned size = (prefix != 0 ? prefix >> 24 : 0) + to_unsigned(num_digits) + - to_unsigned(grouping.count_separators(num_digits)); - return write_padded( - out, specs, size, size, [&](reserve_iterator it) { - for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) - *it++ = static_cast(p & 0xff); - return grouping.apply(it, string_view(buffer.data(), buffer.size())); - }); -} - -// Writes a localized value. -FMT_API auto write_loc(appender out, loc_value value, const format_specs& specs, - locale_ref loc) -> bool; -template -inline auto write_loc(OutputIt, loc_value, const format_specs&, locale_ref) - -> bool { - return false; -} - -template struct write_int_arg { - UInt abs_value; - unsigned prefix; -}; - -template -FMT_CONSTEXPR auto make_write_int_arg(T value, sign_t sign) - -> write_int_arg> { - auto prefix = 0u; - auto abs_value = static_cast>(value); - if (is_negative(value)) { - prefix = 0x01000000 | '-'; - abs_value = 0 - abs_value; - } else { - constexpr const unsigned prefixes[4] = {0, 0, 0x1000000u | '+', - 0x1000000u | ' '}; - prefix = prefixes[sign]; - } - return {abs_value, prefix}; -} - -template struct loc_writer { - basic_appender out; - const format_specs& specs; - std::basic_string sep; - std::string grouping; - std::basic_string decimal_point; - - template ::value)> - auto operator()(T value) -> bool { - auto arg = make_write_int_arg(value, specs.sign); - write_int(out, static_cast>(arg.abs_value), arg.prefix, - specs, digit_grouping(grouping, sep)); - return true; - } - - template ::value)> - auto operator()(T) -> bool { - return false; - } -}; - -template -FMT_CONSTEXPR FMT_INLINE auto write_int(OutputIt out, write_int_arg arg, - const format_specs& specs, locale_ref) - -> OutputIt { - static_assert(std::is_same>::value, ""); - auto abs_value = arg.abs_value; - auto prefix = arg.prefix; - switch (specs.type) { - default: - FMT_ASSERT(false, ""); - FMT_FALLTHROUGH; - case presentation_type::none: - case presentation_type::dec: { - int num_digits = count_digits(abs_value); - return write_int( - out, num_digits, prefix, specs, [=](reserve_iterator it) { - return format_decimal(it, abs_value, num_digits).end; - }); - } - case presentation_type::hex: { - if (specs.alt) - prefix_append(prefix, unsigned(specs.upper ? 'X' : 'x') << 8 | '0'); - int num_digits = count_digits<4>(abs_value); - return write_int( - out, num_digits, prefix, specs, [=](reserve_iterator it) { - return format_uint<4, Char>(it, abs_value, num_digits, specs.upper); - }); - } - case presentation_type::oct: { - int num_digits = count_digits<3>(abs_value); - // Octal prefix '0' is counted as a digit, so only add it if precision - // is not greater than the number of digits. - if (specs.alt && specs.precision <= num_digits && abs_value != 0) - prefix_append(prefix, '0'); - return write_int( - out, num_digits, prefix, specs, [=](reserve_iterator it) { - return format_uint<3, Char>(it, abs_value, num_digits); - }); - } - case presentation_type::bin: { - if (specs.alt) - prefix_append(prefix, unsigned(specs.upper ? 'B' : 'b') << 8 | '0'); - int num_digits = count_digits<1>(abs_value); - return write_int( - out, num_digits, prefix, specs, [=](reserve_iterator it) { - return format_uint<1, Char>(it, abs_value, num_digits); - }); - } - case presentation_type::chr: - return write_char(out, static_cast(abs_value), specs); - } -} -template -FMT_CONSTEXPR FMT_NOINLINE auto write_int_noinline(OutputIt out, - write_int_arg arg, - const format_specs& specs, - locale_ref loc) -> OutputIt { - return write_int(out, arg, specs, loc); -} -template ::value && - !std::is_same::value && - !std::is_same::value)> -FMT_CONSTEXPR FMT_INLINE auto write(basic_appender out, T value, - const format_specs& specs, locale_ref loc) - -> basic_appender { - if (specs.localized && write_loc(out, value, specs, loc)) return out; - return write_int_noinline(out, make_write_int_arg(value, specs.sign), - specs, loc); -} -// An inlined version of write used in format string compilation. -template ::value && - !std::is_same::value && - !std::is_same::value && - !std::is_same>::value)> -FMT_CONSTEXPR FMT_INLINE auto write(OutputIt out, T value, - const format_specs& specs, locale_ref loc) - -> OutputIt { - if (specs.localized && write_loc(out, value, specs, loc)) return out; - return write_int(out, make_write_int_arg(value, specs.sign), specs, - loc); -} - -// An output iterator that counts the number of objects written to it and -// discards them. -class counting_iterator { - private: - size_t count_; - - public: - using iterator_category = std::output_iterator_tag; - using difference_type = std::ptrdiff_t; - using pointer = void; - using reference = void; - FMT_UNCHECKED_ITERATOR(counting_iterator); - - struct value_type { - template FMT_CONSTEXPR void operator=(const T&) {} - }; - - FMT_CONSTEXPR counting_iterator() : count_(0) {} - - FMT_CONSTEXPR auto count() const -> size_t { return count_; } - - FMT_CONSTEXPR auto operator++() -> counting_iterator& { - ++count_; - return *this; - } - FMT_CONSTEXPR auto operator++(int) -> counting_iterator { - auto it = *this; - ++*this; - return it; - } - - FMT_CONSTEXPR friend auto operator+(counting_iterator it, difference_type n) - -> counting_iterator { - it.count_ += static_cast(n); - return it; - } - - FMT_CONSTEXPR auto operator*() const -> value_type { return {}; } -}; - -template -FMT_CONSTEXPR auto write(OutputIt out, basic_string_view s, - const format_specs& specs) -> OutputIt { - auto data = s.data(); - auto size = s.size(); - if (specs.precision >= 0 && to_unsigned(specs.precision) < size) - size = code_point_index(s, to_unsigned(specs.precision)); - bool is_debug = specs.type == presentation_type::debug; - size_t width = 0; - - if (is_debug) size = write_escaped_string(counting_iterator{}, s).count(); - - if (specs.width != 0) { - if (is_debug) - width = size; - else - width = compute_width(basic_string_view(data, size)); - } - return write_padded(out, specs, size, width, - [=](reserve_iterator it) { - if (is_debug) return write_escaped_string(it, s); - return copy(data, data + size, it); - }); -} -template -FMT_CONSTEXPR auto write(OutputIt out, - basic_string_view> s, - const format_specs& specs, locale_ref) -> OutputIt { - return write(out, s, specs); -} -template -FMT_CONSTEXPR auto write(OutputIt out, const Char* s, const format_specs& specs, - locale_ref) -> OutputIt { - if (specs.type == presentation_type::pointer) - return write_ptr(out, bit_cast(s), &specs); - if (!s) report_error("string pointer is null"); - return write(out, basic_string_view(s), specs, {}); -} - -template ::value && - !std::is_same::value && - !std::is_same::value)> -FMT_CONSTEXPR auto write(OutputIt out, T value) -> OutputIt { - auto abs_value = static_cast>(value); - bool negative = is_negative(value); - // Don't do -abs_value since it trips unsigned-integer-overflow sanitizer. - if (negative) abs_value = ~abs_value + 1; - int num_digits = count_digits(abs_value); - auto size = (negative ? 1 : 0) + static_cast(num_digits); - auto it = reserve(out, size); - if (auto ptr = to_pointer(it, size)) { - if (negative) *ptr++ = static_cast('-'); - format_decimal(ptr, abs_value, num_digits); - return out; - } - if (negative) *it++ = static_cast('-'); - it = format_decimal(it, abs_value, num_digits).end; - return base_iterator(out, it); -} - -// DEPRECATED! -template -FMT_CONSTEXPR auto parse_align(const Char* begin, const Char* end, - format_specs& specs) -> const Char* { - FMT_ASSERT(begin != end, ""); - auto align = align::none; - auto p = begin + code_point_length(begin); - if (end - p <= 0) p = begin; - for (;;) { - switch (to_ascii(*p)) { - case '<': - align = align::left; - break; - case '>': - align = align::right; - break; - case '^': - align = align::center; - break; - } - if (align != align::none) { - if (p != begin) { - auto c = *begin; - if (c == '}') return begin; - if (c == '{') { - report_error("invalid fill character '{'"); - return begin; - } - specs.fill = basic_string_view(begin, to_unsigned(p - begin)); - begin = p + 1; - } else { - ++begin; - } - break; - } else if (p == begin) { - break; - } - p = begin; - } - specs.align = align; - return begin; -} - -// A floating-point presentation format. -enum class float_format : unsigned char { - general, // General: exponent notation or fixed point based on magnitude. - exp, // Exponent notation with the default precision of 6, e.g. 1.2e-3. - fixed // Fixed point with the default precision of 6, e.g. 0.0012. -}; - -struct float_specs { - int precision; - float_format format : 8; - sign_t sign : 8; - bool locale : 1; - bool binary32 : 1; - bool showpoint : 1; -}; - -// DEPRECATED! -FMT_CONSTEXPR inline auto parse_float_type_spec(const format_specs& specs) - -> float_specs { - auto result = float_specs(); - result.showpoint = specs.alt; - result.locale = specs.localized; - switch (specs.type) { - default: - FMT_FALLTHROUGH; - case presentation_type::none: - result.format = float_format::general; - break; - case presentation_type::exp: - result.format = float_format::exp; - result.showpoint |= specs.precision != 0; - break; - case presentation_type::fixed: - result.format = float_format::fixed; - result.showpoint |= specs.precision != 0; - break; - case presentation_type::general: - result.format = float_format::general; - break; - } - return result; -} - -template -FMT_CONSTEXPR20 auto write_nonfinite(OutputIt out, bool isnan, - format_specs specs, sign_t sign) - -> OutputIt { - auto str = - isnan ? (specs.upper ? "NAN" : "nan") : (specs.upper ? "INF" : "inf"); - constexpr size_t str_size = 3; - auto size = str_size + (sign ? 1 : 0); - // Replace '0'-padding with space for non-finite values. - const bool is_zero_fill = - specs.fill.size() == 1 && specs.fill.template get() == '0'; - if (is_zero_fill) specs.fill = ' '; - return write_padded(out, specs, size, - [=](reserve_iterator it) { - if (sign) *it++ = detail::sign(sign); - return copy(str, str + str_size, it); - }); -} - -// A decimal floating-point number significand * pow(10, exp). -struct big_decimal_fp { - const char* significand; - int significand_size; - int exponent; -}; - -constexpr auto get_significand_size(const big_decimal_fp& f) -> int { - return f.significand_size; -} -template -inline auto get_significand_size(const dragonbox::decimal_fp& f) -> int { - return count_digits(f.significand); -} - -template -constexpr auto write_significand(OutputIt out, const char* significand, - int significand_size) -> OutputIt { - return copy(significand, significand + significand_size, out); -} -template -inline auto write_significand(OutputIt out, UInt significand, - int significand_size) -> OutputIt { - return format_decimal(out, significand, significand_size).end; -} -template -FMT_CONSTEXPR20 auto write_significand(OutputIt out, T significand, - int significand_size, int exponent, - const Grouping& grouping) -> OutputIt { - if (!grouping.has_separator()) { - out = write_significand(out, significand, significand_size); - return detail::fill_n(out, exponent, static_cast('0')); - } - auto buffer = memory_buffer(); - write_significand(appender(buffer), significand, significand_size); - detail::fill_n(appender(buffer), exponent, '0'); - return grouping.apply(out, string_view(buffer.data(), buffer.size())); -} - -template ::value)> -inline auto write_significand(Char* out, UInt significand, int significand_size, - int integral_size, Char decimal_point) -> Char* { - if (!decimal_point) - return format_decimal(out, significand, significand_size).end; - out += significand_size + 1; - Char* end = out; - int floating_size = significand_size - integral_size; - for (int i = floating_size / 2; i > 0; --i) { - out -= 2; - copy2(out, digits2(static_cast(significand % 100))); - significand /= 100; - } - if (floating_size % 2 != 0) { - *--out = static_cast('0' + significand % 10); - significand /= 10; - } - *--out = decimal_point; - format_decimal(out - integral_size, significand, integral_size); - return end; -} - -template >::value)> -inline auto write_significand(OutputIt out, UInt significand, - int significand_size, int integral_size, - Char decimal_point) -> OutputIt { - // Buffer is large enough to hold digits (digits10 + 1) and a decimal point. - Char buffer[digits10() + 2]; - auto end = write_significand(buffer, significand, significand_size, - integral_size, decimal_point); - return detail::copy_noinline(buffer, end, out); -} - -template -FMT_CONSTEXPR auto write_significand(OutputIt out, const char* significand, - int significand_size, int integral_size, - Char decimal_point) -> OutputIt { - out = detail::copy_noinline(significand, significand + integral_size, - out); - if (!decimal_point) return out; - *out++ = decimal_point; - return detail::copy_noinline(significand + integral_size, - significand + significand_size, out); -} - -template -FMT_CONSTEXPR20 auto write_significand(OutputIt out, T significand, - int significand_size, int integral_size, - Char decimal_point, - const Grouping& grouping) -> OutputIt { - if (!grouping.has_separator()) { - return write_significand(out, significand, significand_size, integral_size, - decimal_point); - } - auto buffer = basic_memory_buffer(); - write_significand(basic_appender(buffer), significand, significand_size, - integral_size, decimal_point); - grouping.apply( - out, basic_string_view(buffer.data(), to_unsigned(integral_size))); - return detail::copy_noinline(buffer.data() + integral_size, - buffer.end(), out); -} - -template > -FMT_CONSTEXPR20 auto do_write_float(OutputIt out, const DecimalFP& f, - const format_specs& specs, - float_specs fspecs, locale_ref loc) - -> OutputIt { - auto significand = f.significand; - int significand_size = get_significand_size(f); - const Char zero = static_cast('0'); - auto sign = fspecs.sign; - size_t size = to_unsigned(significand_size) + (sign ? 1 : 0); - using iterator = reserve_iterator; - - Char decimal_point = - fspecs.locale ? detail::decimal_point(loc) : static_cast('.'); - - int output_exp = f.exponent + significand_size - 1; - auto use_exp_format = [=]() { - if (fspecs.format == float_format::exp) return true; - if (fspecs.format != float_format::general) return false; - // Use the fixed notation if the exponent is in [exp_lower, exp_upper), - // e.g. 0.0001 instead of 1e-04. Otherwise use the exponent notation. - const int exp_lower = -4, exp_upper = 16; - return output_exp < exp_lower || - output_exp >= (fspecs.precision > 0 ? fspecs.precision : exp_upper); - }; - if (use_exp_format()) { - int num_zeros = 0; - if (fspecs.showpoint) { - num_zeros = fspecs.precision - significand_size; - if (num_zeros < 0) num_zeros = 0; - size += to_unsigned(num_zeros); - } else if (significand_size == 1) { - decimal_point = Char(); - } - auto abs_output_exp = output_exp >= 0 ? output_exp : -output_exp; - int exp_digits = 2; - if (abs_output_exp >= 100) exp_digits = abs_output_exp >= 1000 ? 4 : 3; - - size += to_unsigned((decimal_point ? 1 : 0) + 2 + exp_digits); - char exp_char = specs.upper ? 'E' : 'e'; - auto write = [=](iterator it) { - if (sign) *it++ = detail::sign(sign); - // Insert a decimal point after the first digit and add an exponent. - it = write_significand(it, significand, significand_size, 1, - decimal_point); - if (num_zeros > 0) it = detail::fill_n(it, num_zeros, zero); - *it++ = static_cast(exp_char); - return write_exponent(output_exp, it); - }; - return specs.width > 0 - ? write_padded(out, specs, size, write) - : base_iterator(out, write(reserve(out, size))); - } - - int exp = f.exponent + significand_size; - if (f.exponent >= 0) { - // 1234e5 -> 123400000[.0+] - size += to_unsigned(f.exponent); - int num_zeros = fspecs.precision - exp; - abort_fuzzing_if(num_zeros > 5000); - if (fspecs.showpoint) { - ++size; - if (num_zeros <= 0 && fspecs.format != float_format::fixed) num_zeros = 0; - if (num_zeros > 0) size += to_unsigned(num_zeros); - } - auto grouping = Grouping(loc, fspecs.locale); - size += to_unsigned(grouping.count_separators(exp)); - return write_padded(out, specs, size, [&](iterator it) { - if (sign) *it++ = detail::sign(sign); - it = write_significand(it, significand, significand_size, - f.exponent, grouping); - if (!fspecs.showpoint) return it; - *it++ = decimal_point; - return num_zeros > 0 ? detail::fill_n(it, num_zeros, zero) : it; - }); - } else if (exp > 0) { - // 1234e-2 -> 12.34[0+] - int num_zeros = fspecs.showpoint ? fspecs.precision - significand_size : 0; - size += 1 + to_unsigned(num_zeros > 0 ? num_zeros : 0); - auto grouping = Grouping(loc, fspecs.locale); - size += to_unsigned(grouping.count_separators(exp)); - return write_padded(out, specs, size, [&](iterator it) { - if (sign) *it++ = detail::sign(sign); - it = write_significand(it, significand, significand_size, exp, - decimal_point, grouping); - return num_zeros > 0 ? detail::fill_n(it, num_zeros, zero) : it; - }); - } - // 1234e-6 -> 0.001234 - int num_zeros = -exp; - if (significand_size == 0 && fspecs.precision >= 0 && - fspecs.precision < num_zeros) { - num_zeros = fspecs.precision; - } - bool pointy = num_zeros != 0 || significand_size != 0 || fspecs.showpoint; - size += 1 + (pointy ? 1 : 0) + to_unsigned(num_zeros); - return write_padded(out, specs, size, [&](iterator it) { - if (sign) *it++ = detail::sign(sign); - *it++ = zero; - if (!pointy) return it; - *it++ = decimal_point; - it = detail::fill_n(it, num_zeros, zero); - return write_significand(it, significand, significand_size); - }); -} - -template class fallback_digit_grouping { - public: - constexpr fallback_digit_grouping(locale_ref, bool) {} - - constexpr auto has_separator() const -> bool { return false; } - - constexpr auto count_separators(int) const -> int { return 0; } - - template - constexpr auto apply(Out out, basic_string_view) const -> Out { - return out; - } -}; - -template -FMT_CONSTEXPR20 auto write_float(OutputIt out, const DecimalFP& f, - const format_specs& specs, float_specs fspecs, - locale_ref loc) -> OutputIt { - if (is_constant_evaluated()) { - return do_write_float>(out, f, specs, fspecs, - loc); - } else { - return do_write_float(out, f, specs, fspecs, loc); - } -} - -template constexpr auto isnan(T value) -> bool { - return value != value; // std::isnan doesn't support __float128. -} - -template -struct has_isfinite : std::false_type {}; - -template -struct has_isfinite> - : std::true_type {}; - -template ::value&& - has_isfinite::value)> -FMT_CONSTEXPR20 auto isfinite(T value) -> bool { - constexpr T inf = T(std::numeric_limits::infinity()); - if (is_constant_evaluated()) - return !detail::isnan(value) && value < inf && value > -inf; - return std::isfinite(value); -} -template ::value)> -FMT_CONSTEXPR auto isfinite(T value) -> bool { - T inf = T(std::numeric_limits::infinity()); - // std::isfinite doesn't support __float128. - return !detail::isnan(value) && value < inf && value > -inf; -} - -template ::value)> -FMT_INLINE FMT_CONSTEXPR bool signbit(T value) { - if (is_constant_evaluated()) { -#ifdef __cpp_if_constexpr - if constexpr (std::numeric_limits::is_iec559) { - auto bits = detail::bit_cast(static_cast(value)); - return (bits >> (num_bits() - 1)) != 0; - } -#endif - } - return std::signbit(static_cast(value)); -} - -inline FMT_CONSTEXPR20 void adjust_precision(int& precision, int exp10) { - // Adjust fixed precision by exponent because it is relative to decimal - // point. - if (exp10 > 0 && precision > max_value() - exp10) - FMT_THROW(format_error("number is too big")); - precision += exp10; -} - -class bigint { - private: - // A bigint is stored as an array of bigits (big digits), with bigit at index - // 0 being the least significant one. - using bigit = uint32_t; - using double_bigit = uint64_t; - enum { bigits_capacity = 32 }; - basic_memory_buffer bigits_; - int exp_; - - FMT_CONSTEXPR20 auto operator[](int index) const -> bigit { - return bigits_[to_unsigned(index)]; - } - FMT_CONSTEXPR20 auto operator[](int index) -> bigit& { - return bigits_[to_unsigned(index)]; - } - - static constexpr const int bigit_bits = num_bits(); - - friend struct formatter; - - FMT_CONSTEXPR20 void subtract_bigits(int index, bigit other, bigit& borrow) { - auto result = static_cast((*this)[index]) - other - borrow; - (*this)[index] = static_cast(result); - borrow = static_cast(result >> (bigit_bits * 2 - 1)); - } - - FMT_CONSTEXPR20 void remove_leading_zeros() { - int num_bigits = static_cast(bigits_.size()) - 1; - while (num_bigits > 0 && (*this)[num_bigits] == 0) --num_bigits; - bigits_.resize(to_unsigned(num_bigits + 1)); - } - - // Computes *this -= other assuming aligned bigints and *this >= other. - FMT_CONSTEXPR20 void subtract_aligned(const bigint& other) { - FMT_ASSERT(other.exp_ >= exp_, "unaligned bigints"); - FMT_ASSERT(compare(*this, other) >= 0, ""); - bigit borrow = 0; - int i = other.exp_ - exp_; - for (size_t j = 0, n = other.bigits_.size(); j != n; ++i, ++j) - subtract_bigits(i, other.bigits_[j], borrow); - while (borrow > 0) subtract_bigits(i, 0, borrow); - remove_leading_zeros(); - } - - FMT_CONSTEXPR20 void multiply(uint32_t value) { - const double_bigit wide_value = value; - bigit carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - double_bigit result = bigits_[i] * wide_value + carry; - bigits_[i] = static_cast(result); - carry = static_cast(result >> bigit_bits); - } - if (carry != 0) bigits_.push_back(carry); - } - - template ::value || - std::is_same::value)> - FMT_CONSTEXPR20 void multiply(UInt value) { - using half_uint = - conditional_t::value, uint64_t, uint32_t>; - const int shift = num_bits() - bigit_bits; - const UInt lower = static_cast(value); - const UInt upper = value >> num_bits(); - UInt carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - UInt result = lower * bigits_[i] + static_cast(carry); - carry = (upper * bigits_[i] << shift) + (result >> bigit_bits) + - (carry >> bigit_bits); - bigits_[i] = static_cast(result); - } - while (carry != 0) { - bigits_.push_back(static_cast(carry)); - carry >>= bigit_bits; - } - } - - template ::value || - std::is_same::value)> - FMT_CONSTEXPR20 void assign(UInt n) { - size_t num_bigits = 0; - do { - bigits_[num_bigits++] = static_cast(n); - n >>= bigit_bits; - } while (n != 0); - bigits_.resize(num_bigits); - exp_ = 0; - } - - public: - FMT_CONSTEXPR20 bigint() : exp_(0) {} - explicit bigint(uint64_t n) { assign(n); } - - bigint(const bigint&) = delete; - void operator=(const bigint&) = delete; - - FMT_CONSTEXPR20 void assign(const bigint& other) { - auto size = other.bigits_.size(); - bigits_.resize(size); - auto data = other.bigits_.data(); - copy(data, data + size, bigits_.data()); - exp_ = other.exp_; - } - - template FMT_CONSTEXPR20 void operator=(Int n) { - FMT_ASSERT(n > 0, ""); - assign(uint64_or_128_t(n)); - } - - FMT_CONSTEXPR20 auto num_bigits() const -> int { - return static_cast(bigits_.size()) + exp_; - } - - FMT_NOINLINE FMT_CONSTEXPR20 auto operator<<=(int shift) -> bigint& { - FMT_ASSERT(shift >= 0, ""); - exp_ += shift / bigit_bits; - shift %= bigit_bits; - if (shift == 0) return *this; - bigit carry = 0; - for (size_t i = 0, n = bigits_.size(); i < n; ++i) { - bigit c = bigits_[i] >> (bigit_bits - shift); - bigits_[i] = (bigits_[i] << shift) + carry; - carry = c; - } - if (carry != 0) bigits_.push_back(carry); - return *this; - } - - template - FMT_CONSTEXPR20 auto operator*=(Int value) -> bigint& { - FMT_ASSERT(value > 0, ""); - multiply(uint32_or_64_or_128_t(value)); - return *this; - } - - friend FMT_CONSTEXPR20 auto compare(const bigint& lhs, const bigint& rhs) - -> int { - int num_lhs_bigits = lhs.num_bigits(), num_rhs_bigits = rhs.num_bigits(); - if (num_lhs_bigits != num_rhs_bigits) - return num_lhs_bigits > num_rhs_bigits ? 1 : -1; - int i = static_cast(lhs.bigits_.size()) - 1; - int j = static_cast(rhs.bigits_.size()) - 1; - int end = i - j; - if (end < 0) end = 0; - for (; i >= end; --i, --j) { - bigit lhs_bigit = lhs[i], rhs_bigit = rhs[j]; - if (lhs_bigit != rhs_bigit) return lhs_bigit > rhs_bigit ? 1 : -1; - } - if (i != j) return i > j ? 1 : -1; - return 0; - } - - // Returns compare(lhs1 + lhs2, rhs). - friend FMT_CONSTEXPR20 auto add_compare(const bigint& lhs1, - const bigint& lhs2, const bigint& rhs) - -> int { - auto minimum = [](int a, int b) { return a < b ? a : b; }; - auto maximum = [](int a, int b) { return a > b ? a : b; }; - int max_lhs_bigits = maximum(lhs1.num_bigits(), lhs2.num_bigits()); - int num_rhs_bigits = rhs.num_bigits(); - if (max_lhs_bigits + 1 < num_rhs_bigits) return -1; - if (max_lhs_bigits > num_rhs_bigits) return 1; - auto get_bigit = [](const bigint& n, int i) -> bigit { - return i >= n.exp_ && i < n.num_bigits() ? n[i - n.exp_] : 0; - }; - double_bigit borrow = 0; - int min_exp = minimum(minimum(lhs1.exp_, lhs2.exp_), rhs.exp_); - for (int i = num_rhs_bigits - 1; i >= min_exp; --i) { - double_bigit sum = - static_cast(get_bigit(lhs1, i)) + get_bigit(lhs2, i); - bigit rhs_bigit = get_bigit(rhs, i); - if (sum > rhs_bigit + borrow) return 1; - borrow = rhs_bigit + borrow - sum; - if (borrow > 1) return -1; - borrow <<= bigit_bits; - } - return borrow != 0 ? -1 : 0; - } - - // Assigns pow(10, exp) to this bigint. - FMT_CONSTEXPR20 void assign_pow10(int exp) { - FMT_ASSERT(exp >= 0, ""); - if (exp == 0) return *this = 1; - // Find the top bit. - int bitmask = 1; - while (exp >= bitmask) bitmask <<= 1; - bitmask >>= 1; - // pow(10, exp) = pow(5, exp) * pow(2, exp). First compute pow(5, exp) by - // repeated squaring and multiplication. - *this = 5; - bitmask >>= 1; - while (bitmask != 0) { - square(); - if ((exp & bitmask) != 0) *this *= 5; - bitmask >>= 1; - } - *this <<= exp; // Multiply by pow(2, exp) by shifting. - } - - FMT_CONSTEXPR20 void square() { - int num_bigits = static_cast(bigits_.size()); - int num_result_bigits = 2 * num_bigits; - basic_memory_buffer n(std::move(bigits_)); - bigits_.resize(to_unsigned(num_result_bigits)); - auto sum = uint128_t(); - for (int bigit_index = 0; bigit_index < num_bigits; ++bigit_index) { - // Compute bigit at position bigit_index of the result by adding - // cross-product terms n[i] * n[j] such that i + j == bigit_index. - for (int i = 0, j = bigit_index; j >= 0; ++i, --j) { - // Most terms are multiplied twice which can be optimized in the future. - sum += static_cast(n[i]) * n[j]; - } - (*this)[bigit_index] = static_cast(sum); - sum >>= num_bits(); // Compute the carry. - } - // Do the same for the top half. - for (int bigit_index = num_bigits; bigit_index < num_result_bigits; - ++bigit_index) { - for (int j = num_bigits - 1, i = bigit_index - j; i < num_bigits;) - sum += static_cast(n[i++]) * n[j--]; - (*this)[bigit_index] = static_cast(sum); - sum >>= num_bits(); - } - remove_leading_zeros(); - exp_ *= 2; - } - - // If this bigint has a bigger exponent than other, adds trailing zero to make - // exponents equal. This simplifies some operations such as subtraction. - FMT_CONSTEXPR20 void align(const bigint& other) { - int exp_difference = exp_ - other.exp_; - if (exp_difference <= 0) return; - int num_bigits = static_cast(bigits_.size()); - bigits_.resize(to_unsigned(num_bigits + exp_difference)); - for (int i = num_bigits - 1, j = i + exp_difference; i >= 0; --i, --j) - bigits_[j] = bigits_[i]; - memset(bigits_.data(), 0, to_unsigned(exp_difference) * sizeof(bigit)); - exp_ -= exp_difference; - } - - // Divides this bignum by divisor, assigning the remainder to this and - // returning the quotient. - FMT_CONSTEXPR20 auto divmod_assign(const bigint& divisor) -> int { - FMT_ASSERT(this != &divisor, ""); - if (compare(*this, divisor) < 0) return 0; - FMT_ASSERT(divisor.bigits_[divisor.bigits_.size() - 1u] != 0, ""); - align(divisor); - int quotient = 0; - do { - subtract_aligned(divisor); - ++quotient; - } while (compare(*this, divisor) >= 0); - return quotient; - } -}; - -// format_dragon flags. -enum dragon { - predecessor_closer = 1, - fixup = 2, // Run fixup to correct exp10 which can be off by one. - fixed = 4, -}; - -// Formats a floating-point number using a variation of the Fixed-Precision -// Positive Floating-Point Printout ((FPP)^2) algorithm by Steele & White: -// https://fmt.dev/papers/p372-steele.pdf. -FMT_CONSTEXPR20 inline void format_dragon(basic_fp value, - unsigned flags, int num_digits, - buffer& buf, int& exp10) { - bigint numerator; // 2 * R in (FPP)^2. - bigint denominator; // 2 * S in (FPP)^2. - // lower and upper are differences between value and corresponding boundaries. - bigint lower; // (M^- in (FPP)^2). - bigint upper_store; // upper's value if different from lower. - bigint* upper = nullptr; // (M^+ in (FPP)^2). - // Shift numerator and denominator by an extra bit or two (if lower boundary - // is closer) to make lower and upper integers. This eliminates multiplication - // by 2 during later computations. - bool is_predecessor_closer = (flags & dragon::predecessor_closer) != 0; - int shift = is_predecessor_closer ? 2 : 1; - if (value.e >= 0) { - numerator = value.f; - numerator <<= value.e + shift; - lower = 1; - lower <<= value.e; - if (is_predecessor_closer) { - upper_store = 1; - upper_store <<= value.e + 1; - upper = &upper_store; - } - denominator.assign_pow10(exp10); - denominator <<= shift; - } else if (exp10 < 0) { - numerator.assign_pow10(-exp10); - lower.assign(numerator); - if (is_predecessor_closer) { - upper_store.assign(numerator); - upper_store <<= 1; - upper = &upper_store; - } - numerator *= value.f; - numerator <<= shift; - denominator = 1; - denominator <<= shift - value.e; - } else { - numerator = value.f; - numerator <<= shift; - denominator.assign_pow10(exp10); - denominator <<= shift - value.e; - lower = 1; - if (is_predecessor_closer) { - upper_store = 1ULL << 1; - upper = &upper_store; - } - } - int even = static_cast((value.f & 1) == 0); - if (!upper) upper = &lower; - bool shortest = num_digits < 0; - if ((flags & dragon::fixup) != 0) { - if (add_compare(numerator, *upper, denominator) + even <= 0) { - --exp10; - numerator *= 10; - if (num_digits < 0) { - lower *= 10; - if (upper != &lower) *upper *= 10; - } - } - if ((flags & dragon::fixed) != 0) adjust_precision(num_digits, exp10 + 1); - } - // Invariant: value == (numerator / denominator) * pow(10, exp10). - if (shortest) { - // Generate the shortest representation. - num_digits = 0; - char* data = buf.data(); - for (;;) { - int digit = numerator.divmod_assign(denominator); - bool low = compare(numerator, lower) - even < 0; // numerator <[=] lower. - // numerator + upper >[=] pow10: - bool high = add_compare(numerator, *upper, denominator) + even > 0; - data[num_digits++] = static_cast('0' + digit); - if (low || high) { - if (!low) { - ++data[num_digits - 1]; - } else if (high) { - int result = add_compare(numerator, numerator, denominator); - // Round half to even. - if (result > 0 || (result == 0 && (digit % 2) != 0)) - ++data[num_digits - 1]; - } - buf.try_resize(to_unsigned(num_digits)); - exp10 -= num_digits - 1; - return; - } - numerator *= 10; - lower *= 10; - if (upper != &lower) *upper *= 10; - } - } - // Generate the given number of digits. - exp10 -= num_digits - 1; - if (num_digits <= 0) { - auto digit = '0'; - if (num_digits == 0) { - denominator *= 10; - digit = add_compare(numerator, numerator, denominator) > 0 ? '1' : '0'; - } - buf.push_back(digit); - return; - } - buf.try_resize(to_unsigned(num_digits)); - for (int i = 0; i < num_digits - 1; ++i) { - int digit = numerator.divmod_assign(denominator); - buf[i] = static_cast('0' + digit); - numerator *= 10; - } - int digit = numerator.divmod_assign(denominator); - auto result = add_compare(numerator, numerator, denominator); - if (result > 0 || (result == 0 && (digit % 2) != 0)) { - if (digit == 9) { - const auto overflow = '0' + 10; - buf[num_digits - 1] = overflow; - // Propagate the carry. - for (int i = num_digits - 1; i > 0 && buf[i] == overflow; --i) { - buf[i] = '0'; - ++buf[i - 1]; - } - if (buf[0] == overflow) { - buf[0] = '1'; - if ((flags & dragon::fixed) != 0) - buf.push_back('0'); - else - ++exp10; - } - return; - } - ++digit; - } - buf[num_digits - 1] = static_cast('0' + digit); -} - -// Formats a floating-point number using the hexfloat format. -template ::value)> -FMT_CONSTEXPR20 void format_hexfloat(Float value, format_specs specs, - buffer& buf) { - // float is passed as double to reduce the number of instantiations and to - // simplify implementation. - static_assert(!std::is_same::value, ""); - - using info = dragonbox::float_info; - - // Assume Float is in the format [sign][exponent][significand]. - using carrier_uint = typename info::carrier_uint; - - constexpr auto num_float_significand_bits = - detail::num_significand_bits(); - - basic_fp f(value); - f.e += num_float_significand_bits; - if (!has_implicit_bit()) --f.e; - - constexpr auto num_fraction_bits = - num_float_significand_bits + (has_implicit_bit() ? 1 : 0); - constexpr auto num_xdigits = (num_fraction_bits + 3) / 4; - - constexpr auto leading_shift = ((num_xdigits - 1) * 4); - const auto leading_mask = carrier_uint(0xF) << leading_shift; - const auto leading_xdigit = - static_cast((f.f & leading_mask) >> leading_shift); - if (leading_xdigit > 1) f.e -= (32 - countl_zero(leading_xdigit) - 1); - - int print_xdigits = num_xdigits - 1; - if (specs.precision >= 0 && print_xdigits > specs.precision) { - const int shift = ((print_xdigits - specs.precision - 1) * 4); - const auto mask = carrier_uint(0xF) << shift; - const auto v = static_cast((f.f & mask) >> shift); - - if (v >= 8) { - const auto inc = carrier_uint(1) << (shift + 4); - f.f += inc; - f.f &= ~(inc - 1); - } - - // Check long double overflow - if (!has_implicit_bit()) { - const auto implicit_bit = carrier_uint(1) << num_float_significand_bits; - if ((f.f & implicit_bit) == implicit_bit) { - f.f >>= 4; - f.e += 4; - } - } - - print_xdigits = specs.precision; - } - - char xdigits[num_bits() / 4]; - detail::fill_n(xdigits, sizeof(xdigits), '0'); - format_uint<4>(xdigits, f.f, num_xdigits, specs.upper); - - // Remove zero tail - while (print_xdigits > 0 && xdigits[print_xdigits] == '0') --print_xdigits; - - buf.push_back('0'); - buf.push_back(specs.upper ? 'X' : 'x'); - buf.push_back(xdigits[0]); - if (specs.alt || print_xdigits > 0 || print_xdigits < specs.precision) - buf.push_back('.'); - buf.append(xdigits + 1, xdigits + 1 + print_xdigits); - for (; print_xdigits < specs.precision; ++print_xdigits) buf.push_back('0'); - - buf.push_back(specs.upper ? 'P' : 'p'); - - uint32_t abs_e; - if (f.e < 0) { - buf.push_back('-'); - abs_e = static_cast(-f.e); - } else { - buf.push_back('+'); - abs_e = static_cast(f.e); - } - format_decimal(appender(buf), abs_e, detail::count_digits(abs_e)); -} - -template ::value)> -FMT_CONSTEXPR20 void format_hexfloat(Float value, format_specs specs, - buffer& buf) { - format_hexfloat(static_cast(value), specs, buf); -} - -constexpr auto fractional_part_rounding_thresholds(int index) -> uint32_t { - // For checking rounding thresholds. - // The kth entry is chosen to be the smallest integer such that the - // upper 32-bits of 10^(k+1) times it is strictly bigger than 5 * 10^k. - // It is equal to ceil(2^31 + 2^32/10^(k + 1)). - // These are stored in a string literal because we cannot have static arrays - // in constexpr functions and non-static ones are poorly optimized. - return U"\x9999999a\x828f5c29\x80418938\x80068db9\x8000a7c6\x800010c7" - U"\x800001ae\x8000002b"[index]; -} - -template -FMT_CONSTEXPR20 auto format_float(Float value, int precision, float_specs specs, - buffer& buf) -> int { - // float is passed as double to reduce the number of instantiations. - static_assert(!std::is_same::value, ""); - FMT_ASSERT(value >= 0, "value is negative"); - auto converted_value = convert_float(value); - - const bool fixed = specs.format == float_format::fixed; - if (value <= 0) { // <= instead of == to silence a warning. - if (precision <= 0 || !fixed) { - buf.push_back('0'); - return 0; - } - buf.try_resize(to_unsigned(precision)); - fill_n(buf.data(), precision, '0'); - return -precision; - } - - int exp = 0; - bool use_dragon = true; - unsigned dragon_flags = 0; - if (!is_fast_float() || is_constant_evaluated()) { - const auto inv_log2_10 = 0.3010299956639812; // 1 / log2(10) - using info = dragonbox::float_info; - const auto f = basic_fp(converted_value); - // Compute exp, an approximate power of 10, such that - // 10^(exp - 1) <= value < 10^exp or 10^exp <= value < 10^(exp + 1). - // This is based on log10(value) == log2(value) / log2(10) and approximation - // of log2(value) by e + num_fraction_bits idea from double-conversion. - auto e = (f.e + count_digits<1>(f.f) - 1) * inv_log2_10 - 1e-10; - exp = static_cast(e); - if (e > exp) ++exp; // Compute ceil. - dragon_flags = dragon::fixup; - } else if (precision < 0) { - // Use Dragonbox for the shortest format. - if (specs.binary32) { - auto dec = dragonbox::to_decimal(static_cast(value)); - write(appender(buf), dec.significand); - return dec.exponent; - } - auto dec = dragonbox::to_decimal(static_cast(value)); - write(appender(buf), dec.significand); - return dec.exponent; - } else { - // Extract significand bits and exponent bits. - using info = dragonbox::float_info; - auto br = bit_cast(static_cast(value)); - - const uint64_t significand_mask = - (static_cast(1) << num_significand_bits()) - 1; - uint64_t significand = (br & significand_mask); - int exponent = static_cast((br & exponent_mask()) >> - num_significand_bits()); - - if (exponent != 0) { // Check if normal. - exponent -= exponent_bias() + num_significand_bits(); - significand |= - (static_cast(1) << num_significand_bits()); - significand <<= 1; - } else { - // Normalize subnormal inputs. - FMT_ASSERT(significand != 0, "zeros should not appear here"); - int shift = countl_zero(significand); - FMT_ASSERT(shift >= num_bits() - num_significand_bits(), - ""); - shift -= (num_bits() - num_significand_bits() - 2); - exponent = (std::numeric_limits::min_exponent - - num_significand_bits()) - - shift; - significand <<= shift; - } - - // Compute the first several nonzero decimal significand digits. - // We call the number we get the first segment. - const int k = info::kappa - dragonbox::floor_log10_pow2(exponent); - exp = -k; - const int beta = exponent + dragonbox::floor_log2_pow10(k); - uint64_t first_segment; - bool has_more_segments; - int digits_in_the_first_segment; - { - const auto r = dragonbox::umul192_upper128( - significand << beta, dragonbox::get_cached_power(k)); - first_segment = r.high(); - has_more_segments = r.low() != 0; - - // The first segment can have 18 ~ 19 digits. - if (first_segment >= 1000000000000000000ULL) { - digits_in_the_first_segment = 19; - } else { - // When it is of 18-digits, we align it to 19-digits by adding a bogus - // zero at the end. - digits_in_the_first_segment = 18; - first_segment *= 10; - } - } - - // Compute the actual number of decimal digits to print. - if (fixed) adjust_precision(precision, exp + digits_in_the_first_segment); - - // Use Dragon4 only when there might be not enough digits in the first - // segment. - if (digits_in_the_first_segment > precision) { - use_dragon = false; - - if (precision <= 0) { - exp += digits_in_the_first_segment; - - if (precision < 0) { - // Nothing to do, since all we have are just leading zeros. - buf.try_resize(0); - } else { - // We may need to round-up. - buf.try_resize(1); - if ((first_segment | static_cast(has_more_segments)) > - 5000000000000000000ULL) { - buf[0] = '1'; - } else { - buf[0] = '0'; - } - } - } // precision <= 0 - else { - exp += digits_in_the_first_segment - precision; - - // When precision > 0, we divide the first segment into three - // subsegments, each with 9, 9, and 0 ~ 1 digits so that each fits - // in 32-bits which usually allows faster calculation than in - // 64-bits. Since some compiler (e.g. MSVC) doesn't know how to optimize - // division-by-constant for large 64-bit divisors, we do it here - // manually. The magic number 7922816251426433760 below is equal to - // ceil(2^(64+32) / 10^10). - const uint32_t first_subsegment = static_cast( - dragonbox::umul128_upper64(first_segment, 7922816251426433760ULL) >> - 32); - const uint64_t second_third_subsegments = - first_segment - first_subsegment * 10000000000ULL; - - uint64_t prod; - uint32_t digits; - bool should_round_up; - int number_of_digits_to_print = precision > 9 ? 9 : precision; - - // Print a 9-digits subsegment, either the first or the second. - auto print_subsegment = [&](uint32_t subsegment, char* buffer) { - int number_of_digits_printed = 0; - - // If we want to print an odd number of digits from the subsegment, - if ((number_of_digits_to_print & 1) != 0) { - // Convert to 64-bit fixed-point fractional form with 1-digit - // integer part. The magic number 720575941 is a good enough - // approximation of 2^(32 + 24) / 10^8; see - // https://jk-jeon.github.io/posts/2022/12/fixed-precision-formatting/#fixed-length-case - // for details. - prod = ((subsegment * static_cast(720575941)) >> 24) + 1; - digits = static_cast(prod >> 32); - *buffer = static_cast('0' + digits); - number_of_digits_printed++; - } - // If we want to print an even number of digits from the - // first_subsegment, - else { - // Convert to 64-bit fixed-point fractional form with 2-digits - // integer part. The magic number 450359963 is a good enough - // approximation of 2^(32 + 20) / 10^7; see - // https://jk-jeon.github.io/posts/2022/12/fixed-precision-formatting/#fixed-length-case - // for details. - prod = ((subsegment * static_cast(450359963)) >> 20) + 1; - digits = static_cast(prod >> 32); - copy2(buffer, digits2(digits)); - number_of_digits_printed += 2; - } - - // Print all digit pairs. - while (number_of_digits_printed < number_of_digits_to_print) { - prod = static_cast(prod) * static_cast(100); - digits = static_cast(prod >> 32); - copy2(buffer + number_of_digits_printed, digits2(digits)); - number_of_digits_printed += 2; - } - }; - - // Print first subsegment. - print_subsegment(first_subsegment, buf.data()); - - // Perform rounding if the first subsegment is the last subsegment to - // print. - if (precision <= 9) { - // Rounding inside the subsegment. - // We round-up if: - // - either the fractional part is strictly larger than 1/2, or - // - the fractional part is exactly 1/2 and the last digit is odd. - // We rely on the following observations: - // - If fractional_part >= threshold, then the fractional part is - // strictly larger than 1/2. - // - If the MSB of fractional_part is set, then the fractional part - // must be at least 1/2. - // - When the MSB of fractional_part is set, either - // second_third_subsegments being nonzero or has_more_segments - // being true means there are further digits not printed, so the - // fractional part is strictly larger than 1/2. - if (precision < 9) { - uint32_t fractional_part = static_cast(prod); - should_round_up = - fractional_part >= fractional_part_rounding_thresholds( - 8 - number_of_digits_to_print) || - ((fractional_part >> 31) & - ((digits & 1) | (second_third_subsegments != 0) | - has_more_segments)) != 0; - } - // Rounding at the subsegment boundary. - // In this case, the fractional part is at least 1/2 if and only if - // second_third_subsegments >= 5000000000ULL, and is strictly larger - // than 1/2 if we further have either second_third_subsegments > - // 5000000000ULL or has_more_segments == true. - else { - should_round_up = second_third_subsegments > 5000000000ULL || - (second_third_subsegments == 5000000000ULL && - ((digits & 1) != 0 || has_more_segments)); - } - } - // Otherwise, print the second subsegment. - else { - // Compilers are not aware of how to leverage the maximum value of - // second_third_subsegments to find out a better magic number which - // allows us to eliminate an additional shift. 1844674407370955162 = - // ceil(2^64/10) < ceil(2^64*(10^9/(10^10 - 1))). - const uint32_t second_subsegment = - static_cast(dragonbox::umul128_upper64( - second_third_subsegments, 1844674407370955162ULL)); - const uint32_t third_subsegment = - static_cast(second_third_subsegments) - - second_subsegment * 10; - - number_of_digits_to_print = precision - 9; - print_subsegment(second_subsegment, buf.data() + 9); - - // Rounding inside the subsegment. - if (precision < 18) { - // The condition third_subsegment != 0 implies that the segment was - // of 19 digits, so in this case the third segment should be - // consisting of a genuine digit from the input. - uint32_t fractional_part = static_cast(prod); - should_round_up = - fractional_part >= fractional_part_rounding_thresholds( - 8 - number_of_digits_to_print) || - ((fractional_part >> 31) & - ((digits & 1) | (third_subsegment != 0) | - has_more_segments)) != 0; - } - // Rounding at the subsegment boundary. - else { - // In this case, the segment must be of 19 digits, thus - // the third subsegment should be consisting of a genuine digit from - // the input. - should_round_up = third_subsegment > 5 || - (third_subsegment == 5 && - ((digits & 1) != 0 || has_more_segments)); - } - } - - // Round-up if necessary. - if (should_round_up) { - ++buf[precision - 1]; - for (int i = precision - 1; i > 0 && buf[i] > '9'; --i) { - buf[i] = '0'; - ++buf[i - 1]; - } - if (buf[0] > '9') { - buf[0] = '1'; - if (fixed) - buf[precision++] = '0'; - else - ++exp; - } - } - buf.try_resize(to_unsigned(precision)); - } - } // if (digits_in_the_first_segment > precision) - else { - // Adjust the exponent for its use in Dragon4. - exp += digits_in_the_first_segment - 1; - } - } - if (use_dragon) { - auto f = basic_fp(); - bool is_predecessor_closer = specs.binary32 - ? f.assign(static_cast(value)) - : f.assign(converted_value); - if (is_predecessor_closer) dragon_flags |= dragon::predecessor_closer; - if (fixed) dragon_flags |= dragon::fixed; - // Limit precision to the maximum possible number of significant digits in - // an IEEE754 double because we don't need to generate zeros. - const int max_double_digits = 767; - if (precision > max_double_digits) precision = max_double_digits; - format_dragon(f, dragon_flags, precision, buf, exp); - } - if (!fixed && !specs.showpoint) { - // Remove trailing zeros. - auto num_digits = buf.size(); - while (num_digits > 0 && buf[num_digits - 1] == '0') { - --num_digits; - ++exp; - } - buf.try_resize(num_digits); - } - return exp; -} - -template -FMT_CONSTEXPR20 auto write_float(OutputIt out, T value, format_specs specs, - locale_ref loc) -> OutputIt { - sign_t sign = specs.sign; - if (detail::signbit(value)) { // value < 0 is false for NaN so use signbit. - sign = sign::minus; - value = -value; - } else if (sign == sign::minus) { - sign = sign::none; - } - - if (!detail::isfinite(value)) - return write_nonfinite(out, detail::isnan(value), specs, sign); - - if (specs.align == align::numeric && sign) { - auto it = reserve(out, 1); - *it++ = detail::sign(sign); - out = base_iterator(out, it); - sign = sign::none; - if (specs.width != 0) --specs.width; - } - - memory_buffer buffer; - if (specs.type == presentation_type::hexfloat) { - if (sign) buffer.push_back(detail::sign(sign)); - format_hexfloat(convert_float(value), specs, buffer); - return write_bytes(out, {buffer.data(), buffer.size()}, - specs); - } - - int precision = specs.precision >= 0 || specs.type == presentation_type::none - ? specs.precision - : 6; - if (specs.type == presentation_type::exp) { - if (precision == max_value()) - report_error("number is too big"); - else - ++precision; - } else if (specs.type != presentation_type::fixed && precision == 0) { - precision = 1; - } - float_specs fspecs = parse_float_type_spec(specs); - fspecs.sign = sign; - if (const_check(std::is_same())) fspecs.binary32 = true; - int exp = format_float(convert_float(value), precision, fspecs, buffer); - fspecs.precision = precision; - auto f = big_decimal_fp{buffer.data(), static_cast(buffer.size()), exp}; - return write_float(out, f, specs, fspecs, loc); -} - -template ::value)> -FMT_CONSTEXPR20 auto write(OutputIt out, T value, format_specs specs, - locale_ref loc = {}) -> OutputIt { - if (const_check(!is_supported_floating_point(value))) return out; - return specs.localized && write_loc(out, value, specs, loc) - ? out - : write_float(out, value, specs, loc); -} - -template ::value)> -FMT_CONSTEXPR20 auto write(OutputIt out, T value) -> OutputIt { - if (is_constant_evaluated()) return write(out, value, format_specs()); - if (const_check(!is_supported_floating_point(value))) return out; - - auto sign = sign_t::none; - if (detail::signbit(value)) { - sign = sign::minus; - value = -value; - } - - constexpr auto specs = format_specs(); - using floaty = conditional_t::value, double, T>; - using floaty_uint = typename dragonbox::float_info::carrier_uint; - floaty_uint mask = exponent_mask(); - if ((bit_cast(value) & mask) == mask) - return write_nonfinite(out, std::isnan(value), specs, sign); - - auto fspecs = float_specs(); - fspecs.sign = sign; - auto dec = dragonbox::to_decimal(static_cast(value)); - return write_float(out, dec, specs, fspecs, {}); -} - -template ::value && - !is_fast_float::value)> -inline auto write(OutputIt out, T value) -> OutputIt { - return write(out, value, format_specs()); -} - -template -auto write(OutputIt out, monostate, format_specs = {}, locale_ref = {}) - -> OutputIt { - FMT_ASSERT(false, ""); - return out; -} - -template -FMT_CONSTEXPR auto write(OutputIt out, basic_string_view value) - -> OutputIt { - auto it = reserve(out, value.size()); - it = copy_noinline(value.begin(), value.end(), it); - return base_iterator(out, it); -} - -template ::value)> -constexpr auto write(OutputIt out, const T& value) -> OutputIt { - return write(out, to_string_view(value)); -} - -// FMT_ENABLE_IF() condition separated to workaround an MSVC bug. -template < - typename Char, typename OutputIt, typename T, - bool check = - std::is_enum::value && !std::is_same::value && - mapped_type_constant>::value != - type::custom_type, - FMT_ENABLE_IF(check)> -FMT_CONSTEXPR auto write(OutputIt out, T value) -> OutputIt { - return write(out, static_cast>(value)); -} - -template ::value)> -FMT_CONSTEXPR auto write(OutputIt out, T value, const format_specs& specs = {}, - locale_ref = {}) -> OutputIt { - return specs.type != presentation_type::none && - specs.type != presentation_type::string - ? write(out, value ? 1 : 0, specs, {}) - : write_bytes(out, value ? "true" : "false", specs); -} - -template -FMT_CONSTEXPR auto write(OutputIt out, Char value) -> OutputIt { - auto it = reserve(out, 1); - *it++ = value; - return base_iterator(out, it); -} - -template -FMT_CONSTEXPR20 auto write(OutputIt out, const Char* value) -> OutputIt { - if (value) return write(out, basic_string_view(value)); - report_error("string pointer is null"); - return out; -} - -template ::value)> -auto write(OutputIt out, const T* value, const format_specs& specs = {}, - locale_ref = {}) -> OutputIt { - return write_ptr(out, bit_cast(value), &specs); -} - -// A write overload that handles implicit conversions. -template > -FMT_CONSTEXPR auto write(OutputIt out, const T& value) -> enable_if_t< - std::is_class::value && !has_to_string_view::value && - !is_floating_point::value && !std::is_same::value && - !std::is_same().map( - value))>>::value, - OutputIt> { - return write(out, arg_mapper().map(value)); -} - -template > -FMT_CONSTEXPR auto write(OutputIt out, const T& value) - -> enable_if_t::value == - type::custom_type && - !std::is_fundamental::value, - OutputIt> { - auto formatter = typename Context::template formatter_type(); - auto parse_ctx = typename Context::parse_context_type({}); - formatter.parse(parse_ctx); - auto ctx = Context(out, {}, {}); - return formatter.format(value, ctx); -} - -// An argument visitor that formats the argument and writes it via the output -// iterator. It's a class and not a generic lambda for compatibility with C++11. -template struct default_arg_formatter { - using iterator = basic_appender; - using context = buffered_context; - - iterator out; - basic_format_args args; - locale_ref loc; - - template auto operator()(T value) -> iterator { - return write(out, value); - } - auto operator()(typename basic_format_arg::handle h) -> iterator { - basic_format_parse_context parse_ctx({}); - context format_ctx(out, args, loc); - h.format(parse_ctx, format_ctx); - return format_ctx.out(); - } -}; - -template struct arg_formatter { - using iterator = basic_appender; - using context = buffered_context; - - iterator out; - const format_specs& specs; - locale_ref locale; - - template - FMT_CONSTEXPR FMT_INLINE auto operator()(T value) -> iterator { - return detail::write(out, value, specs, locale); - } - auto operator()(typename basic_format_arg::handle) -> iterator { - // User-defined types are handled separately because they require access - // to the parse context. - return out; - } -}; - -struct width_checker { - template ::value)> - FMT_CONSTEXPR auto operator()(T value) -> unsigned long long { - if (is_negative(value)) report_error("negative width"); - return static_cast(value); - } - - template ::value)> - FMT_CONSTEXPR auto operator()(T) -> unsigned long long { - report_error("width is not integer"); - return 0; - } -}; - -struct precision_checker { - template ::value)> - FMT_CONSTEXPR auto operator()(T value) -> unsigned long long { - if (is_negative(value)) report_error("negative precision"); - return static_cast(value); - } - - template ::value)> - FMT_CONSTEXPR auto operator()(T) -> unsigned long long { - report_error("precision is not integer"); - return 0; - } -}; - -template -FMT_CONSTEXPR auto get_dynamic_spec(FormatArg arg) -> int { - unsigned long long value = arg.visit(Handler()); - if (value > to_unsigned(max_value())) report_error("number is too big"); - return static_cast(value); -} - -template -FMT_CONSTEXPR auto get_arg(Context& ctx, ID id) -> decltype(ctx.arg(id)) { - auto arg = ctx.arg(id); - if (!arg) report_error("argument not found"); - return arg; -} - -template -FMT_CONSTEXPR void handle_dynamic_spec(int& value, - arg_ref ref, - Context& ctx) { - switch (ref.kind) { - case arg_id_kind::none: - break; - case arg_id_kind::index: - value = detail::get_dynamic_spec(get_arg(ctx, ref.val.index)); - break; - case arg_id_kind::name: - value = detail::get_dynamic_spec(get_arg(ctx, ref.val.name)); - break; - } -} - -#if FMT_USE_USER_DEFINED_LITERALS -# if FMT_USE_NONTYPE_TEMPLATE_ARGS -template Str> -struct statically_named_arg : view { - static constexpr auto name = Str.data; - - const T& value; - statically_named_arg(const T& v) : value(v) {} -}; - -template Str> -struct is_named_arg> : std::true_type {}; - -template Str> -struct is_statically_named_arg> - : std::true_type {}; - -template Str> -struct udl_arg { - template auto operator=(T&& value) const { - return statically_named_arg(std::forward(value)); - } -}; -# else -template struct udl_arg { - const Char* str; - - template auto operator=(T&& value) const -> named_arg { - return {str, std::forward(value)}; - } -}; -# endif -#endif // FMT_USE_USER_DEFINED_LITERALS - -template -auto vformat(const Locale& loc, basic_string_view fmt, - typename detail::vformat_args::type args) - -> std::basic_string { - auto buf = basic_memory_buffer(); - detail::vformat_to(buf, fmt, args, detail::locale_ref(loc)); - return {buf.data(), buf.size()}; -} - -using format_func = void (*)(detail::buffer&, int, const char*); - -FMT_API void format_error_code(buffer& out, int error_code, - string_view message) noexcept; - -using fmt::report_error; -FMT_API void report_error(format_func func, int error_code, - const char* message) noexcept; -} // namespace detail - -FMT_BEGIN_EXPORT -FMT_API auto vsystem_error(int error_code, string_view format_str, - format_args args) -> std::system_error; - -/** - * Constructs `std::system_error` with a message formatted with - * `fmt::format(fmt, args...)`. - * `error_code` is a system error code as given by `errno`. - * - * **Example**: - * - * // This throws std::system_error with the description - * // cannot open file 'madeup': No such file or directory - * // or similar (system message may vary). - * const char* filename = "madeup"; - * std::FILE* file = std::fopen(filename, "r"); - * if (!file) - * throw fmt::system_error(errno, "cannot open file '{}'", filename); - */ -template -auto system_error(int error_code, format_string fmt, T&&... args) - -> std::system_error { - return vsystem_error(error_code, fmt, fmt::make_format_args(args...)); -} - -/** - * Formats an error message for an error returned by an operating system or a - * language runtime, for example a file opening error, and writes it to `out`. - * The format is the same as the one used by `std::system_error(ec, message)` - * where `ec` is `std::error_code(error_code, std::generic_category())`. - * It is implementation-defined but normally looks like: - * - * : - * - * where `` is the passed message and `` is the system - * message corresponding to the error code. - * `error_code` is a system error code as given by `errno`. - */ -FMT_API void format_system_error(detail::buffer& out, int error_code, - const char* message) noexcept; - -// Reports a system error without throwing an exception. -// Can be used to report errors from destructors. -FMT_API void report_system_error(int error_code, const char* message) noexcept; - -/// A fast integer formatter. -class format_int { - private: - // Buffer should be large enough to hold all digits (digits10 + 1), - // a sign and a null character. - enum { buffer_size = std::numeric_limits::digits10 + 3 }; - mutable char buffer_[buffer_size]; - char* str_; - - template - FMT_CONSTEXPR20 auto format_unsigned(UInt value) -> char* { - auto n = static_cast>(value); - return detail::format_decimal(buffer_, n, buffer_size - 1).begin; - } - - template - FMT_CONSTEXPR20 auto format_signed(Int value) -> char* { - auto abs_value = static_cast>(value); - bool negative = value < 0; - if (negative) abs_value = 0 - abs_value; - auto begin = format_unsigned(abs_value); - if (negative) *--begin = '-'; - return begin; - } - - public: - explicit FMT_CONSTEXPR20 format_int(int value) : str_(format_signed(value)) {} - explicit FMT_CONSTEXPR20 format_int(long value) - : str_(format_signed(value)) {} - explicit FMT_CONSTEXPR20 format_int(long long value) - : str_(format_signed(value)) {} - explicit FMT_CONSTEXPR20 format_int(unsigned value) - : str_(format_unsigned(value)) {} - explicit FMT_CONSTEXPR20 format_int(unsigned long value) - : str_(format_unsigned(value)) {} - explicit FMT_CONSTEXPR20 format_int(unsigned long long value) - : str_(format_unsigned(value)) {} - - /// Returns the number of characters written to the output buffer. - FMT_CONSTEXPR20 auto size() const -> size_t { - return detail::to_unsigned(buffer_ - str_ + buffer_size - 1); - } - - /// Returns a pointer to the output buffer content. No terminating null - /// character is appended. - FMT_CONSTEXPR20 auto data() const -> const char* { return str_; } - - /// Returns a pointer to the output buffer content with terminating null - /// character appended. - FMT_CONSTEXPR20 auto c_str() const -> const char* { - buffer_[buffer_size - 1] = '\0'; - return str_; - } - - /// Returns the content of the output buffer as an `std::string`. - auto str() const -> std::string { return std::string(str_, size()); } -}; - -template -struct formatter::value>> - : formatter, Char> { - template - auto format(const T& value, FormatContext& ctx) const -> decltype(ctx.out()) { - using base = formatter, Char>; - auto&& val = format_as(value); // Make an lvalue reference for format. - return base::format(val, ctx); - } -}; - -#define FMT_FORMAT_AS(Type, Base) \ - template \ - struct formatter : formatter {} - -FMT_FORMAT_AS(signed char, int); -FMT_FORMAT_AS(unsigned char, unsigned); -FMT_FORMAT_AS(short, int); -FMT_FORMAT_AS(unsigned short, unsigned); -FMT_FORMAT_AS(long, detail::long_type); -FMT_FORMAT_AS(unsigned long, detail::ulong_type); -FMT_FORMAT_AS(Char*, const Char*); -FMT_FORMAT_AS(std::nullptr_t, const void*); -FMT_FORMAT_AS(detail::std_string_view, basic_string_view); -FMT_FORMAT_AS(void*, const void*); - -template -class formatter, Char> - : public formatter, Char> {}; - -template -struct formatter : formatter, Char> {}; - -/** - * Converts `p` to `const void*` for pointer formatting. - * - * **Example**: - * - * auto s = fmt::format("{}", fmt::ptr(p)); - */ -template auto ptr(T p) -> const void* { - static_assert(std::is_pointer::value, ""); - return detail::bit_cast(p); -} - -/** - * Converts `e` to the underlying type. - * - * **Example**: - * - * enum class color { red, green, blue }; - * auto s = fmt::format("{}", fmt::underlying(color::red)); - */ -template -constexpr auto underlying(Enum e) noexcept -> underlying_t { - return static_cast>(e); -} - -namespace enums { -template ::value)> -constexpr auto format_as(Enum e) noexcept -> underlying_t { - return static_cast>(e); -} -} // namespace enums - -class bytes { - private: - string_view data_; - friend struct formatter; - - public: - explicit bytes(string_view data) : data_(data) {} -}; - -template <> struct formatter { - private: - detail::dynamic_format_specs<> specs_; - - public: - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> const char* { - return parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, - detail::type::string_type); - } - - template - auto format(bytes b, FormatContext& ctx) const -> decltype(ctx.out()) { - auto specs = specs_; - detail::handle_dynamic_spec(specs.width, - specs.width_ref, ctx); - detail::handle_dynamic_spec( - specs.precision, specs.precision_ref, ctx); - return detail::write_bytes(ctx.out(), b.data_, specs); - } -}; - -// group_digits_view is not derived from view because it copies the argument. -template struct group_digits_view { - T value; -}; - -/** - * Returns a view that formats an integer value using ',' as a - * locale-independent thousands separator. - * - * **Example**: - * - * fmt::print("{}", fmt::group_digits(12345)); - * // Output: "12,345" - */ -template auto group_digits(T value) -> group_digits_view { - return {value}; -} - -template struct formatter> : formatter { - private: - detail::dynamic_format_specs<> specs_; - - public: - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> const char* { - return parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, - detail::type::int_type); - } - - template - auto format(group_digits_view t, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto specs = specs_; - detail::handle_dynamic_spec(specs.width, - specs.width_ref, ctx); - detail::handle_dynamic_spec( - specs.precision, specs.precision_ref, ctx); - auto arg = detail::make_write_int_arg(t.value, specs.sign); - return detail::write_int( - ctx.out(), static_cast>(arg.abs_value), - arg.prefix, specs, detail::digit_grouping("\3", ",")); - } -}; - -template struct nested_view { - const formatter* fmt; - const T* value; -}; - -template -struct formatter, Char> { - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - return ctx.begin(); - } - template - auto format(nested_view view, FormatContext& ctx) const - -> decltype(ctx.out()) { - return view.fmt->format(*view.value, ctx); - } -}; - -template struct nested_formatter { - private: - int width_; - detail::fill_t fill_; - align_t align_ : 4; - formatter formatter_; - - public: - constexpr nested_formatter() : width_(0), align_(align_t::none) {} - - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto specs = detail::dynamic_format_specs(); - auto it = parse_format_specs(ctx.begin(), ctx.end(), specs, ctx, - detail::type::none_type); - width_ = specs.width; - fill_ = specs.fill; - align_ = specs.align; - ctx.advance_to(it); - return formatter_.parse(ctx); - } - - template - auto write_padded(FormatContext& ctx, F write) const -> decltype(ctx.out()) { - if (width_ == 0) return write(ctx.out()); - auto buf = basic_memory_buffer(); - write(basic_appender(buf)); - auto specs = format_specs(); - specs.width = width_; - specs.fill = fill_; - specs.align = align_; - return detail::write( - ctx.out(), basic_string_view(buf.data(), buf.size()), specs); - } - - auto nested(const T& value) const -> nested_view { - return nested_view{&formatter_, &value}; - } -}; - -/** - * Converts `value` to `std::string` using the default format for type `T`. - * - * **Example**: - * - * std::string answer = fmt::to_string(42); - */ -template ::value && - !detail::has_format_as::value)> -inline auto to_string(const T& value) -> std::string { - auto buffer = memory_buffer(); - detail::write(appender(buffer), value); - return {buffer.data(), buffer.size()}; -} - -template ::value)> -FMT_NODISCARD inline auto to_string(T value) -> std::string { - // The buffer should be large enough to store the number including the sign - // or "false" for bool. - constexpr int max_size = detail::digits10() + 2; - char buffer[max_size > 5 ? static_cast(max_size) : 5]; - char* begin = buffer; - return std::string(begin, detail::write(begin, value)); -} - -template -FMT_NODISCARD auto to_string(const basic_memory_buffer& buf) - -> std::basic_string { - auto size = buf.size(); - detail::assume(size < std::basic_string().max_size()); - return std::basic_string(buf.data(), size); -} - -template ::value && - detail::has_format_as::value)> -inline auto to_string(const T& value) -> std::string { - return to_string(format_as(value)); -} - -FMT_END_EXPORT - -namespace detail { - -template -void vformat_to(buffer& buf, basic_string_view fmt, - typename vformat_args::type args, locale_ref loc) { - auto out = basic_appender(buf); - if (fmt.size() == 2 && equal2(fmt.data(), "{}")) { - auto arg = args.get(0); - if (!arg) report_error("argument not found"); - arg.visit(default_arg_formatter{out, args, loc}); - return; - } - - struct format_handler { - basic_format_parse_context parse_context; - buffered_context context; - - format_handler(basic_appender p_out, basic_string_view str, - basic_format_args> p_args, - locale_ref p_loc) - : parse_context(str), context(p_out, p_args, p_loc) {} - - void on_text(const Char* begin, const Char* end) { - auto text = basic_string_view(begin, to_unsigned(end - begin)); - context.advance_to(write(context.out(), text)); - } - - FMT_CONSTEXPR auto on_arg_id() -> int { - return parse_context.next_arg_id(); - } - FMT_CONSTEXPR auto on_arg_id(int id) -> int { - parse_context.check_arg_id(id); - return id; - } - FMT_CONSTEXPR auto on_arg_id(basic_string_view id) -> int { - parse_context.check_arg_id(id); - int arg_id = context.arg_id(id); - if (arg_id < 0) report_error("argument not found"); - return arg_id; - } - - FMT_INLINE void on_replacement_field(int id, const Char*) { - auto arg = get_arg(context, id); - context.advance_to(arg.visit(default_arg_formatter{ - context.out(), context.args(), context.locale()})); - } - - auto on_format_specs(int id, const Char* begin, const Char* end) - -> const Char* { - auto arg = get_arg(context, id); - // Not using a visitor for custom types gives better codegen. - if (arg.format_custom(begin, parse_context, context)) - return parse_context.begin(); - auto specs = detail::dynamic_format_specs(); - begin = parse_format_specs(begin, end, specs, parse_context, arg.type()); - detail::handle_dynamic_spec( - specs.width, specs.width_ref, context); - detail::handle_dynamic_spec( - specs.precision, specs.precision_ref, context); - if (begin == end || *begin != '}') - report_error("missing '}' in format string"); - context.advance_to(arg.visit( - arg_formatter{context.out(), specs, context.locale()})); - return begin; - } - - FMT_NORETURN void on_error(const char* message) { report_error(message); } - }; - detail::parse_format_string(fmt, format_handler(out, fmt, args, loc)); -} - -FMT_BEGIN_EXPORT - -#ifndef FMT_HEADER_ONLY -extern template FMT_API void vformat_to(buffer&, string_view, - typename vformat_args<>::type, - locale_ref); -extern template FMT_API auto thousands_sep_impl(locale_ref) - -> thousands_sep_result; -extern template FMT_API auto thousands_sep_impl(locale_ref) - -> thousands_sep_result; -extern template FMT_API auto decimal_point_impl(locale_ref) -> char; -extern template FMT_API auto decimal_point_impl(locale_ref) -> wchar_t; -#endif // FMT_HEADER_ONLY - -FMT_END_EXPORT - -template -template -FMT_CONSTEXPR FMT_INLINE auto native_formatter::format( - const T& val, FormatContext& ctx) const -> decltype(ctx.out()) { - if (specs_.width_ref.kind == arg_id_kind::none && - specs_.precision_ref.kind == arg_id_kind::none) { - return write(ctx.out(), val, specs_, ctx.locale()); - } - auto specs = specs_; - handle_dynamic_spec(specs.width, specs.width_ref, ctx); - handle_dynamic_spec(specs.precision, specs.precision_ref, - ctx); - return write(ctx.out(), val, specs, ctx.locale()); -} - -} // namespace detail - -FMT_BEGIN_EXPORT - -template -struct formatter - : detail::native_formatter {}; - -#if FMT_USE_USER_DEFINED_LITERALS -inline namespace literals { -/** - * User-defined literal equivalent of `fmt::arg`. - * - * **Example**: - * - * using namespace fmt::literals; - * fmt::print("The answer is {answer}.", "answer"_a=42); - */ -# if FMT_USE_NONTYPE_TEMPLATE_ARGS -template constexpr auto operator""_a() { - using char_t = remove_cvref_t; - return detail::udl_arg(); -} -# else -constexpr auto operator""_a(const char* s, size_t) -> detail::udl_arg { - return {s}; -} -# endif -} // namespace literals -#endif // FMT_USE_USER_DEFINED_LITERALS - -FMT_API auto vformat(string_view fmt, format_args args) -> std::string; - -/** - * Formats `args` according to specifications in `fmt` and returns the result - * as a string. - * - * **Example**: - * - * #include - * std::string message = fmt::format("The answer is {}.", 42); - */ -template -FMT_NODISCARD FMT_INLINE auto format(format_string fmt, T&&... args) - -> std::string { - return vformat(fmt, fmt::make_format_args(args...)); -} - -template ::value)> -inline auto vformat(const Locale& loc, string_view fmt, format_args args) - -> std::string { - return detail::vformat(loc, fmt, args); -} - -template ::value)> -inline auto format(const Locale& loc, format_string fmt, T&&... args) - -> std::string { - return fmt::vformat(loc, string_view(fmt), fmt::make_format_args(args...)); -} - -template ::value&& - detail::is_locale::value)> -auto vformat_to(OutputIt out, const Locale& loc, string_view fmt, - format_args args) -> OutputIt { - using detail::get_buffer; - auto&& buf = get_buffer(out); - detail::vformat_to(buf, fmt, args, detail::locale_ref(loc)); - return detail::get_iterator(buf, out); -} - -template ::value&& - detail::is_locale::value)> -FMT_INLINE auto format_to(OutputIt out, const Locale& loc, - format_string fmt, T&&... args) -> OutputIt { - return vformat_to(out, loc, fmt, fmt::make_format_args(args...)); -} - -template ::value)> -FMT_NODISCARD FMT_INLINE auto formatted_size(const Locale& loc, - format_string fmt, - T&&... args) -> size_t { - auto buf = detail::counting_buffer<>(); - detail::vformat_to(buf, fmt, fmt::make_format_args(args...), - detail::locale_ref(loc)); - return buf.count(); -} - -FMT_END_EXPORT - -FMT_END_NAMESPACE - -#ifdef FMT_HEADER_ONLY -# define FMT_FUNC inline -# include "format-inl.h" -#else -# define FMT_FUNC -#endif - -// Restore _LIBCPP_REMOVE_TRANSITIVE_INCLUDES. -#ifdef FMT_REMOVE_TRANSITIVE_INCLUDES -# undef _LIBCPP_REMOVE_TRANSITIVE_INCLUDES -#endif - -#endif // FMT_FORMAT_H_ diff --git a/tt_metal/third_party/fmt/fmt/os.h b/tt_metal/third_party/fmt/fmt/os.h deleted file mode 100644 index 5c85ea08ff4..00000000000 --- a/tt_metal/third_party/fmt/fmt/os.h +++ /dev/null @@ -1,439 +0,0 @@ -// Formatting library for C++ - optional OS-specific functionality -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_OS_H_ -#define FMT_OS_H_ - -#include "format.h" - -#ifndef FMT_MODULE -# include -# include -# include -# include // std::system_error - -# if FMT_HAS_INCLUDE() -# include // LC_NUMERIC_MASK on macOS -# endif -#endif // FMT_MODULE - -#ifndef FMT_USE_FCNTL -// UWP doesn't provide _pipe. -# if FMT_HAS_INCLUDE("winapifamily.h") -# include -# endif -# if (FMT_HAS_INCLUDE() || defined(__APPLE__) || \ - defined(__linux__)) && \ - (!defined(WINAPI_FAMILY) || \ - (WINAPI_FAMILY == WINAPI_FAMILY_DESKTOP_APP)) -# include // for O_RDONLY -# define FMT_USE_FCNTL 1 -# else -# define FMT_USE_FCNTL 0 -# endif -#endif - -#ifndef FMT_POSIX -# if defined(_WIN32) && !defined(__MINGW32__) -// Fix warnings about deprecated symbols. -# define FMT_POSIX(call) _##call -# else -# define FMT_POSIX(call) call -# endif -#endif - -// Calls to system functions are wrapped in FMT_SYSTEM for testability. -#ifdef FMT_SYSTEM -# define FMT_HAS_SYSTEM -# define FMT_POSIX_CALL(call) FMT_SYSTEM(call) -#else -# define FMT_SYSTEM(call) ::call -# ifdef _WIN32 -// Fix warnings about deprecated symbols. -# define FMT_POSIX_CALL(call) ::_##call -# else -# define FMT_POSIX_CALL(call) ::call -# endif -#endif - -// Retries the expression while it evaluates to error_result and errno -// equals to EINTR. -#ifndef _WIN32 -# define FMT_RETRY_VAL(result, expression, error_result) \ - do { \ - (result) = (expression); \ - } while ((result) == (error_result) && errno == EINTR) -#else -# define FMT_RETRY_VAL(result, expression, error_result) result = (expression) -#endif - -#define FMT_RETRY(result, expression) FMT_RETRY_VAL(result, expression, -1) - -FMT_BEGIN_NAMESPACE -FMT_BEGIN_EXPORT - -/** - * A reference to a null-terminated string. It can be constructed from a C - * string or `std::string`. - * - * You can use one of the following type aliases for common character types: - * - * +---------------+-----------------------------+ - * | Type | Definition | - * +===============+=============================+ - * | cstring_view | basic_cstring_view | - * +---------------+-----------------------------+ - * | wcstring_view | basic_cstring_view | - * +---------------+-----------------------------+ - * - * This class is most useful as a parameter type for functions that wrap C APIs. - */ -template class basic_cstring_view { - private: - const Char* data_; - - public: - /// Constructs a string reference object from a C string. - basic_cstring_view(const Char* s) : data_(s) {} - - /// Constructs a string reference from an `std::string` object. - basic_cstring_view(const std::basic_string& s) : data_(s.c_str()) {} - - /// Returns the pointer to a C string. - auto c_str() const -> const Char* { return data_; } -}; - -using cstring_view = basic_cstring_view; -using wcstring_view = basic_cstring_view; - -#ifdef _WIN32 -FMT_API const std::error_category& system_category() noexcept; - -namespace detail { -FMT_API void format_windows_error(buffer& out, int error_code, - const char* message) noexcept; -} - -FMT_API std::system_error vwindows_error(int error_code, string_view format_str, - format_args args); - -/** - * Constructs a `std::system_error` object with the description of the form - * - * : - * - * where `` is the formatted message and `` is the - * system message corresponding to the error code. - * `error_code` is a Windows error code as given by `GetLastError`. - * If `error_code` is not a valid error code such as -1, the system message - * will look like "error -1". - * - * **Example**: - * - * // This throws a system_error with the description - * // cannot open file 'madeup': The system cannot find the file - * specified. - * // or similar (system message may vary). - * const char *filename = "madeup"; - * LPOFSTRUCT of = LPOFSTRUCT(); - * HFILE file = OpenFile(filename, &of, OF_READ); - * if (file == HFILE_ERROR) { - * throw fmt::windows_error(GetLastError(), - * "cannot open file '{}'", filename); - * } - */ -template -std::system_error windows_error(int error_code, string_view message, - const Args&... args) { - return vwindows_error(error_code, message, fmt::make_format_args(args...)); -} - -// Reports a Windows error without throwing an exception. -// Can be used to report errors from destructors. -FMT_API void report_windows_error(int error_code, const char* message) noexcept; -#else -inline auto system_category() noexcept -> const std::error_category& { - return std::system_category(); -} -#endif // _WIN32 - -// std::system is not available on some platforms such as iOS (#2248). -#ifdef __OSX__ -template > -void say(const S& format_str, Args&&... args) { - std::system(format("say \"{}\"", format(format_str, args...)).c_str()); -} -#endif - -// A buffered file. -class buffered_file { - private: - FILE* file_; - - friend class file; - - explicit buffered_file(FILE* f) : file_(f) {} - - public: - buffered_file(const buffered_file&) = delete; - void operator=(const buffered_file&) = delete; - - // Constructs a buffered_file object which doesn't represent any file. - buffered_file() noexcept : file_(nullptr) {} - - // Destroys the object closing the file it represents if any. - FMT_API ~buffered_file() noexcept; - - public: - buffered_file(buffered_file&& other) noexcept : file_(other.file_) { - other.file_ = nullptr; - } - - auto operator=(buffered_file&& other) -> buffered_file& { - close(); - file_ = other.file_; - other.file_ = nullptr; - return *this; - } - - // Opens a file. - FMT_API buffered_file(cstring_view filename, cstring_view mode); - - // Closes the file. - FMT_API void close(); - - // Returns the pointer to a FILE object representing this file. - auto get() const noexcept -> FILE* { return file_; } - - FMT_API auto descriptor() const -> int; - - template - inline void print(string_view fmt, const T&... args) { - const auto& vargs = fmt::make_format_args(args...); - detail::is_locking() ? fmt::vprint_buffered(file_, fmt, vargs) - : fmt::vprint(file_, fmt, vargs); - } -}; - -#if FMT_USE_FCNTL - -// A file. Closed file is represented by a file object with descriptor -1. -// Methods that are not declared with noexcept may throw -// fmt::system_error in case of failure. Note that some errors such as -// closing the file multiple times will cause a crash on Windows rather -// than an exception. You can get standard behavior by overriding the -// invalid parameter handler with _set_invalid_parameter_handler. -class FMT_API file { - private: - int fd_; // File descriptor. - - // Constructs a file object with a given descriptor. - explicit file(int fd) : fd_(fd) {} - - friend struct pipe; - - public: - // Possible values for the oflag argument to the constructor. - enum { - RDONLY = FMT_POSIX(O_RDONLY), // Open for reading only. - WRONLY = FMT_POSIX(O_WRONLY), // Open for writing only. - RDWR = FMT_POSIX(O_RDWR), // Open for reading and writing. - CREATE = FMT_POSIX(O_CREAT), // Create if the file doesn't exist. - APPEND = FMT_POSIX(O_APPEND), // Open in append mode. - TRUNC = FMT_POSIX(O_TRUNC) // Truncate the content of the file. - }; - - // Constructs a file object which doesn't represent any file. - file() noexcept : fd_(-1) {} - - // Opens a file and constructs a file object representing this file. - file(cstring_view path, int oflag); - - public: - file(const file&) = delete; - void operator=(const file&) = delete; - - file(file&& other) noexcept : fd_(other.fd_) { other.fd_ = -1; } - - // Move assignment is not noexcept because close may throw. - auto operator=(file&& other) -> file& { - close(); - fd_ = other.fd_; - other.fd_ = -1; - return *this; - } - - // Destroys the object closing the file it represents if any. - ~file() noexcept; - - // Returns the file descriptor. - auto descriptor() const noexcept -> int { return fd_; } - - // Closes the file. - void close(); - - // Returns the file size. The size has signed type for consistency with - // stat::st_size. - auto size() const -> long long; - - // Attempts to read count bytes from the file into the specified buffer. - auto read(void* buffer, size_t count) -> size_t; - - // Attempts to write count bytes from the specified buffer to the file. - auto write(const void* buffer, size_t count) -> size_t; - - // Duplicates a file descriptor with the dup function and returns - // the duplicate as a file object. - static auto dup(int fd) -> file; - - // Makes fd be the copy of this file descriptor, closing fd first if - // necessary. - void dup2(int fd); - - // Makes fd be the copy of this file descriptor, closing fd first if - // necessary. - void dup2(int fd, std::error_code& ec) noexcept; - - // Creates a buffered_file object associated with this file and detaches - // this file object from the file. - auto fdopen(const char* mode) -> buffered_file; - -# if defined(_WIN32) && !defined(__MINGW32__) - // Opens a file and constructs a file object representing this file by - // wcstring_view filename. Windows only. - static file open_windows_file(wcstring_view path, int oflag); -# endif -}; - -struct FMT_API pipe { - file read_end; - file write_end; - - // Creates a pipe setting up read_end and write_end file objects for reading - // and writing respectively. - pipe(); -}; - -// Returns the memory page size. -auto getpagesize() -> long; - -namespace detail { - -struct buffer_size { - buffer_size() = default; - size_t value = 0; - auto operator=(size_t val) const -> buffer_size { - auto bs = buffer_size(); - bs.value = val; - return bs; - } -}; - -struct ostream_params { - int oflag = file::WRONLY | file::CREATE | file::TRUNC; - size_t buffer_size = BUFSIZ > 32768 ? BUFSIZ : 32768; - - ostream_params() {} - - template - ostream_params(T... params, int new_oflag) : ostream_params(params...) { - oflag = new_oflag; - } - - template - ostream_params(T... params, detail::buffer_size bs) - : ostream_params(params...) { - this->buffer_size = bs.value; - } - -// Intel has a bug that results in failure to deduce a constructor -// for empty parameter packs. -# if defined(__INTEL_COMPILER) && __INTEL_COMPILER < 2000 - ostream_params(int new_oflag) : oflag(new_oflag) {} - ostream_params(detail::buffer_size bs) : buffer_size(bs.value) {} -# endif -}; - -class file_buffer final : public buffer { - private: - file file_; - - FMT_API static void grow(buffer& buf, size_t); - - public: - FMT_API file_buffer(cstring_view path, const ostream_params& params); - FMT_API file_buffer(file_buffer&& other) noexcept; - FMT_API ~file_buffer(); - - void flush() { - if (size() == 0) return; - file_.write(data(), size() * sizeof(data()[0])); - clear(); - } - - void close() { - flush(); - file_.close(); - } -}; - -} // namespace detail - -constexpr auto buffer_size = detail::buffer_size(); - -/// A fast output stream for writing from a single thread. Writing from -/// multiple threads without external synchronization may result in a data race. -class FMT_API ostream { - private: - FMT_MSC_WARNING(suppress : 4251) - detail::file_buffer buffer_; - - ostream(cstring_view path, const detail::ostream_params& params) - : buffer_(path, params) {} - - public: - ostream(ostream&& other) : buffer_(std::move(other.buffer_)) {} - - ~ostream(); - - void flush() { buffer_.flush(); } - - template - friend auto output_file(cstring_view path, T... params) -> ostream; - - void close() { buffer_.close(); } - - /// Formats `args` according to specifications in `fmt` and writes the - /// output to the file. - template void print(format_string fmt, T&&... args) { - vformat_to(appender(buffer_), fmt, fmt::make_format_args(args...)); - } -}; - -/** - * Opens a file for writing. Supported parameters passed in `params`: - * - * - ``: Flags passed to [open]( - * https://pubs.opengroup.org/onlinepubs/007904875/functions/open.html) - * (`file::WRONLY | file::CREATE | file::TRUNC` by default) - * - `buffer_size=`: Output buffer size - * - * **Example**: - * - * auto out = fmt::output_file("guide.txt"); - * out.print("Don't {}", "Panic"); - */ -template -inline auto output_file(cstring_view path, T... params) -> ostream { - return {path, detail::ostream_params(params...)}; -} -#endif // FMT_USE_FCNTL - -FMT_END_EXPORT -FMT_END_NAMESPACE - -#endif // FMT_OS_H_ diff --git a/tt_metal/third_party/fmt/fmt/ostream.h b/tt_metal/third_party/fmt/fmt/ostream.h deleted file mode 100644 index 98faef659f5..00000000000 --- a/tt_metal/third_party/fmt/fmt/ostream.h +++ /dev/null @@ -1,211 +0,0 @@ -// Formatting library for C++ - std::ostream support -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_OSTREAM_H_ -#define FMT_OSTREAM_H_ - -#ifndef FMT_MODULE -# include // std::filebuf -#endif - -#ifdef _WIN32 -# ifdef __GLIBCXX__ -# include -# include -# endif -# include -#endif - -#include "chrono.h" // formatbuf - -FMT_BEGIN_NAMESPACE -namespace detail { - -// Generate a unique explicit instantion in every translation unit using a tag -// type in an anonymous namespace. -namespace { -struct file_access_tag {}; -} // namespace -template -class file_access { - friend auto get_file(BufType& obj) -> FILE* { return obj.*FileMemberPtr; } -}; - -#if FMT_MSC_VERSION -template class file_access; -auto get_file(std::filebuf&) -> FILE*; -#endif - -inline auto write_ostream_unicode(std::ostream& os, fmt::string_view data) - -> bool { - FILE* f = nullptr; -#if FMT_MSC_VERSION && FMT_USE_RTTI - if (auto* buf = dynamic_cast(os.rdbuf())) - f = get_file(*buf); - else - return false; -#elif defined(_WIN32) && defined(__GLIBCXX__) && FMT_USE_RTTI - auto* rdbuf = os.rdbuf(); - if (auto* sfbuf = dynamic_cast<__gnu_cxx::stdio_sync_filebuf*>(rdbuf)) - f = sfbuf->file(); - else if (auto* fbuf = dynamic_cast<__gnu_cxx::stdio_filebuf*>(rdbuf)) - f = fbuf->file(); - else - return false; -#else - ignore_unused(os, data, f); -#endif -#ifdef _WIN32 - if (f) { - int fd = _fileno(f); - if (_isatty(fd)) { - os.flush(); - return write_console(fd, data); - } - } -#endif - return false; -} -inline auto write_ostream_unicode(std::wostream&, - fmt::basic_string_view) -> bool { - return false; -} - -// Write the content of buf to os. -// It is a separate function rather than a part of vprint to simplify testing. -template -void write_buffer(std::basic_ostream& os, buffer& buf) { - const Char* buf_data = buf.data(); - using unsigned_streamsize = std::make_unsigned::type; - unsigned_streamsize size = buf.size(); - unsigned_streamsize max_size = to_unsigned(max_value()); - do { - unsigned_streamsize n = size <= max_size ? size : max_size; - os.write(buf_data, static_cast(n)); - buf_data += n; - size -= n; - } while (size != 0); -} - -template -void format_value(buffer& buf, const T& value) { - auto&& format_buf = formatbuf>(buf); - auto&& output = std::basic_ostream(&format_buf); -#if !defined(FMT_STATIC_THOUSANDS_SEPARATOR) - output.imbue(std::locale::classic()); // The default is always unlocalized. -#endif - output << value; - output.exceptions(std::ios_base::failbit | std::ios_base::badbit); -} - -template struct streamed_view { - const T& value; -}; - -} // namespace detail - -// Formats an object of type T that has an overloaded ostream operator<<. -template -struct basic_ostream_formatter : formatter, Char> { - void set_debug_format() = delete; - - template - auto format(const T& value, Context& ctx) const -> decltype(ctx.out()) { - auto buffer = basic_memory_buffer(); - detail::format_value(buffer, value); - return formatter, Char>::format( - {buffer.data(), buffer.size()}, ctx); - } -}; - -using ostream_formatter = basic_ostream_formatter; - -template -struct formatter, Char> - : basic_ostream_formatter { - template - auto format(detail::streamed_view view, Context& ctx) const - -> decltype(ctx.out()) { - return basic_ostream_formatter::format(view.value, ctx); - } -}; - -/** - * Returns a view that formats `value` via an ostream `operator<<`. - * - * **Example**: - * - * fmt::print("Current thread id: {}\n", - * fmt::streamed(std::this_thread::get_id())); - */ -template -constexpr auto streamed(const T& value) -> detail::streamed_view { - return {value}; -} - -namespace detail { - -inline void vprint_directly(std::ostream& os, string_view format_str, - format_args args) { - auto buffer = memory_buffer(); - detail::vformat_to(buffer, format_str, args); - detail::write_buffer(os, buffer); -} - -} // namespace detail - -FMT_EXPORT template -void vprint(std::basic_ostream& os, - basic_string_view> format_str, - typename detail::vformat_args::type args) { - auto buffer = basic_memory_buffer(); - detail::vformat_to(buffer, format_str, args); - if (detail::write_ostream_unicode(os, {buffer.data(), buffer.size()})) return; - detail::write_buffer(os, buffer); -} - -/** - * Prints formatted data to the stream `os`. - * - * **Example**: - * - * fmt::print(cerr, "Don't {}!", "panic"); - */ -FMT_EXPORT template -void print(std::ostream& os, format_string fmt, T&&... args) { - const auto& vargs = fmt::make_format_args(args...); - if (detail::use_utf8()) - vprint(os, fmt, vargs); - else - detail::vprint_directly(os, fmt, vargs); -} - -FMT_EXPORT -template -void print(std::wostream& os, - basic_format_string...> fmt, - Args&&... args) { - vprint(os, fmt, fmt::make_format_args>(args...)); -} - -FMT_EXPORT template -void println(std::ostream& os, format_string fmt, T&&... args) { - fmt::print(os, "{}\n", fmt::format(fmt, std::forward(args)...)); -} - -FMT_EXPORT -template -void println(std::wostream& os, - basic_format_string...> fmt, - Args&&... args) { - print(os, L"{}\n", fmt::format(fmt, std::forward(args)...)); -} - -FMT_END_NAMESPACE - -#endif // FMT_OSTREAM_H_ diff --git a/tt_metal/third_party/fmt/fmt/printf.h b/tt_metal/third_party/fmt/fmt/printf.h deleted file mode 100644 index 072cc6b309d..00000000000 --- a/tt_metal/third_party/fmt/fmt/printf.h +++ /dev/null @@ -1,656 +0,0 @@ -// Formatting library for C++ - legacy printf implementation -// -// Copyright (c) 2012 - 2016, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_PRINTF_H_ -#define FMT_PRINTF_H_ - -#ifndef FMT_MODULE -# include // std::max -# include // std::numeric_limits -#endif - -#include "format.h" - -FMT_BEGIN_NAMESPACE -FMT_BEGIN_EXPORT - -template struct printf_formatter { - printf_formatter() = delete; -}; - -template class basic_printf_context { - private: - basic_appender out_; - basic_format_args args_; - - static_assert(std::is_same::value || - std::is_same::value, - "Unsupported code unit type."); - - public: - using char_type = Char; - using parse_context_type = basic_format_parse_context; - template using formatter_type = printf_formatter; - - /// Constructs a `printf_context` object. References to the arguments are - /// stored in the context object so make sure they have appropriate lifetimes. - basic_printf_context(basic_appender out, - basic_format_args args) - : out_(out), args_(args) {} - - auto out() -> basic_appender { return out_; } - void advance_to(basic_appender) {} - - auto locale() -> detail::locale_ref { return {}; } - - auto arg(int id) const -> basic_format_arg { - return args_.get(id); - } -}; - -namespace detail { - -// Checks if a value fits in int - used to avoid warnings about comparing -// signed and unsigned integers. -template struct int_checker { - template static auto fits_in_int(T value) -> bool { - unsigned max = to_unsigned(max_value()); - return value <= max; - } - static auto fits_in_int(bool) -> bool { return true; } -}; - -template <> struct int_checker { - template static auto fits_in_int(T value) -> bool { - return value >= (std::numeric_limits::min)() && - value <= max_value(); - } - static auto fits_in_int(int) -> bool { return true; } -}; - -struct printf_precision_handler { - template ::value)> - auto operator()(T value) -> int { - if (!int_checker::is_signed>::fits_in_int(value)) - report_error("number is too big"); - return (std::max)(static_cast(value), 0); - } - - template ::value)> - auto operator()(T) -> int { - report_error("precision is not integer"); - return 0; - } -}; - -// An argument visitor that returns true iff arg is a zero integer. -struct is_zero_int { - template ::value)> - auto operator()(T value) -> bool { - return value == 0; - } - - template ::value)> - auto operator()(T) -> bool { - return false; - } -}; - -template struct make_unsigned_or_bool : std::make_unsigned {}; - -template <> struct make_unsigned_or_bool { - using type = bool; -}; - -template class arg_converter { - private: - using char_type = typename Context::char_type; - - basic_format_arg& arg_; - char_type type_; - - public: - arg_converter(basic_format_arg& arg, char_type type) - : arg_(arg), type_(type) {} - - void operator()(bool value) { - if (type_ != 's') operator()(value); - } - - template ::value)> - void operator()(U value) { - bool is_signed = type_ == 'd' || type_ == 'i'; - using target_type = conditional_t::value, U, T>; - if (const_check(sizeof(target_type) <= sizeof(int))) { - // Extra casts are used to silence warnings. - if (is_signed) { - auto n = static_cast(static_cast(value)); - arg_ = detail::make_arg(n); - } else { - using unsigned_type = typename make_unsigned_or_bool::type; - auto n = static_cast(static_cast(value)); - arg_ = detail::make_arg(n); - } - } else { - if (is_signed) { - // glibc's printf doesn't sign extend arguments of smaller types: - // std::printf("%lld", -42); // prints "4294967254" - // but we don't have to do the same because it's a UB. - auto n = static_cast(value); - arg_ = detail::make_arg(n); - } else { - auto n = static_cast::type>(value); - arg_ = detail::make_arg(n); - } - } - } - - template ::value)> - void operator()(U) {} // No conversion needed for non-integral types. -}; - -// Converts an integer argument to T for printf, if T is an integral type. -// If T is void, the argument is converted to corresponding signed or unsigned -// type depending on the type specifier: 'd' and 'i' - signed, other - -// unsigned). -template -void convert_arg(basic_format_arg& arg, Char type) { - arg.visit(arg_converter(arg, type)); -} - -// Converts an integer argument to char for printf. -template class char_converter { - private: - basic_format_arg& arg_; - - public: - explicit char_converter(basic_format_arg& arg) : arg_(arg) {} - - template ::value)> - void operator()(T value) { - auto c = static_cast(value); - arg_ = detail::make_arg(c); - } - - template ::value)> - void operator()(T) {} // No conversion needed for non-integral types. -}; - -// An argument visitor that return a pointer to a C string if argument is a -// string or null otherwise. -template struct get_cstring { - template auto operator()(T) -> const Char* { return nullptr; } - auto operator()(const Char* s) -> const Char* { return s; } -}; - -// Checks if an argument is a valid printf width specifier and sets -// left alignment if it is negative. -class printf_width_handler { - private: - format_specs& specs_; - - public: - explicit printf_width_handler(format_specs& specs) : specs_(specs) {} - - template ::value)> - auto operator()(T value) -> unsigned { - auto width = static_cast>(value); - if (detail::is_negative(value)) { - specs_.align = align::left; - width = 0 - width; - } - unsigned int_max = to_unsigned(max_value()); - if (width > int_max) report_error("number is too big"); - return static_cast(width); - } - - template ::value)> - auto operator()(T) -> unsigned { - report_error("width is not integer"); - return 0; - } -}; - -// Workaround for a bug with the XL compiler when initializing -// printf_arg_formatter's base class. -template -auto make_arg_formatter(basic_appender iter, format_specs& s) - -> arg_formatter { - return {iter, s, locale_ref()}; -} - -// The `printf` argument formatter. -template -class printf_arg_formatter : public arg_formatter { - private: - using base = arg_formatter; - using context_type = basic_printf_context; - - context_type& context_; - - void write_null_pointer(bool is_string = false) { - auto s = this->specs; - s.type = presentation_type::none; - write_bytes(this->out, is_string ? "(null)" : "(nil)", s); - } - - public: - printf_arg_formatter(basic_appender iter, format_specs& s, - context_type& ctx) - : base(make_arg_formatter(iter, s)), context_(ctx) {} - - void operator()(monostate value) { base::operator()(value); } - - template ::value)> - void operator()(T value) { - // MSVC2013 fails to compile separate overloads for bool and Char so use - // std::is_same instead. - if (!std::is_same::value) { - base::operator()(value); - return; - } - format_specs s = this->specs; - if (s.type != presentation_type::none && s.type != presentation_type::chr) { - return (*this)(static_cast(value)); - } - s.sign = sign::none; - s.alt = false; - s.fill = ' '; // Ignore '0' flag for char types. - // align::numeric needs to be overwritten here since the '0' flag is - // ignored for non-numeric types - if (s.align == align::none || s.align == align::numeric) - s.align = align::right; - write(this->out, static_cast(value), s); - } - - template ::value)> - void operator()(T value) { - base::operator()(value); - } - - void operator()(const char* value) { - if (value) - base::operator()(value); - else - write_null_pointer(this->specs.type != presentation_type::pointer); - } - - void operator()(const wchar_t* value) { - if (value) - base::operator()(value); - else - write_null_pointer(this->specs.type != presentation_type::pointer); - } - - void operator()(basic_string_view value) { base::operator()(value); } - - void operator()(const void* value) { - if (value) - base::operator()(value); - else - write_null_pointer(); - } - - void operator()(typename basic_format_arg::handle handle) { - auto parse_ctx = basic_format_parse_context({}); - handle.format(parse_ctx, context_); - } -}; - -template -void parse_flags(format_specs& specs, const Char*& it, const Char* end) { - for (; it != end; ++it) { - switch (*it) { - case '-': - specs.align = align::left; - break; - case '+': - specs.sign = sign::plus; - break; - case '0': - specs.fill = '0'; - break; - case ' ': - if (specs.sign != sign::plus) specs.sign = sign::space; - break; - case '#': - specs.alt = true; - break; - default: - return; - } - } -} - -template -auto parse_header(const Char*& it, const Char* end, format_specs& specs, - GetArg get_arg) -> int { - int arg_index = -1; - Char c = *it; - if (c >= '0' && c <= '9') { - // Parse an argument index (if followed by '$') or a width possibly - // preceded with '0' flag(s). - int value = parse_nonnegative_int(it, end, -1); - if (it != end && *it == '$') { // value is an argument index - ++it; - arg_index = value != -1 ? value : max_value(); - } else { - if (c == '0') specs.fill = '0'; - if (value != 0) { - // Nonzero value means that we parsed width and don't need to - // parse it or flags again, so return now. - if (value == -1) report_error("number is too big"); - specs.width = value; - return arg_index; - } - } - } - parse_flags(specs, it, end); - // Parse width. - if (it != end) { - if (*it >= '0' && *it <= '9') { - specs.width = parse_nonnegative_int(it, end, -1); - if (specs.width == -1) report_error("number is too big"); - } else if (*it == '*') { - ++it; - specs.width = static_cast( - get_arg(-1).visit(detail::printf_width_handler(specs))); - } - } - return arg_index; -} - -inline auto parse_printf_presentation_type(char c, type t, bool& upper) - -> presentation_type { - using pt = presentation_type; - constexpr auto integral_set = sint_set | uint_set | bool_set | char_set; - switch (c) { - case 'd': - return in(t, integral_set) ? pt::dec : pt::none; - case 'o': - return in(t, integral_set) ? pt::oct : pt::none; - case 'X': - upper = true; - FMT_FALLTHROUGH; - case 'x': - return in(t, integral_set) ? pt::hex : pt::none; - case 'E': - upper = true; - FMT_FALLTHROUGH; - case 'e': - return in(t, float_set) ? pt::exp : pt::none; - case 'F': - upper = true; - FMT_FALLTHROUGH; - case 'f': - return in(t, float_set) ? pt::fixed : pt::none; - case 'G': - upper = true; - FMT_FALLTHROUGH; - case 'g': - return in(t, float_set) ? pt::general : pt::none; - case 'A': - upper = true; - FMT_FALLTHROUGH; - case 'a': - return in(t, float_set) ? pt::hexfloat : pt::none; - case 'c': - return in(t, integral_set) ? pt::chr : pt::none; - case 's': - return in(t, string_set | cstring_set) ? pt::string : pt::none; - case 'p': - return in(t, pointer_set | cstring_set) ? pt::pointer : pt::none; - default: - return pt::none; - } -} - -template -void vprintf(buffer& buf, basic_string_view format, - basic_format_args args) { - using iterator = basic_appender; - auto out = iterator(buf); - auto context = basic_printf_context(out, args); - auto parse_ctx = basic_format_parse_context(format); - - // Returns the argument with specified index or, if arg_index is -1, the next - // argument. - auto get_arg = [&](int arg_index) { - if (arg_index < 0) - arg_index = parse_ctx.next_arg_id(); - else - parse_ctx.check_arg_id(--arg_index); - return detail::get_arg(context, arg_index); - }; - - const Char* start = parse_ctx.begin(); - const Char* end = parse_ctx.end(); - auto it = start; - while (it != end) { - if (!find(it, end, '%', it)) { - it = end; // find leaves it == nullptr if it doesn't find '%'. - break; - } - Char c = *it++; - if (it != end && *it == c) { - write(out, basic_string_view(start, to_unsigned(it - start))); - start = ++it; - continue; - } - write(out, basic_string_view(start, to_unsigned(it - 1 - start))); - - auto specs = format_specs(); - specs.align = align::right; - - // Parse argument index, flags and width. - int arg_index = parse_header(it, end, specs, get_arg); - if (arg_index == 0) report_error("argument not found"); - - // Parse precision. - if (it != end && *it == '.') { - ++it; - c = it != end ? *it : 0; - if ('0' <= c && c <= '9') { - specs.precision = parse_nonnegative_int(it, end, 0); - } else if (c == '*') { - ++it; - specs.precision = - static_cast(get_arg(-1).visit(printf_precision_handler())); - } else { - specs.precision = 0; - } - } - - auto arg = get_arg(arg_index); - // For d, i, o, u, x, and X conversion specifiers, if a precision is - // specified, the '0' flag is ignored - if (specs.precision >= 0 && arg.is_integral()) { - // Ignore '0' for non-numeric types or if '-' present. - specs.fill = ' '; - } - if (specs.precision >= 0 && arg.type() == type::cstring_type) { - auto str = arg.visit(get_cstring()); - auto str_end = str + specs.precision; - auto nul = std::find(str, str_end, Char()); - auto sv = basic_string_view( - str, to_unsigned(nul != str_end ? nul - str : specs.precision)); - arg = make_arg>(sv); - } - if (specs.alt && arg.visit(is_zero_int())) specs.alt = false; - if (specs.fill.template get() == '0') { - if (arg.is_arithmetic() && specs.align != align::left) - specs.align = align::numeric; - else - specs.fill = ' '; // Ignore '0' flag for non-numeric types or if '-' - // flag is also present. - } - - // Parse length and convert the argument to the required type. - c = it != end ? *it++ : 0; - Char t = it != end ? *it : 0; - switch (c) { - case 'h': - if (t == 'h') { - ++it; - t = it != end ? *it : 0; - convert_arg(arg, t); - } else { - convert_arg(arg, t); - } - break; - case 'l': - if (t == 'l') { - ++it; - t = it != end ? *it : 0; - convert_arg(arg, t); - } else { - convert_arg(arg, t); - } - break; - case 'j': - convert_arg(arg, t); - break; - case 'z': - convert_arg(arg, t); - break; - case 't': - convert_arg(arg, t); - break; - case 'L': - // printf produces garbage when 'L' is omitted for long double, no - // need to do the same. - break; - default: - --it; - convert_arg(arg, c); - } - - // Parse type. - if (it == end) report_error("invalid format string"); - char type = static_cast(*it++); - if (arg.is_integral()) { - // Normalize type. - switch (type) { - case 'i': - case 'u': - type = 'd'; - break; - case 'c': - arg.visit(char_converter>(arg)); - break; - } - } - bool upper = false; - specs.type = parse_printf_presentation_type(type, arg.type(), upper); - if (specs.type == presentation_type::none) - report_error("invalid format specifier"); - specs.upper = upper; - - start = it; - - // Format argument. - arg.visit(printf_arg_formatter(out, specs, context)); - } - write(out, basic_string_view(start, to_unsigned(it - start))); -} -} // namespace detail - -using printf_context = basic_printf_context; -using wprintf_context = basic_printf_context; - -using printf_args = basic_format_args; -using wprintf_args = basic_format_args; - -/// Constructs an `format_arg_store` object that contains references to -/// arguments and can be implicitly converted to `printf_args`. -template -inline auto make_printf_args(T&... args) - -> decltype(fmt::make_format_args>(args...)) { - return fmt::make_format_args>(args...); -} - -template struct vprintf_args { - using type = basic_format_args>; -}; - -template -inline auto vsprintf(basic_string_view fmt, - typename vprintf_args::type args) - -> std::basic_string { - auto buf = basic_memory_buffer(); - detail::vprintf(buf, fmt, args); - return to_string(buf); -} - -/** - * Formats `args` according to specifications in `fmt` and returns the result - * as as string. - * - * **Example**: - * - * std::string message = fmt::sprintf("The answer is %d", 42); - */ -template > -inline auto sprintf(const S& fmt, const T&... args) -> std::basic_string { - return vsprintf(detail::to_string_view(fmt), - fmt::make_format_args>(args...)); -} - -template -inline auto vfprintf(std::FILE* f, basic_string_view fmt, - typename vprintf_args::type args) -> int { - auto buf = basic_memory_buffer(); - detail::vprintf(buf, fmt, args); - size_t size = buf.size(); - return std::fwrite(buf.data(), sizeof(Char), size, f) < size - ? -1 - : static_cast(size); -} - -/** - * Formats `args` according to specifications in `fmt` and writes the output - * to `f`. - * - * **Example**: - * - * fmt::fprintf(stderr, "Don't %s!", "panic"); - */ -template > -inline auto fprintf(std::FILE* f, const S& fmt, const T&... args) -> int { - return vfprintf(f, detail::to_string_view(fmt), - make_printf_args(args...)); -} - -template -FMT_DEPRECATED inline auto vprintf(basic_string_view fmt, - typename vprintf_args::type args) - -> int { - return vfprintf(stdout, fmt, args); -} - -/** - * Formats `args` according to specifications in `fmt` and writes the output - * to `stdout`. - * - * **Example**: - * - * fmt::printf("Elapsed time: %.2f seconds", 1.23); - */ -template -inline auto printf(string_view fmt, const T&... args) -> int { - return vfprintf(stdout, fmt, make_printf_args(args...)); -} -template -FMT_DEPRECATED inline auto printf(basic_string_view fmt, - const T&... args) -> int { - return vfprintf(stdout, fmt, make_printf_args(args...)); -} - -FMT_END_EXPORT -FMT_END_NAMESPACE - -#endif // FMT_PRINTF_H_ diff --git a/tt_metal/third_party/fmt/fmt/ranges.h b/tt_metal/third_party/fmt/fmt/ranges.h deleted file mode 100644 index f387903cf63..00000000000 --- a/tt_metal/third_party/fmt/fmt/ranges.h +++ /dev/null @@ -1,882 +0,0 @@ -// Formatting library for C++ - range and tuple support -// -// Copyright (c) 2012 - present, Victor Zverovich and {fmt} contributors -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_RANGES_H_ -#define FMT_RANGES_H_ - -#ifndef FMT_MODULE -# include -# include -# include -# include -# include -# include -#endif - -#include "format.h" - -FMT_BEGIN_NAMESPACE - -FMT_EXPORT -enum class range_format { disabled, map, set, sequence, string, debug_string }; - -namespace detail { - -template class is_map { - template static auto check(U*) -> typename U::mapped_type; - template static void check(...); - - public: - static constexpr const bool value = - !std::is_void(nullptr))>::value; -}; - -template class is_set { - template static auto check(U*) -> typename U::key_type; - template static void check(...); - - public: - static constexpr const bool value = - !std::is_void(nullptr))>::value && !is_map::value; -}; - -template struct conditional_helper {}; - -template struct is_range_ : std::false_type {}; - -#if !FMT_MSC_VERSION || FMT_MSC_VERSION > 1800 - -# define FMT_DECLTYPE_RETURN(val) \ - ->decltype(val) { return val; } \ - static_assert( \ - true, "") // This makes it so that a semicolon is required after the - // macro, which helps clang-format handle the formatting. - -// C array overload -template -auto range_begin(const T (&arr)[N]) -> const T* { - return arr; -} -template -auto range_end(const T (&arr)[N]) -> const T* { - return arr + N; -} - -template -struct has_member_fn_begin_end_t : std::false_type {}; - -template -struct has_member_fn_begin_end_t().begin()), - decltype(std::declval().end())>> - : std::true_type {}; - -// Member function overloads. -template -auto range_begin(T&& rng) FMT_DECLTYPE_RETURN(static_cast(rng).begin()); -template -auto range_end(T&& rng) FMT_DECLTYPE_RETURN(static_cast(rng).end()); - -// ADL overloads. Only participate in overload resolution if member functions -// are not found. -template -auto range_begin(T&& rng) - -> enable_if_t::value, - decltype(begin(static_cast(rng)))> { - return begin(static_cast(rng)); -} -template -auto range_end(T&& rng) -> enable_if_t::value, - decltype(end(static_cast(rng)))> { - return end(static_cast(rng)); -} - -template -struct has_const_begin_end : std::false_type {}; -template -struct has_mutable_begin_end : std::false_type {}; - -template -struct has_const_begin_end< - T, void_t&>())), - decltype(detail::range_end( - std::declval&>()))>> - : std::true_type {}; - -template -struct has_mutable_begin_end< - T, void_t())), - decltype(detail::range_end(std::declval())), - // the extra int here is because older versions of MSVC don't - // SFINAE properly unless there are distinct types - int>> : std::true_type {}; - -template -struct is_range_ - : std::integral_constant::value || - has_mutable_begin_end::value)> {}; -# undef FMT_DECLTYPE_RETURN -#endif - -// tuple_size and tuple_element check. -template class is_tuple_like_ { - template - static auto check(U* p) -> decltype(std::tuple_size::value, int()); - template static void check(...); - - public: - static constexpr const bool value = - !std::is_void(nullptr))>::value; -}; - -// Check for integer_sequence -#if defined(__cpp_lib_integer_sequence) || FMT_MSC_VERSION >= 1900 -template -using integer_sequence = std::integer_sequence; -template using index_sequence = std::index_sequence; -template using make_index_sequence = std::make_index_sequence; -#else -template struct integer_sequence { - using value_type = T; - - static FMT_CONSTEXPR auto size() -> size_t { return sizeof...(N); } -}; - -template using index_sequence = integer_sequence; - -template -struct make_integer_sequence : make_integer_sequence {}; -template -struct make_integer_sequence : integer_sequence {}; - -template -using make_index_sequence = make_integer_sequence; -#endif - -template -using tuple_index_sequence = make_index_sequence::value>; - -template ::value> -class is_tuple_formattable_ { - public: - static constexpr const bool value = false; -}; -template class is_tuple_formattable_ { - template - static auto all_true(index_sequence, - integer_sequence= 0)...>) -> std::true_type; - static auto all_true(...) -> std::false_type; - - template - static auto check(index_sequence) -> decltype(all_true( - index_sequence{}, - integer_sequence::type, - C>::value)...>{})); - - public: - static constexpr const bool value = - decltype(check(tuple_index_sequence{}))::value; -}; - -template -FMT_CONSTEXPR void for_each(index_sequence, Tuple&& t, F&& f) { - using std::get; - // Using a free function get(Tuple) now. - const int unused[] = {0, ((void)f(get(t)), 0)...}; - ignore_unused(unused); -} - -template -FMT_CONSTEXPR void for_each(Tuple&& t, F&& f) { - for_each(tuple_index_sequence>(), - std::forward(t), std::forward(f)); -} - -template -void for_each2(index_sequence, Tuple1&& t1, Tuple2&& t2, F&& f) { - using std::get; - const int unused[] = {0, ((void)f(get(t1), get(t2)), 0)...}; - ignore_unused(unused); -} - -template -void for_each2(Tuple1&& t1, Tuple2&& t2, F&& f) { - for_each2(tuple_index_sequence>(), - std::forward(t1), std::forward(t2), - std::forward(f)); -} - -namespace tuple { -// Workaround a bug in MSVC 2019 (v140). -template -using result_t = std::tuple, Char>...>; - -using std::get; -template -auto get_formatters(index_sequence) - -> result_t(std::declval()))...>; -} // namespace tuple - -#if FMT_MSC_VERSION && FMT_MSC_VERSION < 1920 -// Older MSVC doesn't get the reference type correctly for arrays. -template struct range_reference_type_impl { - using type = decltype(*detail::range_begin(std::declval())); -}; - -template struct range_reference_type_impl { - using type = T&; -}; - -template -using range_reference_type = typename range_reference_type_impl::type; -#else -template -using range_reference_type = - decltype(*detail::range_begin(std::declval())); -#endif - -// We don't use the Range's value_type for anything, but we do need the Range's -// reference type, with cv-ref stripped. -template -using uncvref_type = remove_cvref_t>; - -template -FMT_CONSTEXPR auto maybe_set_debug_format(Formatter& f, bool set) - -> decltype(f.set_debug_format(set)) { - f.set_debug_format(set); -} -template -FMT_CONSTEXPR void maybe_set_debug_format(Formatter&, ...) {} - -template -struct range_format_kind_ - : std::integral_constant, T>::value - ? range_format::disabled - : is_map::value ? range_format::map - : is_set::value ? range_format::set - : range_format::sequence> {}; - -template -using range_format_constant = std::integral_constant; - -// These are not generic lambdas for compatibility with C++11. -template struct parse_empty_specs { - template FMT_CONSTEXPR void operator()(Formatter& f) { - f.parse(ctx); - detail::maybe_set_debug_format(f, true); - } - ParseContext& ctx; -}; -template struct format_tuple_element { - using char_type = typename FormatContext::char_type; - - template - void operator()(const formatter& f, const T& v) { - if (i > 0) ctx.advance_to(detail::copy(separator, ctx.out())); - ctx.advance_to(f.format(v, ctx)); - ++i; - } - - int i; - FormatContext& ctx; - basic_string_view separator; -}; - -} // namespace detail - -template struct is_tuple_like { - static constexpr const bool value = - detail::is_tuple_like_::value && !detail::is_range_::value; -}; - -template struct is_tuple_formattable { - static constexpr const bool value = - detail::is_tuple_formattable_::value; -}; - -template -struct formatter::value && - fmt::is_tuple_formattable::value>> { - private: - decltype(detail::tuple::get_formatters( - detail::tuple_index_sequence())) formatters_; - - basic_string_view separator_ = detail::string_literal{}; - basic_string_view opening_bracket_ = - detail::string_literal{}; - basic_string_view closing_bracket_ = - detail::string_literal{}; - - public: - FMT_CONSTEXPR formatter() {} - - FMT_CONSTEXPR void set_separator(basic_string_view sep) { - separator_ = sep; - } - - FMT_CONSTEXPR void set_brackets(basic_string_view open, - basic_string_view close) { - opening_bracket_ = open; - closing_bracket_ = close; - } - - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - auto it = ctx.begin(); - if (it != ctx.end() && *it != '}') report_error("invalid format specifier"); - detail::for_each(formatters_, detail::parse_empty_specs{ctx}); - return it; - } - - template - auto format(const Tuple& value, FormatContext& ctx) const - -> decltype(ctx.out()) { - ctx.advance_to(detail::copy(opening_bracket_, ctx.out())); - detail::for_each2( - formatters_, value, - detail::format_tuple_element{0, ctx, separator_}); - return detail::copy(closing_bracket_, ctx.out()); - } -}; - -template struct is_range { - static constexpr const bool value = - detail::is_range_::value && !detail::has_to_string_view::value; -}; - -namespace detail { -template struct range_mapper { - using mapper = arg_mapper; - - template , Context>::value)> - static auto map(T&& value) -> T&& { - return static_cast(value); - } - template , Context>::value)> - static auto map(T&& value) - -> decltype(mapper().map(static_cast(value))) { - return mapper().map(static_cast(value)); - } -}; - -template -using range_formatter_type = - formatter>{} - .map(std::declval()))>, - Char>; - -template -using maybe_const_range = - conditional_t::value, const R, R>; - -// Workaround a bug in MSVC 2015 and earlier. -#if !FMT_MSC_VERSION || FMT_MSC_VERSION >= 1910 -template -struct is_formattable_delayed - : is_formattable>, Char> {}; -#endif -} // namespace detail - -template struct conjunction : std::true_type {}; -template struct conjunction

: P {}; -template -struct conjunction - : conditional_t, P1> {}; - -template -struct range_formatter; - -template -struct range_formatter< - T, Char, - enable_if_t>, - is_formattable>::value>> { - private: - detail::range_formatter_type underlying_; - basic_string_view separator_ = detail::string_literal{}; - basic_string_view opening_bracket_ = - detail::string_literal{}; - basic_string_view closing_bracket_ = - detail::string_literal{}; - bool is_debug = false; - - template ::value)> - auto write_debug_string(Output& out, It it, Sentinel end) const -> Output { - auto buf = basic_memory_buffer(); - for (; it != end; ++it) buf.push_back(*it); - auto specs = format_specs(); - specs.type = presentation_type::debug; - return detail::write( - out, basic_string_view(buf.data(), buf.size()), specs); - } - - template ::value)> - auto write_debug_string(Output& out, It, Sentinel) const -> Output { - return out; - } - - public: - FMT_CONSTEXPR range_formatter() {} - - FMT_CONSTEXPR auto underlying() -> detail::range_formatter_type& { - return underlying_; - } - - FMT_CONSTEXPR void set_separator(basic_string_view sep) { - separator_ = sep; - } - - FMT_CONSTEXPR void set_brackets(basic_string_view open, - basic_string_view close) { - opening_bracket_ = open; - closing_bracket_ = close; - } - - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - auto it = ctx.begin(); - auto end = ctx.end(); - detail::maybe_set_debug_format(underlying_, true); - if (it == end) return underlying_.parse(ctx); - - switch (detail::to_ascii(*it)) { - case 'n': - set_brackets({}, {}); - ++it; - break; - case '?': - is_debug = true; - set_brackets({}, {}); - ++it; - if (it == end || *it != 's') report_error("invalid format specifier"); - FMT_FALLTHROUGH; - case 's': - if (!std::is_same::value) - report_error("invalid format specifier"); - if (!is_debug) { - set_brackets(detail::string_literal{}, - detail::string_literal{}); - set_separator({}); - detail::maybe_set_debug_format(underlying_, false); - } - ++it; - return it; - } - - if (it != end && *it != '}') { - if (*it != ':') report_error("invalid format specifier"); - detail::maybe_set_debug_format(underlying_, false); - ++it; - } - - ctx.advance_to(it); - return underlying_.parse(ctx); - } - - template - auto format(R&& range, FormatContext& ctx) const -> decltype(ctx.out()) { - auto mapper = detail::range_mapper>(); - auto out = ctx.out(); - auto it = detail::range_begin(range); - auto end = detail::range_end(range); - if (is_debug) return write_debug_string(out, it, end); - - out = detail::copy(opening_bracket_, out); - int i = 0; - for (; it != end; ++it) { - if (i > 0) out = detail::copy(separator_, out); - ctx.advance_to(out); - auto&& item = *it; // Need an lvalue - out = underlying_.format(mapper.map(item), ctx); - ++i; - } - out = detail::copy(closing_bracket_, out); - return out; - } -}; - -FMT_EXPORT -template -struct range_format_kind - : conditional_t< - is_range::value, detail::range_format_kind_, - std::integral_constant> {}; - -template -struct formatter< - R, Char, - enable_if_t::value != range_format::disabled && - range_format_kind::value != range_format::map && - range_format_kind::value != range_format::string && - range_format_kind::value != range_format::debug_string> -// Workaround a bug in MSVC 2015 and earlier. -#if !FMT_MSC_VERSION || FMT_MSC_VERSION >= 1910 - , - detail::is_formattable_delayed -#endif - >::value>> { - private: - using range_type = detail::maybe_const_range; - range_formatter, Char> range_formatter_; - - public: - using nonlocking = void; - - FMT_CONSTEXPR formatter() { - if (detail::const_check(range_format_kind::value != - range_format::set)) - return; - range_formatter_.set_brackets(detail::string_literal{}, - detail::string_literal{}); - } - - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - return range_formatter_.parse(ctx); - } - - template - auto format(range_type& range, FormatContext& ctx) const - -> decltype(ctx.out()) { - return range_formatter_.format(range, ctx); - } -}; - -// A map formatter. -template -struct formatter< - R, Char, - enable_if_t::value == range_format::map>> { - private: - using map_type = detail::maybe_const_range; - using element_type = detail::uncvref_type; - - decltype(detail::tuple::get_formatters( - detail::tuple_index_sequence())) formatters_; - bool no_delimiters_ = false; - - public: - FMT_CONSTEXPR formatter() {} - - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - auto it = ctx.begin(); - auto end = ctx.end(); - if (it != end) { - if (detail::to_ascii(*it) == 'n') { - no_delimiters_ = true; - ++it; - } - if (it != end && *it != '}') { - if (*it != ':') report_error("invalid format specifier"); - ++it; - } - ctx.advance_to(it); - } - detail::for_each(formatters_, detail::parse_empty_specs{ctx}); - return it; - } - - template - auto format(map_type& map, FormatContext& ctx) const -> decltype(ctx.out()) { - auto out = ctx.out(); - basic_string_view open = detail::string_literal{}; - if (!no_delimiters_) out = detail::copy(open, out); - int i = 0; - auto mapper = detail::range_mapper>(); - basic_string_view sep = detail::string_literal{}; - for (auto&& value : map) { - if (i > 0) out = detail::copy(sep, out); - ctx.advance_to(out); - detail::for_each2(formatters_, mapper.map(value), - detail::format_tuple_element{ - 0, ctx, detail::string_literal{}}); - ++i; - } - basic_string_view close = detail::string_literal{}; - if (!no_delimiters_) out = detail::copy(close, out); - return out; - } -}; - -// A (debug_)string formatter. -template -struct formatter< - R, Char, - enable_if_t::value == range_format::string || - range_format_kind::value == - range_format::debug_string>> { - private: - using range_type = detail::maybe_const_range; - using string_type = - conditional_t, - decltype(detail::range_begin(std::declval())), - decltype(detail::range_end(std::declval()))>::value, - detail::std_string_view, std::basic_string>; - - formatter underlying_; - - public: - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - return underlying_.parse(ctx); - } - - template - auto format(range_type& range, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto out = ctx.out(); - if (detail::const_check(range_format_kind::value == - range_format::debug_string)) - *out++ = '"'; - out = underlying_.format( - string_type{detail::range_begin(range), detail::range_end(range)}, ctx); - if (detail::const_check(range_format_kind::value == - range_format::debug_string)) - *out++ = '"'; - return out; - } -}; - -template -struct join_view : detail::view { - It begin; - Sentinel end; - basic_string_view sep; - - join_view(It b, Sentinel e, basic_string_view s) - : begin(std::move(b)), end(e), sep(s) {} -}; - -template -struct formatter, Char> { - private: - using value_type = -#ifdef __cpp_lib_ranges - std::iter_value_t; -#else - typename std::iterator_traits::value_type; -#endif - formatter, Char> value_formatter_; - - using view_ref = conditional_t::value, - const join_view&, - join_view&&>; - - public: - using nonlocking = void; - - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> const Char* { - return value_formatter_.parse(ctx); - } - - template - auto format(view_ref& value, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto it = std::forward(value).begin; - auto out = ctx.out(); - if (it == value.end) return out; - out = value_formatter_.format(*it, ctx); - ++it; - while (it != value.end) { - out = detail::copy(value.sep.begin(), value.sep.end(), out); - ctx.advance_to(out); - out = value_formatter_.format(*it, ctx); - ++it; - } - return out; - } -}; - -/// Returns a view that formats the iterator range `[begin, end)` with elements -/// separated by `sep`. -template -auto join(It begin, Sentinel end, string_view sep) -> join_view { - return {std::move(begin), end, sep}; -} - -/** - * Returns a view that formats `range` with elements separated by `sep`. - * - * **Example**: - * - * auto v = std::vector{1, 2, 3}; - * fmt::print("{}", fmt::join(v, ", ")); - * // Output: 1, 2, 3 - * - * `fmt::join` applies passed format specifiers to the range elements: - * - * fmt::print("{:02}", fmt::join(v, ", ")); - * // Output: 01, 02, 03 - */ -template -auto join(Range&& r, string_view sep) - -> join_view { - return {detail::range_begin(r), detail::range_end(r), sep}; -} - -template struct tuple_join_view : detail::view { - const std::tuple& tuple; - basic_string_view sep; - - tuple_join_view(const std::tuple& t, basic_string_view s) - : tuple(t), sep{s} {} -}; - -// Define FMT_TUPLE_JOIN_SPECIFIERS to enable experimental format specifiers -// support in tuple_join. It is disabled by default because of issues with -// the dynamic width and precision. -#ifndef FMT_TUPLE_JOIN_SPECIFIERS -# define FMT_TUPLE_JOIN_SPECIFIERS 0 -#endif - -template -struct formatter, Char> { - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - return do_parse(ctx, std::integral_constant()); - } - - template - auto format(const tuple_join_view& value, - FormatContext& ctx) const -> typename FormatContext::iterator { - return do_format(value, ctx, - std::integral_constant()); - } - - private: - std::tuple::type, Char>...> formatters_; - - template - FMT_CONSTEXPR auto do_parse(ParseContext& ctx, - std::integral_constant) - -> decltype(ctx.begin()) { - return ctx.begin(); - } - - template - FMT_CONSTEXPR auto do_parse(ParseContext& ctx, - std::integral_constant) - -> decltype(ctx.begin()) { - auto end = ctx.begin(); -#if FMT_TUPLE_JOIN_SPECIFIERS - end = std::get(formatters_).parse(ctx); - if (N > 1) { - auto end1 = do_parse(ctx, std::integral_constant()); - if (end != end1) - report_error("incompatible format specs for tuple elements"); - } -#endif - return end; - } - - template - auto do_format(const tuple_join_view&, FormatContext& ctx, - std::integral_constant) const -> - typename FormatContext::iterator { - return ctx.out(); - } - - template - auto do_format(const tuple_join_view& value, FormatContext& ctx, - std::integral_constant) const -> - typename FormatContext::iterator { - auto out = std::get(formatters_) - .format(std::get(value.tuple), ctx); - if (N <= 1) return out; - out = detail::copy(value.sep, out); - ctx.advance_to(out); - return do_format(value, ctx, std::integral_constant()); - } -}; - -namespace detail { -// Check if T has an interface like a container adaptor (e.g. std::stack, -// std::queue, std::priority_queue). -template class is_container_adaptor_like { - template static auto check(U* p) -> typename U::container_type; - template static void check(...); - - public: - static constexpr const bool value = - !std::is_void(nullptr))>::value; -}; - -template struct all { - const Container& c; - auto begin() const -> typename Container::const_iterator { return c.begin(); } - auto end() const -> typename Container::const_iterator { return c.end(); } -}; -} // namespace detail - -template -struct formatter< - T, Char, - enable_if_t, - bool_constant::value == - range_format::disabled>>::value>> - : formatter, Char> { - using all = detail::all; - template - auto format(const T& t, FormatContext& ctx) const -> decltype(ctx.out()) { - struct getter : T { - static auto get(const T& t) -> all { - return {t.*(&getter::c)}; // Access c through the derived class. - } - }; - return formatter::format(getter::get(t), ctx); - } -}; - -FMT_BEGIN_EXPORT - -/** - * Returns an object that formats `std::tuple` with elements separated by `sep`. - * - * **Example**: - * - * auto t = std::tuple{1, 'a'}; - * fmt::print("{}", fmt::join(t, ", ")); - * // Output: 1, a - */ -template -FMT_CONSTEXPR auto join(const std::tuple& tuple, string_view sep) - -> tuple_join_view { - return {tuple, sep}; -} - -/** - * Returns an object that formats `std::initializer_list` with elements - * separated by `sep`. - * - * **Example**: - * - * fmt::print("{}", fmt::join({1, 2, 3}, ", ")); - * // Output: "1, 2, 3" - */ -template -auto join(std::initializer_list list, string_view sep) - -> join_view { - return join(std::begin(list), std::end(list), sep); -} - -FMT_END_EXPORT -FMT_END_NAMESPACE - -#endif // FMT_RANGES_H_ diff --git a/tt_metal/third_party/fmt/fmt/std.h b/tt_metal/third_party/fmt/fmt/std.h deleted file mode 100644 index fb43940bc06..00000000000 --- a/tt_metal/third_party/fmt/fmt/std.h +++ /dev/null @@ -1,699 +0,0 @@ -// Formatting library for C++ - formatters for standard library types -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_STD_H_ -#define FMT_STD_H_ - -#include "format.h" -#include "ostream.h" - -#ifndef FMT_MODULE -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -// Check FMT_CPLUSPLUS to suppress a bogus warning in MSVC. -# if FMT_CPLUSPLUS >= 201703L -# if FMT_HAS_INCLUDE() -# include -# endif -# if FMT_HAS_INCLUDE() -# include -# endif -# if FMT_HAS_INCLUDE() -# include -# endif -# endif -// Use > instead of >= in the version check because may be -// available after C++17 but before C++20 is marked as implemented. -# if FMT_CPLUSPLUS > 201703L && FMT_HAS_INCLUDE() -# include -# endif -# if FMT_CPLUSPLUS > 202002L && FMT_HAS_INCLUDE() -# include -# endif -#endif // FMT_MODULE - -#if FMT_HAS_INCLUDE() -# include -#endif - -// GCC 4 does not support FMT_HAS_INCLUDE. -#if FMT_HAS_INCLUDE() || defined(__GLIBCXX__) -# include -// Android NDK with gabi++ library on some architectures does not implement -// abi::__cxa_demangle(). -# ifndef __GABIXX_CXXABI_H__ -# define FMT_HAS_ABI_CXA_DEMANGLE -# endif -#endif - -// For older Xcode versions, __cpp_lib_xxx flags are inaccurately defined. -#ifndef FMT_CPP_LIB_FILESYSTEM -# ifdef __cpp_lib_filesystem -# define FMT_CPP_LIB_FILESYSTEM __cpp_lib_filesystem -# else -# define FMT_CPP_LIB_FILESYSTEM 0 -# endif -#endif - -#ifndef FMT_CPP_LIB_VARIANT -# ifdef __cpp_lib_variant -# define FMT_CPP_LIB_VARIANT __cpp_lib_variant -# else -# define FMT_CPP_LIB_VARIANT 0 -# endif -#endif - -#if FMT_CPP_LIB_FILESYSTEM -FMT_BEGIN_NAMESPACE - -namespace detail { - -template -auto get_path_string(const std::filesystem::path& p, - const std::basic_string& native) { - if constexpr (std::is_same_v && std::is_same_v) - return to_utf8(native, to_utf8_error_policy::replace); - else - return p.string(); -} - -template -void write_escaped_path(basic_memory_buffer& quoted, - const std::filesystem::path& p, - const std::basic_string& native) { - if constexpr (std::is_same_v && - std::is_same_v) { - auto buf = basic_memory_buffer(); - write_escaped_string(std::back_inserter(buf), native); - bool valid = to_utf8::convert(quoted, {buf.data(), buf.size()}); - FMT_ASSERT(valid, "invalid utf16"); - } else if constexpr (std::is_same_v) { - write_escaped_string( - std::back_inserter(quoted), native); - } else { - write_escaped_string(std::back_inserter(quoted), p.string()); - } -} - -} // namespace detail - -FMT_EXPORT -template struct formatter { - private: - format_specs specs_; - detail::arg_ref width_ref_; - bool debug_ = false; - char path_type_ = 0; - - public: - FMT_CONSTEXPR void set_debug_format(bool set = true) { debug_ = set; } - - template FMT_CONSTEXPR auto parse(ParseContext& ctx) { - auto it = ctx.begin(), end = ctx.end(); - if (it == end) return it; - - it = detail::parse_align(it, end, specs_); - if (it == end) return it; - - it = detail::parse_dynamic_spec(it, end, specs_.width, width_ref_, ctx); - if (it != end && *it == '?') { - debug_ = true; - ++it; - } - if (it != end && (*it == 'g')) path_type_ = detail::to_ascii(*it++); - return it; - } - - template - auto format(const std::filesystem::path& p, FormatContext& ctx) const { - auto specs = specs_; - auto path_string = - !path_type_ ? p.native() - : p.generic_string(); - - detail::handle_dynamic_spec(specs.width, width_ref_, - ctx); - if (!debug_) { - auto s = detail::get_path_string(p, path_string); - return detail::write(ctx.out(), basic_string_view(s), specs); - } - auto quoted = basic_memory_buffer(); - detail::write_escaped_path(quoted, p, path_string); - return detail::write(ctx.out(), - basic_string_view(quoted.data(), quoted.size()), - specs); - } -}; - -class path : public std::filesystem::path { - public: - auto display_string() const -> std::string { - const std::filesystem::path& base = *this; - return fmt::format(FMT_STRING("{}"), base); - } - auto system_string() const -> std::string { return string(); } - - auto generic_display_string() const -> std::string { - const std::filesystem::path& base = *this; - return fmt::format(FMT_STRING("{:g}"), base); - } - auto generic_system_string() const -> std::string { return generic_string(); } -}; - -FMT_END_NAMESPACE -#endif // FMT_CPP_LIB_FILESYSTEM - -FMT_BEGIN_NAMESPACE -FMT_EXPORT -template -struct formatter, Char> : nested_formatter { - private: - // Functor because C++11 doesn't support generic lambdas. - struct writer { - const std::bitset& bs; - - template - FMT_CONSTEXPR auto operator()(OutputIt out) -> OutputIt { - for (auto pos = N; pos > 0; --pos) { - out = detail::write(out, bs[pos - 1] ? Char('1') : Char('0')); - } - - return out; - } - }; - - public: - template - auto format(const std::bitset& bs, FormatContext& ctx) const - -> decltype(ctx.out()) { - return write_padded(ctx, writer{bs}); - } -}; - -FMT_EXPORT -template -struct formatter : basic_ostream_formatter {}; -FMT_END_NAMESPACE - -#ifdef __cpp_lib_optional -FMT_BEGIN_NAMESPACE -FMT_EXPORT -template -struct formatter, Char, - std::enable_if_t::value>> { - private: - formatter underlying_; - static constexpr basic_string_view optional = - detail::string_literal{}; - static constexpr basic_string_view none = - detail::string_literal{}; - - template - FMT_CONSTEXPR static auto maybe_set_debug_format(U& u, bool set) - -> decltype(u.set_debug_format(set)) { - u.set_debug_format(set); - } - - template - FMT_CONSTEXPR static void maybe_set_debug_format(U&, ...) {} - - public: - template FMT_CONSTEXPR auto parse(ParseContext& ctx) { - maybe_set_debug_format(underlying_, true); - return underlying_.parse(ctx); - } - - template - auto format(const std::optional& opt, FormatContext& ctx) const - -> decltype(ctx.out()) { - if (!opt) return detail::write(ctx.out(), none); - - auto out = ctx.out(); - out = detail::write(out, optional); - ctx.advance_to(out); - out = underlying_.format(*opt, ctx); - return detail::write(out, ')'); - } -}; -FMT_END_NAMESPACE -#endif // __cpp_lib_optional - -#if defined(__cpp_lib_expected) || FMT_CPP_LIB_VARIANT - -FMT_BEGIN_NAMESPACE -namespace detail { - -template -auto write_escaped_alternative(OutputIt out, const T& v) -> OutputIt { - if constexpr (has_to_string_view::value) - return write_escaped_string(out, detail::to_string_view(v)); - if constexpr (std::is_same_v) return write_escaped_char(out, v); - return write(out, v); -} - -} // namespace detail - -FMT_END_NAMESPACE -#endif - -#ifdef __cpp_lib_expected -FMT_BEGIN_NAMESPACE - -FMT_EXPORT -template -struct formatter, Char, - std::enable_if_t::value && - is_formattable::value>> { - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - return ctx.begin(); - } - - template - auto format(const std::expected& value, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto out = ctx.out(); - - if (value.has_value()) { - out = detail::write(out, "expected("); - out = detail::write_escaped_alternative(out, *value); - } else { - out = detail::write(out, "unexpected("); - out = detail::write_escaped_alternative(out, value.error()); - } - *out++ = ')'; - return out; - } -}; -FMT_END_NAMESPACE -#endif // __cpp_lib_expected - -#ifdef __cpp_lib_source_location -FMT_BEGIN_NAMESPACE -FMT_EXPORT -template <> struct formatter { - template FMT_CONSTEXPR auto parse(ParseContext& ctx) { - return ctx.begin(); - } - - template - auto format(const std::source_location& loc, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto out = ctx.out(); - out = detail::write(out, loc.file_name()); - out = detail::write(out, ':'); - out = detail::write(out, loc.line()); - out = detail::write(out, ':'); - out = detail::write(out, loc.column()); - out = detail::write(out, ": "); - out = detail::write(out, loc.function_name()); - return out; - } -}; -FMT_END_NAMESPACE -#endif - -#if FMT_CPP_LIB_VARIANT -FMT_BEGIN_NAMESPACE -namespace detail { - -template -using variant_index_sequence = - std::make_index_sequence::value>; - -template struct is_variant_like_ : std::false_type {}; -template -struct is_variant_like_> : std::true_type {}; - -// formattable element check. -template class is_variant_formattable_ { - template - static std::conjunction< - is_formattable, C>...> - check(std::index_sequence); - - public: - static constexpr const bool value = - decltype(check(variant_index_sequence{}))::value; -}; - -} // namespace detail - -template struct is_variant_like { - static constexpr const bool value = detail::is_variant_like_::value; -}; - -template struct is_variant_formattable { - static constexpr const bool value = - detail::is_variant_formattable_::value; -}; - -FMT_EXPORT -template struct formatter { - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - return ctx.begin(); - } - - template - auto format(const std::monostate&, FormatContext& ctx) const - -> decltype(ctx.out()) { - return detail::write(ctx.out(), "monostate"); - } -}; - -FMT_EXPORT -template -struct formatter< - Variant, Char, - std::enable_if_t, is_variant_formattable>>> { - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - return ctx.begin(); - } - - template - auto format(const Variant& value, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto out = ctx.out(); - - out = detail::write(out, "variant("); - FMT_TRY { - std::visit( - [&](const auto& v) { - out = detail::write_escaped_alternative(out, v); - }, - value); - } - FMT_CATCH(const std::bad_variant_access&) { - detail::write(out, "valueless by exception"); - } - *out++ = ')'; - return out; - } -}; -FMT_END_NAMESPACE -#endif // FMT_CPP_LIB_VARIANT - -FMT_BEGIN_NAMESPACE -FMT_EXPORT -template struct formatter { - template - FMT_CONSTEXPR auto parse(ParseContext& ctx) -> decltype(ctx.begin()) { - return ctx.begin(); - } - - template - FMT_CONSTEXPR auto format(const std::error_code& ec, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto out = ctx.out(); - out = detail::write_bytes(out, ec.category().name(), format_specs()); - out = detail::write(out, Char(':')); - out = detail::write(out, ec.value()); - return out; - } -}; - -#if FMT_USE_RTTI -namespace detail { - -template -auto write_demangled_name(OutputIt out, const std::type_info& ti) -> OutputIt { -# ifdef FMT_HAS_ABI_CXA_DEMANGLE - int status = 0; - std::size_t size = 0; - std::unique_ptr demangled_name_ptr( - abi::__cxa_demangle(ti.name(), nullptr, &size, &status), &std::free); - - string_view demangled_name_view; - if (demangled_name_ptr) { - demangled_name_view = demangled_name_ptr.get(); - - // Normalization of stdlib inline namespace names. - // libc++ inline namespaces. - // std::__1::* -> std::* - // std::__1::__fs::* -> std::* - // libstdc++ inline namespaces. - // std::__cxx11::* -> std::* - // std::filesystem::__cxx11::* -> std::filesystem::* - if (demangled_name_view.starts_with("std::")) { - char* begin = demangled_name_ptr.get(); - char* to = begin + 5; // std:: - for (char *from = to, *end = begin + demangled_name_view.size(); - from < end;) { - // This is safe, because demangled_name is NUL-terminated. - if (from[0] == '_' && from[1] == '_') { - char* next = from + 1; - while (next < end && *next != ':') next++; - if (next[0] == ':' && next[1] == ':') { - from = next + 2; - continue; - } - } - *to++ = *from++; - } - demangled_name_view = {begin, detail::to_unsigned(to - begin)}; - } - } else { - demangled_name_view = string_view(ti.name()); - } - return detail::write_bytes(out, demangled_name_view); -# elif FMT_MSC_VERSION - const string_view demangled_name(ti.name()); - for (std::size_t i = 0; i < demangled_name.size(); ++i) { - auto sub = demangled_name; - sub.remove_prefix(i); - if (sub.starts_with("enum ")) { - i += 4; - continue; - } - if (sub.starts_with("class ") || sub.starts_with("union ")) { - i += 5; - continue; - } - if (sub.starts_with("struct ")) { - i += 6; - continue; - } - if (*sub.begin() != ' ') *out++ = *sub.begin(); - } - return out; -# else - return detail::write_bytes(out, string_view(ti.name())); -# endif -} - -} // namespace detail - -FMT_EXPORT -template -struct formatter { - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - return ctx.begin(); - } - - template - auto format(const std::type_info& ti, Context& ctx) const - -> decltype(ctx.out()) { - return detail::write_demangled_name(ctx.out(), ti); - } -}; -#endif - -FMT_EXPORT -template -struct formatter< - T, Char, // DEPRECATED! Mixing code unit types. - typename std::enable_if::value>::type> { - private: - bool with_typename_ = false; - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - auto it = ctx.begin(); - auto end = ctx.end(); - if (it == end || *it == '}') return it; - if (*it == 't') { - ++it; - with_typename_ = FMT_USE_RTTI != 0; - } - return it; - } - - template - auto format(const std::exception& ex, Context& ctx) const - -> decltype(ctx.out()) { - auto out = ctx.out(); -#if FMT_USE_RTTI - if (with_typename_) { - out = detail::write_demangled_name(out, typeid(ex)); - *out++ = ':'; - *out++ = ' '; - } -#endif - return detail::write_bytes(out, string_view(ex.what())); - } -}; - -namespace detail { - -template -struct has_flip : std::false_type {}; - -template -struct has_flip().flip())>> - : std::true_type {}; - -template struct is_bit_reference_like { - static constexpr const bool value = - std::is_convertible::value && - std::is_nothrow_assignable::value && has_flip::value; -}; - -#ifdef _LIBCPP_VERSION - -// Workaround for libc++ incompatibility with C++ standard. -// According to the Standard, `bitset::operator[] const` returns bool. -template -struct is_bit_reference_like> { - static constexpr const bool value = true; -}; - -#endif - -} // namespace detail - -// We can't use std::vector::reference and -// std::bitset::reference because the compiler can't deduce Allocator and N -// in partial specialization. -FMT_EXPORT -template -struct formatter::value>> - : formatter { - template - FMT_CONSTEXPR auto format(const BitRef& v, FormatContext& ctx) const - -> decltype(ctx.out()) { - return formatter::format(v, ctx); - } -}; - -template -auto ptr(const std::unique_ptr& p) -> const void* { - return p.get(); -} -template auto ptr(const std::shared_ptr& p) -> const void* { - return p.get(); -} - -FMT_EXPORT -template -struct formatter, Char, - enable_if_t::value>> - : formatter { - template - auto format(const std::atomic& v, FormatContext& ctx) const - -> decltype(ctx.out()) { - return formatter::format(v.load(), ctx); - } -}; - -#ifdef __cpp_lib_atomic_flag_test -FMT_EXPORT -template -struct formatter : formatter { - template - auto format(const std::atomic_flag& v, FormatContext& ctx) const - -> decltype(ctx.out()) { - return formatter::format(v.test(), ctx); - } -}; -#endif // __cpp_lib_atomic_flag_test - -FMT_EXPORT -template struct formatter, Char> { - private: - detail::dynamic_format_specs specs_; - - template - FMT_CONSTEXPR auto do_format(const std::complex& c, - detail::dynamic_format_specs& specs, - FormatContext& ctx, OutputIt out) const - -> OutputIt { - if (c.real() != 0) { - *out++ = Char('('); - out = detail::write(out, c.real(), specs, ctx.locale()); - specs.sign = sign::plus; - out = detail::write(out, c.imag(), specs, ctx.locale()); - if (!detail::isfinite(c.imag())) *out++ = Char(' '); - *out++ = Char('i'); - *out++ = Char(')'); - return out; - } - out = detail::write(out, c.imag(), specs, ctx.locale()); - if (!detail::isfinite(c.imag())) *out++ = Char(' '); - *out++ = Char('i'); - return out; - } - - public: - FMT_CONSTEXPR auto parse(basic_format_parse_context& ctx) - -> decltype(ctx.begin()) { - if (ctx.begin() == ctx.end() || *ctx.begin() == '}') return ctx.begin(); - return parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, - detail::type_constant::value); - } - - template - auto format(const std::complex& c, FormatContext& ctx) const - -> decltype(ctx.out()) { - auto specs = specs_; - if (specs.width_ref.kind != detail::arg_id_kind::none || - specs.precision_ref.kind != detail::arg_id_kind::none) { - detail::handle_dynamic_spec(specs.width, - specs.width_ref, ctx); - detail::handle_dynamic_spec( - specs.precision, specs.precision_ref, ctx); - } - - if (specs.width == 0) return do_format(c, specs, ctx, ctx.out()); - auto buf = basic_memory_buffer(); - - auto outer_specs = format_specs(); - outer_specs.width = specs.width; - outer_specs.fill = specs.fill; - outer_specs.align = specs.align; - - specs.width = 0; - specs.fill = {}; - specs.align = align::none; - - do_format(c, specs, ctx, basic_appender(buf)); - return detail::write(ctx.out(), - basic_string_view(buf.data(), buf.size()), - outer_specs); - } -}; - -FMT_END_NAMESPACE -#endif // FMT_STD_H_ diff --git a/tt_metal/third_party/fmt/fmt/xchar.h b/tt_metal/third_party/fmt/fmt/xchar.h deleted file mode 100644 index b1f39ed2220..00000000000 --- a/tt_metal/third_party/fmt/fmt/xchar.h +++ /dev/null @@ -1,322 +0,0 @@ -// Formatting library for C++ - optional wchar_t and exotic character support -// -// Copyright (c) 2012 - present, Victor Zverovich -// All rights reserved. -// -// For the license information refer to format.h. - -#ifndef FMT_XCHAR_H_ -#define FMT_XCHAR_H_ - -#include "color.h" -#include "format.h" -#include "ranges.h" - -#ifndef FMT_MODULE -# include -# if !defined(FMT_STATIC_THOUSANDS_SEPARATOR) -# include -# endif -#endif - -FMT_BEGIN_NAMESPACE -namespace detail { - -template -using is_exotic_char = bool_constant::value>; - -template struct format_string_char {}; - -template -struct format_string_char< - S, void_t())))>> { - using type = char_t; -}; - -template -struct format_string_char::value>> { - using type = typename S::char_type; -}; - -template -using format_string_char_t = typename format_string_char::type; - -inline auto write_loc(basic_appender out, loc_value value, - const format_specs& specs, locale_ref loc) -> bool { -#ifndef FMT_STATIC_THOUSANDS_SEPARATOR - auto& numpunct = - std::use_facet>(loc.get()); - auto separator = std::wstring(); - auto grouping = numpunct.grouping(); - if (!grouping.empty()) separator = std::wstring(1, numpunct.thousands_sep()); - return value.visit(loc_writer{out, specs, separator, grouping, {}}); -#endif - return false; -} -} // namespace detail - -FMT_BEGIN_EXPORT - -using wstring_view = basic_string_view; -using wformat_parse_context = basic_format_parse_context; -using wformat_context = buffered_context; -using wformat_args = basic_format_args; -using wmemory_buffer = basic_memory_buffer; - -#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 -// Workaround broken conversion on older gcc. -template using wformat_string = wstring_view; -inline auto runtime(wstring_view s) -> wstring_view { return s; } -#else -template -using wformat_string = basic_format_string...>; -inline auto runtime(wstring_view s) -> runtime_format_string { - return {{s}}; -} -#endif - -template <> struct is_char : std::true_type {}; -template <> struct is_char : std::true_type {}; -template <> struct is_char : std::true_type {}; - -#ifdef __cpp_char8_t -template <> -struct is_char : bool_constant {}; -#endif - -template -constexpr auto make_wformat_args(T&... args) - -> decltype(fmt::make_format_args(args...)) { - return fmt::make_format_args(args...); -} - -inline namespace literals { -#if FMT_USE_USER_DEFINED_LITERALS && !FMT_USE_NONTYPE_TEMPLATE_ARGS -constexpr auto operator""_a(const wchar_t* s, size_t) - -> detail::udl_arg { - return {s}; -} -#endif -} // namespace literals - -template -auto join(It begin, Sentinel end, wstring_view sep) - -> join_view { - return {begin, end, sep}; -} - -template -auto join(Range&& range, wstring_view sep) - -> join_view, detail::sentinel_t, - wchar_t> { - return join(std::begin(range), std::end(range), sep); -} - -template -auto join(std::initializer_list list, wstring_view sep) - -> join_view { - return join(std::begin(list), std::end(list), sep); -} - -template -auto join(const std::tuple& tuple, basic_string_view sep) - -> tuple_join_view { - return {tuple, sep}; -} - -template ::value)> -auto vformat(basic_string_view format_str, - typename detail::vformat_args::type args) - -> std::basic_string { - auto buf = basic_memory_buffer(); - detail::vformat_to(buf, format_str, args); - return to_string(buf); -} - -template -auto format(wformat_string fmt, T&&... args) -> std::wstring { - return vformat(fmt::wstring_view(fmt), fmt::make_wformat_args(args...)); -} - -template -auto format_to(OutputIt out, wformat_string fmt, T&&... args) - -> OutputIt { - return vformat_to(out, fmt::wstring_view(fmt), - fmt::make_wformat_args(args...)); -} - -// Pass char_t as a default template parameter instead of using -// std::basic_string> to reduce the symbol size. -template , - FMT_ENABLE_IF(!std::is_same::value && - !std::is_same::value)> -auto format(const S& format_str, T&&... args) -> std::basic_string { - return vformat(detail::to_string_view(format_str), - fmt::make_format_args>(args...)); -} - -template , - FMT_ENABLE_IF(detail::is_locale::value&& - detail::is_exotic_char::value)> -inline auto vformat(const Locale& loc, const S& format_str, - typename detail::vformat_args::type args) - -> std::basic_string { - return detail::vformat(loc, detail::to_string_view(format_str), args); -} - -template , - FMT_ENABLE_IF(detail::is_locale::value&& - detail::is_exotic_char::value)> -inline auto format(const Locale& loc, const S& format_str, T&&... args) - -> std::basic_string { - return detail::vformat( - loc, detail::to_string_view(format_str), - fmt::make_format_args>(args...)); -} - -template , - FMT_ENABLE_IF(detail::is_output_iterator::value&& - detail::is_exotic_char::value)> -auto vformat_to(OutputIt out, const S& format_str, - typename detail::vformat_args::type args) -> OutputIt { - auto&& buf = detail::get_buffer(out); - detail::vformat_to(buf, detail::to_string_view(format_str), args); - return detail::get_iterator(buf, out); -} - -template , - FMT_ENABLE_IF(detail::is_output_iterator::value && - !std::is_same::value && - !std::is_same::value)> -inline auto format_to(OutputIt out, const S& fmt, T&&... args) -> OutputIt { - return vformat_to(out, detail::to_string_view(fmt), - fmt::make_format_args>(args...)); -} - -template , - FMT_ENABLE_IF(detail::is_output_iterator::value&& - detail::is_locale::value&& - detail::is_exotic_char::value)> -inline auto vformat_to(OutputIt out, const Locale& loc, const S& format_str, - typename detail::vformat_args::type args) - -> OutputIt { - auto&& buf = detail::get_buffer(out); - vformat_to(buf, detail::to_string_view(format_str), args, - detail::locale_ref(loc)); - return detail::get_iterator(buf, out); -} - -template , - bool enable = detail::is_output_iterator::value && - detail::is_locale::value && - detail::is_exotic_char::value> -inline auto format_to(OutputIt out, const Locale& loc, const S& format_str, - T&&... args) -> - typename std::enable_if::type { - return vformat_to(out, loc, detail::to_string_view(format_str), - fmt::make_format_args>(args...)); -} - -template ::value&& - detail::is_exotic_char::value)> -inline auto vformat_to_n(OutputIt out, size_t n, - basic_string_view format_str, - typename detail::vformat_args::type args) - -> format_to_n_result { - using traits = detail::fixed_buffer_traits; - auto buf = detail::iterator_buffer(out, n); - detail::vformat_to(buf, format_str, args); - return {buf.out(), buf.count()}; -} - -template , - FMT_ENABLE_IF(detail::is_output_iterator::value&& - detail::is_exotic_char::value)> -inline auto format_to_n(OutputIt out, size_t n, const S& fmt, T&&... args) - -> format_to_n_result { - return vformat_to_n(out, n, fmt::basic_string_view(fmt), - fmt::make_format_args>(args...)); -} - -template , - FMT_ENABLE_IF(detail::is_exotic_char::value)> -inline auto formatted_size(const S& fmt, T&&... args) -> size_t { - auto buf = detail::counting_buffer(); - detail::vformat_to(buf, detail::to_string_view(fmt), - fmt::make_format_args>(args...)); - return buf.count(); -} - -inline void vprint(std::FILE* f, wstring_view fmt, wformat_args args) { - auto buf = wmemory_buffer(); - detail::vformat_to(buf, fmt, args); - buf.push_back(L'\0'); - if (std::fputws(buf.data(), f) == -1) - FMT_THROW(system_error(errno, FMT_STRING("cannot write to file"))); -} - -inline void vprint(wstring_view fmt, wformat_args args) { - vprint(stdout, fmt, args); -} - -template -void print(std::FILE* f, wformat_string fmt, T&&... args) { - return vprint(f, wstring_view(fmt), fmt::make_wformat_args(args...)); -} - -template void print(wformat_string fmt, T&&... args) { - return vprint(wstring_view(fmt), fmt::make_wformat_args(args...)); -} - -template -void println(std::FILE* f, wformat_string fmt, T&&... args) { - return print(f, L"{}\n", fmt::format(fmt, std::forward(args)...)); -} - -template void println(wformat_string fmt, T&&... args) { - return print(L"{}\n", fmt::format(fmt, std::forward(args)...)); -} - -inline auto vformat(const text_style& ts, wstring_view fmt, wformat_args args) - -> std::wstring { - auto buf = wmemory_buffer(); - detail::vformat_to(buf, ts, fmt, args); - return fmt::to_string(buf); -} - -template -inline auto format(const text_style& ts, wformat_string fmt, T&&... args) - -> std::wstring { - return fmt::vformat(ts, fmt, fmt::make_wformat_args(args...)); -} - -template -FMT_DEPRECATED void print(std::FILE* f, const text_style& ts, - wformat_string fmt, const T&... args) { - vprint(f, ts, fmt, fmt::make_wformat_args(args...)); -} - -template -FMT_DEPRECATED void print(const text_style& ts, wformat_string fmt, - const T&... args) { - return print(stdout, ts, fmt, args...); -} - -/// Converts `value` to `std::wstring` using the default format for type `T`. -template inline auto to_wstring(const T& value) -> std::wstring { - return format(FMT_STRING(L"{}"), value); -} -FMT_END_EXPORT -FMT_END_NAMESPACE - -#endif // FMT_XCHAR_H_ diff --git a/ttnn/CMakeLists.txt b/ttnn/CMakeLists.txt index 2cf781ccfe2..7e4ed72251c 100644 --- a/ttnn/CMakeLists.txt +++ b/ttnn/CMakeLists.txt @@ -550,8 +550,6 @@ set(TTNN_PRECOMPILED_HEADERS ${PROJECT_SOURCE_DIR}/tt_metal/tt_stl/reflection.hpp ${PROJECT_SOURCE_DIR}/ttnn/cpp/ttnn/operation.hpp ${PROJECT_SOURCE_DIR}/tt_metal/third_party/tracy/public/tracy/Tracy.hpp - ${PROJECT_SOURCE_DIR}/tt_metal/third_party/fmt/fmt/core.h - ${PROJECT_SOURCE_DIR}/tt_metal/third_party/fmt/fmt/format.h ${PROJECT_SOURCE_DIR}/tt_metal/third_party/umd/device/device_api_metal.h ${PROJECT_SOURCE_DIR}/tt_metal/third_party/umd/device/tt_device.h From aaa08a5425474a557902ff7ca6be48abf630144c Mon Sep 17 00:00:00 2001 From: Anil Mahmud Date: Tue, 8 Oct 2024 12:32:31 -0400 Subject: [PATCH 40/58] #13547: Removed unused arguments from release_dst and acquire_dst --- METALIUM_GUIDE.md | 20 ++++----- .../apis/kernel_apis/compute/acquire_dst.rst | 2 +- .../apis/kernel_apis/compute/release_dst.rst | 2 +- tests/tt_eager/ops/kernel/eltwise_sfpu.cpp | 4 +- ...m_large_block_zm_fused_bias_activation.cpp | 8 ++-- ...m_large_block_zm_fused_bias_activation.cpp | 8 ++-- .../old/matmul/kernels/compute_local_l1.cpp | 4 +- .../tt_metal/test_kernels/compute/bcast_h.cpp | 4 +- .../test_kernels/compute/bcast_hw.cpp | 4 +- .../tt_metal/test_kernels/compute/bcast_w.cpp | 4 +- .../tt_metal/test_kernels/compute/bmm.cpp | 4 +- .../compute/bmm_large_block_zm.cpp | 4 +- ...m_large_block_zm_fused_bias_activation.cpp | 8 ++-- .../bmm_large_block_zm_mixed_precision.cpp | 4 +- .../compute/bmm_tilize_untilize.cpp | 12 ++--- .../test_kernels/compute/broadcast.cpp | 4 +- .../tt_metal/test_kernels/compute/cumsum.cpp | 4 +- .../test_kernels/compute/dropout_sfpu.cpp | 4 +- .../test_kernels/compute/eltwise_copy.cpp | 4 +- .../compute/eltwise_copy_block.cpp | 4 +- .../eltwise_copy_block_matmul_partials.cpp | 4 +- .../test_kernels/compute/eltwise_sfpi.cpp | 4 +- .../test_kernels/compute/layernorm.cpp | 4 +- .../tt_metal/test_kernels/compute/matmul.cpp | 4 +- .../test_kernels/compute/matmul_block.cpp | 4 +- .../compute/matmul_large_block.cpp | 8 ++-- .../matmul_large_block_generalized.cpp | 8 ++-- .../compute/matmul_large_block_zm.cpp | 4 +- .../test_kernels/compute/matmul_with_bias.cpp | 8 ++-- .../test_kernels/compute/max_pool.cpp | 4 +- .../compute/max_pool_multi_core.cpp | 4 +- .../test_kernels/compute/reconfig.cpp | 4 +- .../test_kernels/compute/reduce_h.cpp | 4 +- .../test_kernels/compute/reduce_hw.cpp | 4 +- .../test_kernels/compute/reduce_w.cpp | 4 +- .../tt_metal/test_kernels/compute/rmsnorm.cpp | 4 +- .../test_kernels/compute/rotary_embedding.cpp | 4 +- .../tt_metal/test_kernels/compute/softmax.cpp | 4 +- .../compute/transformer_attn_matmul.cpp | 4 +- .../test_kernels/compute/transpose_wh.cpp | 4 +- .../unit_tests/matmul/multi_block_compute.cpp | 4 +- .../unit_tests/matmul/multi_tile_compute.cpp | 4 +- .../unit_tests/matmul/single_tile_compute.cpp | 4 +- .../test_kernels/compute/unpack_tilizeA_B.cpp | 4 +- tt_metal/include/compute_kernel_api/reg_api.h | 16 ++----- tt_metal/kernels/compute/eltwise_sfpu.cpp | 4 +- .../kernels/compute/add_2_tiles.cpp | 4 +- .../contributed/vecadd/kernels/add.cpp | 4 +- .../matmul_common/kernels/compute/bmm.cpp | 4 +- .../kernels/compute/bmm_large_block_zm.cpp | 4 +- .../tt_dnn/kernels/compute/eltwise_copy.cpp | 4 +- .../tt_dnn/kernels/compute/moreh_common.hpp | 4 +- .../tt_dnn/kernels/compute/transpose_wh.cpp | 4 +- .../single_core/kernels/moreh_dot.cpp | 4 +- .../kernels/moreh_dot_backward.cpp | 4 +- .../kernels/compute_depthwise_conv1d.cpp | 4 +- .../bcast/device/kernels/compute/bcast_h.cpp | 4 +- .../compute/bcast_h_sharded_optimised.cpp | 4 +- .../bcast/device/kernels/compute/bcast_hw.cpp | 4 +- .../bcast/device/kernels/compute/bcast_w.cpp | 4 +- .../device/kernels/compute/eltwise_copy.cpp | 4 +- .../device/kernels/compute/transpose_wh.cpp | 4 +- .../kernels/compute/rotary_embedding.cpp | 4 +- .../compute/rotary_embedding_llama.cpp | 4 +- .../kernels/compute/transpose_wh_sharded.cpp | 4 +- .../matmul/device/kernels/compute/bmm.cpp | 4 +- .../kernels/compute/bmm_large_block_zm.cpp | 4 +- .../device/kernels/moreh_cumsum_nc.cpp | 4 +- .../moreh_dot/device/kernels/moreh_dot.cpp | 4 +- .../device/kernels/moreh_dot_backward.cpp | 4 +- .../device/kernels/compute/layernorm.cpp | 4 +- .../compute/layernorm_post_allgather.cpp | 4 +- .../compute/layernorm_pre_allgather.cpp | 4 +- .../device/kernels/compute/softmax.cpp | 4 +- .../kernels/compute/softmax_sharded.cpp | 4 +- .../device/kernels/compute/reduce_h.cpp | 4 +- .../device/kernels/compute/reduce_hw.cpp | 4 +- .../device/kernels/compute/reduce_w.cpp | 4 +- .../moe/device/kernels/compute/moe.cpp | 40 ++++++++--------- .../topk/device/kernels/compute/topk.cpp | 16 +++---- .../device/kernels/compute/topk_final.cpp | 20 ++++----- .../device/kernels/compute/topk_local.cpp | 16 +++---- .../sdpa/device/kernels/compute/sdpa.cpp | 36 +++++++-------- .../device/kernels/compute/sdpa_noncausal.cpp | 36 +++++++-------- .../kernels/compute/sdpa_flash_decode.cpp | 44 +++++++++---------- 85 files changed, 284 insertions(+), 292 deletions(-) diff --git a/METALIUM_GUIDE.md b/METALIUM_GUIDE.md index 55ae85c85c3..a68f5ba129e 100644 --- a/METALIUM_GUIDE.md +++ b/METALIUM_GUIDE.md @@ -126,7 +126,7 @@ kernel: namespace NAMESPACE { void MAIN { mm_init(); - acquire_dst(tt::DstMode::Tile); + acquire_dst(); cb_wait_front(tt::CB::c_in0, /* number of tiles */ 1); cb_wait_front(tt::CB::c_in1, /* number of tiles */ 1); @@ -140,7 +140,7 @@ void MAIN { pack_tile(0, tt::CB::c_out0); cb_push_back(tt::CB::c_out0, /* number of tiles */ 1); - release_dst(tt::DstMode::Tile); + release_dst(); } } // namespace NAMESPACE ``` @@ -149,7 +149,7 @@ It takes two matrix tiles from `tt::CB::c_in0` and `tt::CB::c_in0` L1 and conducts a single-tile matrix multiplication. Finally, it packs the result to `tt::CB::c_out0`. -Note that tile registers are acquired by `acquire_dst(..)`, but actually we can +Note that tile registers are acquired by `acquire_dst()`, but actually we can use `tile_regs_..()` functions for the more fine-grained tile register lock mechanism. At the end of this section, we will explain more details. @@ -226,10 +226,10 @@ inline __attribute__((always_inline)) void cb_wait_front(uint32_t cbid, uint32_t } ``` -Another interesting function is `acquire_dst(tt::DstMode mode)`: +Another interesting function is `acquire_dst()`: * The UNPACK kernel has an empty one: ``` -inline __attribute__((always_inline)) void acquire_dst(tt::DstMode mode) { +inline __attribute__((always_inline)) void acquire_dst() { ; ; @@ -237,7 +237,7 @@ inline __attribute__((always_inline)) void acquire_dst(tt::DstMode mode) { ``` * The MATH kernel waits for DEST available: ``` -inline __attribute__((always_inline)) void acquire_dst(tt::DstMode mode) { +inline __attribute__((always_inline)) void acquire_dst() { ( llk_math_wait_for_dest_available() ); ; @@ -245,7 +245,7 @@ inline __attribute__((always_inline)) void acquire_dst(tt::DstMode mode) { ``` * The UNPACK kernel waits for the end of MATH kernel: ``` -inline __attribute__((always_inline)) void acquire_dst(tt::DstMode mode) { +inline __attribute__((always_inline)) void acquire_dst() { ; ( llk_packer_wait_for_math_done() ); @@ -254,14 +254,14 @@ inline __attribute__((always_inline)) void acquire_dst(tt::DstMode mode) { [Its implementation](https://github.com/tenstorrent/tt-metal/blob/6d4951a20ca4c392888f924f038ae0780a8cc656/tt_metal/include/compute_kernel_api/reg_api.h#L28-L32) matches the preprocessed code: ``` -ALWI void acquire_dst(tt::DstMode mode) { +ALWI void acquire_dst() { MATH(( llk_math_wait_for_dest_available() )); PACK(( llk_packer_wait_for_math_done() )); } ``` -Based on the implementation of `acquire_dst(..)`, if we use it, we can guess it +Based on the implementation of `acquire_dst()`, if we use it, we can guess it executes UNPACK, MATH, PACK in order, which will help you to follow the execution order and instructions that actually run on each kernel. @@ -292,7 +292,7 @@ ALWI void tile_regs_release() { } ``` -We can replace `acquire_dst(..)` and `release_dst(..)` from the above example +We can replace `acquire_dst()` from the above example with `tile_regs_..()` functions like: ``` namespace NAMESPACE { diff --git a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/acquire_dst.rst b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/acquire_dst.rst index c6b725cf683..10c52fc84d1 100644 --- a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/acquire_dst.rst +++ b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/acquire_dst.rst @@ -1,4 +1,4 @@ acquire_dst =========== -.. doxygenfunction:: acquire_dst(tt::DstMode mode) +.. doxygenfunction:: acquire_dst() diff --git a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/release_dst.rst b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/release_dst.rst index 481b642fac9..804b60edf61 100644 --- a/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/release_dst.rst +++ b/docs/source/tt-metalium/tt_metal/apis/kernel_apis/compute/release_dst.rst @@ -1,4 +1,4 @@ release_dst =========== -.. doxygenfunction:: release_dst(tt::DstMode mode) +.. doxygenfunction:: release_dst() diff --git a/tests/tt_eager/ops/kernel/eltwise_sfpu.cpp b/tests/tt_eager/ops/kernel/eltwise_sfpu.cpp index 6ec7b2de5b0..d65c4bdf818 100644 --- a/tests/tt_eager/ops/kernel/eltwise_sfpu.cpp +++ b/tests/tt_eager/ops/kernel/eltwise_sfpu.cpp @@ -20,7 +20,7 @@ void MAIN { uint32_t block_index = 0; cb_reserve_back(tt::CB::c_out0, per_core_block_dim); uint32_t tile_index = 0; - acquire_dst(tt::DstMode::Half); + acquire_dst(); // Pop tile after tile, copy to DST and pack cb_wait_front(tt::CB::c_in0, 1); @@ -36,7 +36,7 @@ void MAIN { cb_pop_front(tt::CB::c_in0, 1); - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(tt::CB::c_out0, per_core_block_dim); diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation.cpp index 731a9d21fae..79ef28eca0f 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/1_compute_mm/kernels/bmm_large_block_zm_fused_bias_activation.cpp @@ -58,7 +58,7 @@ void MAIN { int in1_index_subblock_offset = 0; for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); if (enable_reload) { // Reconfigure input @@ -99,7 +99,7 @@ void MAIN { pack_tile(i, mm_bias_intermediate_cb_id); } cb_push_back(mm_bias_intermediate_cb_id, out_subblock_num_tiles); - release_dst(tt::DstMode::Half); + release_dst(); // Redundant wait since we know data was just pushed cb_wait_front(mm_bias_intermediate_cb_id, out_subblock_num_tiles); @@ -109,7 +109,7 @@ void MAIN { unpack_reconfig_data_format(mm_bias_intermediate_cb_id, bias_cb_id); // reconfigure packer df for out pack_reconfig_data_format(out_cb_id); - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t i = 0, j = 0; j < out_subblock_h; j++) { uint32_t bcast_tile_idx = in1_index_subblock_offset; for (uint32_t k = 0; k < out_subblock_w; k++, i++) { @@ -158,7 +158,7 @@ void MAIN { cb_push_back(mm_partials_cb_id, out_subblock_num_tiles); } - release_dst(tt::DstMode::Half); + release_dst(); in1_index_subblock_offset += out_subblock_w; } in0_index_subblock_offset += in0_subblock_num_tiles; diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/old/matmul/kernels/bmm_large_block_zm_fused_bias_activation.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/old/matmul/kernels/bmm_large_block_zm_fused_bias_activation.cpp index 47657c7059d..6722056c2bf 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/old/matmul/kernels/bmm_large_block_zm_fused_bias_activation.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/old/matmul/kernels/bmm_large_block_zm_fused_bias_activation.cpp @@ -58,7 +58,7 @@ void MAIN { int in1_index_subblock_offset = 0; for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); if (enable_reload) { // Reconfigure input @@ -98,7 +98,7 @@ void MAIN { pack_tile(i, mm_bias_intermediate_cb_id); } cb_push_back(mm_bias_intermediate_cb_id, out_subblock_num_tiles); - release_dst(tt::DstMode::Half); + release_dst(); // Redundant wait since we know data was just pushed cb_wait_front(mm_bias_intermediate_cb_id, out_subblock_num_tiles); @@ -108,7 +108,7 @@ void MAIN { unpack_reconfig_data_format(mm_bias_intermediate_cb_id, bias_cb_id); // reconfigure packer df for out pack_reconfig_data_format(out_cb_id); - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t i = 0, j = 0; j < out_subblock_h; j++) { uint32_t bcast_tile_idx = in1_index_subblock_offset; for (uint32_t k = 0; k < out_subblock_w; k++, i++) { @@ -150,7 +150,7 @@ void MAIN { cb_push_back(mm_partials_cb_id, out_subblock_num_tiles); } - release_dst(tt::DstMode::Half); + release_dst(); in1_index_subblock_offset += out_subblock_w; } in0_index_subblock_offset += in0_subblock_num_tiles; diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/old/matmul/kernels/compute_local_l1.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/old/matmul/kernels/compute_local_l1.cpp index 56bb8f17a4f..60d3136267d 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/old/matmul/kernels/compute_local_l1.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/old/matmul/kernels/compute_local_l1.cpp @@ -16,14 +16,14 @@ void MAIN { for (uint32_t mt = 0; mt < sub_Mt; ++mt) { for (uint32_t nt = 0; nt < sub_Nt; ++nt) { - acquire_dst(tt::DstMode::Full); + acquire_dst(); for (uint32_t kt = 0; kt < Kt; ++kt) { matmul_tiles(tt::CB::c_in0, tt::CB::c_in1, mt * Kt + kt, nt * Kt + kt, 0, false); } cb_reserve_back(tt::CB::c_out0, onetile); pack_tile(0, tt::CB::c_out0); cb_push_back(tt::CB::c_out0, onetile); - release_dst(tt::DstMode::Full); + release_dst(); } } } diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/bcast_h.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/bcast_h.cpp index 747765489ac..1220f3e935d 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/bcast_h.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/bcast_h.cpp @@ -24,7 +24,7 @@ void MAIN { cb_reserve_back(tt::CB::c_out0, onetile); - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_wait_front(tt::CB::c_in0, onetile); @@ -33,7 +33,7 @@ void MAIN { cb_pop_front(tt::CB::c_in0, onetile); - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(tt::CB::c_out0, onetile); cb_pop_front(tt::CB::c_in1, onetile); diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/bcast_hw.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/bcast_hw.cpp index 230ee8b9c36..499afa82fad 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/bcast_hw.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/bcast_hw.cpp @@ -26,7 +26,7 @@ void MAIN { #endif cb_reserve_back(tt::CB::c_out0, onetile); - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_wait_front(tt::CB::c_in0, onetile); @@ -37,7 +37,7 @@ void MAIN { #ifndef BCAST_SCALAR cb_pop_front(tt::CB::c_in1, onetile); #endif - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(tt::CB::c_out0, onetile); } } } diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/bcast_w.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/bcast_w.cpp index 0de0e2f82c0..ec6f71c0023 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/bcast_w.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/bcast_w.cpp @@ -23,14 +23,14 @@ void MAIN { cb_reserve_back(tt::CB::c_out0, onetile); - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_wait_front(tt::CB::c_in0, onetile); BCAST_OP(tt::CB::c_in0, tt::CB::c_in1, 0, 0, 0); pack_tile(0, tt::CB::c_out0); cb_pop_front(tt::CB::c_in0, onetile); - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(tt::CB::c_out0, onetile); diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/bmm.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/bmm.cpp index 6e42eb29d49..d62a8e06e98 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/bmm.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/bmm.cpp @@ -31,7 +31,7 @@ void MAIN { for (uint32_t mt_C = 0; mt_C < Mt; ++mt_C) // output tile of C for (uint32_t nt_C = 0; nt_C < Nt; ++nt_C) // output tile index of C { - acquire_dst(tt::DstMode::Full); + acquire_dst(); for (uint32_t kt = 0; kt < Kt; kt++) { cb_wait_front(tt::CB::c_in0, onetile); cb_wait_front(tt::CB::c_in1, onetile); @@ -46,7 +46,7 @@ void MAIN { pack_tile(0, tt::CB::c_out0); cb_push_back(tt::CB::c_out0, onetile); - release_dst(tt::DstMode::Full); + release_dst(); } diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm.cpp index ec293c8c7bb..2ab808f2f32 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm.cpp @@ -41,7 +41,7 @@ void MAIN { int in1_index_subblock_offset = 0; for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); if (enable_reload) { copy_tile_to_dst_init_short(); @@ -91,7 +91,7 @@ void MAIN { cb_push_back(tt::CB::c_intermed0, out_subblock_num_tiles); } - release_dst(tt::DstMode::Half); + release_dst(); in1_index_subblock_offset += out_subblock_w; } in0_index_subblock_offset += in0_subblock_num_tiles; diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp index 47657c7059d..6722056c2bf 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp @@ -58,7 +58,7 @@ void MAIN { int in1_index_subblock_offset = 0; for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); if (enable_reload) { // Reconfigure input @@ -98,7 +98,7 @@ void MAIN { pack_tile(i, mm_bias_intermediate_cb_id); } cb_push_back(mm_bias_intermediate_cb_id, out_subblock_num_tiles); - release_dst(tt::DstMode::Half); + release_dst(); // Redundant wait since we know data was just pushed cb_wait_front(mm_bias_intermediate_cb_id, out_subblock_num_tiles); @@ -108,7 +108,7 @@ void MAIN { unpack_reconfig_data_format(mm_bias_intermediate_cb_id, bias_cb_id); // reconfigure packer df for out pack_reconfig_data_format(out_cb_id); - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t i = 0, j = 0; j < out_subblock_h; j++) { uint32_t bcast_tile_idx = in1_index_subblock_offset; for (uint32_t k = 0; k < out_subblock_w; k++, i++) { @@ -150,7 +150,7 @@ void MAIN { cb_push_back(mm_partials_cb_id, out_subblock_num_tiles); } - release_dst(tt::DstMode::Half); + release_dst(); in1_index_subblock_offset += out_subblock_w; } in0_index_subblock_offset += in0_subblock_num_tiles; diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm_mixed_precision.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm_mixed_precision.cpp index d2b042f7238..632fc69018f 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm_mixed_precision.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/bmm_large_block_zm_mixed_precision.cpp @@ -46,7 +46,7 @@ void MAIN { int in1_index_subblock_offset = 0; for (uint32_t in1_subblock = 0; in1_subblock < in1_num_subblocks; in1_subblock++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); if (enable_reload) { // Reconfigure input @@ -98,7 +98,7 @@ void MAIN { cb_push_back(mm_partials_cb_id, out_subblock_num_tiles); } - release_dst(tt::DstMode::Half); + release_dst(); in1_index_subblock_offset += out_subblock_w; } in0_index_subblock_offset += in0_subblock_num_tiles; diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/bmm_tilize_untilize.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/bmm_tilize_untilize.cpp index b12c847b1df..3c972131d6b 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/bmm_tilize_untilize.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/bmm_tilize_untilize.cpp @@ -71,10 +71,10 @@ inline void tilize_in( for (uint32_t n = 0; n < num_out_subblocks_in_col; n++) { for (uint32_t w = 0; w < out_subblock_w; w++) { uint32_t tile_index = block_offset + within_block_index + w; - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(interm_cb_id, tile_index, 0); pack_tile(0, reblock_cb_id); - release_dst(tt::DstMode::Half); + release_dst(); } block_offset += out_subblock_num_tiles; } @@ -165,7 +165,7 @@ void MAIN { for (uint32_t in0_subblock_i = 0; in0_subblock_i < in0_num_subblocks; ++in0_subblock_i) { int in1_index_subblock_offset = 0; for (uint32_t in1_subblock_i = 0; in1_subblock_i < in1_num_subblocks; ++in1_subblock_i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); if (enable_reload) { // Reconfigure input copy_tile_to_dst_init_short_with_dt(in1_cb_id, matmul_partials_cb); @@ -201,7 +201,7 @@ void MAIN { if (last_out) { // first move the current result from dst to interim CB pack_matmul_subblock(out_for_bias_cb_id, out_subblock_num_tiles); - release_dst(tt::DstMode::Half); + release_dst(); // reconfig unpacker df for src B // unpack_reconfig_data_format(out_for_bias_cb_id, bias_cb_id); // bcast add data from bias_cb_id @@ -210,7 +210,7 @@ void MAIN { add_bcast_rows_init_short(); // reconfig packer df for out // pack_reconfig_data_format(out_cb_id); - acquire_dst(tt::DstMode::Half); + acquire_dst(); uint32_t i = 0; for (uint32_t h = 0; h < out_subblock_h; ++ h) { uint32_t bcast_tile_i = bias_block_offset + in1_index_subblock_offset; @@ -244,7 +244,7 @@ void MAIN { : out_cb_id) : matmul_partials_cb; pack_matmul_subblock(curr_matmul_out_cb, out_subblock_num_tiles); - release_dst(tt::DstMode::Half); + release_dst(); in1_index_subblock_offset += out_subblock_w; } // for in1_num_subblocks #ifndef FUSE_BIAS diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/broadcast.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/broadcast.cpp index cf60e0652a1..267be6ebc2e 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/broadcast.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/broadcast.cpp @@ -19,7 +19,7 @@ void MAIN { cb_wait_front(tt::CB::c_in1, onetile); cb_reserve_back(tt::CB::c_out0, onetile); - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_wait_front(tt::CB::c_in0, onetile); #ifndef BCAST_SPECIFIC @@ -30,7 +30,7 @@ void MAIN { pack_tile(0, tt::CB::c_out0); cb_pop_front(tt::CB::c_in0, onetile); - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(tt::CB::c_out0, onetile); cb_pop_front(tt::CB::c_in1, onetile); } diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/cumsum.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/cumsum.cpp index 5cd4b60b14f..c72464280c5 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/cumsum.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/cumsum.cpp @@ -30,7 +30,7 @@ void MAIN { for(uint32_t wt = 0; wt < Wt; ++wt) { for(uint32_t ht = 0; ht < Ht; ++ht) { cb_reserve_back(tt::CB::c_out0, onetile); - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_wait_front(tt::CB::c_in0, onetile); #ifndef ROWWISE @@ -48,7 +48,7 @@ void MAIN { pack_tile(0, tt::CB::c_out0); cb_pop_front(tt::CB::c_in0, onetile); - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(tt::CB::c_out0, onetile); } } diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/dropout_sfpu.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/dropout_sfpu.cpp index a55ce9ba155..5f43fc0b346 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/dropout_sfpu.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/dropout_sfpu.cpp @@ -21,7 +21,7 @@ void MAIN { for (uint32_t block_index = 0; block_index < per_core_block_cnt; block_index++) { cb_reserve_back(tt::CB::c_out0, per_core_block_dim); for(uint32_t tile_index = 0; tile_index < per_core_block_dim; ++tile_index) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); // Pop tile after tile, copy to DST and pack cb_wait_front(tt::CB::c_in0, 1); @@ -34,7 +34,7 @@ void MAIN { cb_pop_front(tt::CB::c_in0, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_push_back(tt::CB::c_out0, per_core_block_dim); } diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/eltwise_copy.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/eltwise_copy.cpp index 1e7c029d9a3..41e494d29b8 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/eltwise_copy.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/eltwise_copy.cpp @@ -15,7 +15,7 @@ void MAIN { unary_op_init_common(tt::CB::c_in0); for(uint32_t b=0;b 0) { copy_tile_to_dst_init_short(); cb_wait_front(partials_cb, out_block_num_tiles); @@ -68,7 +68,7 @@ void MAIN { cb_push_back(partials_cb, out_block_num_tiles); } } - release_dst(tt::DstMode::Half); + release_dst(); } } } // namespace NAMESPACE diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/unit_tests/matmul/multi_tile_compute.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/unit_tests/matmul/multi_tile_compute.cpp index 6190640c8bd..8c50ca5a30e 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/unit_tests/matmul/multi_tile_compute.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/unit_tests/matmul/multi_tile_compute.cpp @@ -23,7 +23,7 @@ void MAIN { // we are looking at block // out = in0[r x k]*in1[k x c] mm_init(); - acquire_dst(tt::DstMode::Half); + acquire_dst(); uint32_t out_tile_index = 0; uint32_t in0_index_r_offset = 0; @@ -50,6 +50,6 @@ void MAIN { pack_tile(tile_index, out_cb); } cb_push_back(out_cb, out_num_tiles); - release_dst(tt::DstMode::Half); + release_dst(); } } // namespace NAMESPACE diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/unit_tests/matmul/single_tile_compute.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/unit_tests/matmul/single_tile_compute.cpp index 8792a0af75e..cb8eb194d98 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/unit_tests/matmul/single_tile_compute.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/unit_tests/matmul/single_tile_compute.cpp @@ -21,14 +21,14 @@ void MAIN { const bool transpose = false; mm_init(); cb_reserve_back(out_cb, num_out_tiles); - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_wait_front(in0_cb, num_in0_tiles); cb_wait_front(in1_cb, num_in1_tiles); matmul_tiles(in0_cb, in1_cb, in0_tile_index, in1_tile_index, out_tile_index, transpose); pack_tile(0, out_cb); cb_pop_front(in0_cb, num_in0_tiles); cb_pop_front(in1_cb, num_in1_tiles); - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(out_cb, num_out_tiles); } } // namespace NAMESPACE diff --git a/tests/tt_metal/tt_metal/test_kernels/compute/unpack_tilizeA_B.cpp b/tests/tt_metal/tt_metal/test_kernels/compute/unpack_tilizeA_B.cpp index 6a49900ac5c..dc56a879feb 100644 --- a/tests/tt_metal/tt_metal/test_kernels/compute/unpack_tilizeA_B.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/compute/unpack_tilizeA_B.cpp @@ -41,11 +41,11 @@ void MAIN { unpack_tilizeA_B_block(tt::CB::c_in0, tt::CB::c_in1, per_core_block_tile_cnt, b); for(uint i=0; i(tt::CB::c_in0, tt::CB::c_in1, current_index, 0, htr); pack_tile(htr, tt::CB::c_out0, current_index); } - release_dst(tt::DstMode::Half); + release_dst(); } cb_pop_front(tt::CB::c_in1, onetile); } diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/compute/bcast_hw.cpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/compute/bcast_hw.cpp index 230ee8b9c36..499afa82fad 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/compute/bcast_hw.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/compute/bcast_hw.cpp @@ -26,7 +26,7 @@ void MAIN { #endif cb_reserve_back(tt::CB::c_out0, onetile); - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_wait_front(tt::CB::c_in0, onetile); @@ -37,7 +37,7 @@ void MAIN { #ifndef BCAST_SCALAR cb_pop_front(tt::CB::c_in1, onetile); #endif - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(tt::CB::c_out0, onetile); } } } diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/compute/bcast_w.cpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/compute/bcast_w.cpp index 0de0e2f82c0..ec6f71c0023 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/compute/bcast_w.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/compute/bcast_w.cpp @@ -23,14 +23,14 @@ void MAIN { cb_reserve_back(tt::CB::c_out0, onetile); - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_wait_front(tt::CB::c_in0, onetile); BCAST_OP(tt::CB::c_in0, tt::CB::c_in1, 0, 0, 0); pack_tile(0, tt::CB::c_out0); cb_pop_front(tt::CB::c_in0, onetile); - release_dst(tt::DstMode::Half); + release_dst(); cb_push_back(tt::CB::c_out0, onetile); diff --git a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/compute/eltwise_copy.cpp b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/compute/eltwise_copy.cpp index 118723f4640..a20b181877e 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/compute/eltwise_copy.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/sharded/device/kernels/compute/eltwise_copy.cpp @@ -15,7 +15,7 @@ void MAIN { unary_op_init_common(tt::CB::c_in0); for(uint32_t b=0;b ALWI void calc_numeric_stable(uint32_t cb_in, uint32_t cb_bcast_scaler, uint32_t cb_max, uint32_t cb_out) { diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp index 0bd45dcdbdb..21e7a4f704e 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_h.cpp @@ -24,7 +24,7 @@ void MAIN { // tiles are expected to be coming in in NCWH order (H-contiguous) // reducing in W means out[0][w] = sum(h=0..H-1, in[h][w]) // in this case we just sequentially add to accumulator all the H-tiles in a column - acquire_dst(tt::DstMode::Half); + acquire_dst(); for(uint32_t ht = 0; ht < Ht; ++ht) { cb_wait_front(tt::CB::c_in0, onetile); // REDUCE_OP is expected to come from add_define @@ -35,7 +35,7 @@ void MAIN { cb_reserve_back(tt::CB::c_out0, onetile); pack_tile(reduce_dst_idx, tt::CB::c_out0); cb_push_back(tt::CB::c_out0, onetile); - release_dst(tt::DstMode::Half); + release_dst(); } } } diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_hw.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_hw.cpp index e6989f6b357..e493d76ab02 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_hw.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_hw.cpp @@ -19,7 +19,7 @@ void MAIN { for (uint32_t nc = 0; nc < NC; nc++) { constexpr int onetile = 1; int reduce_dst_idx = 0; - acquire_dst(tt::DstMode::Half); + acquire_dst(); for(uint32_t ht = 0; ht < Ht; ++ht) { // tiles are expected to be coming in in NCHW order (W-contiguous) // reducing in W means out[h][0] = sum(w=0..W-1, in[h][w]) @@ -34,7 +34,7 @@ void MAIN { cb_reserve_back(tt::CB::c_out0, onetile); pack_tile(reduce_dst_idx, tt::CB::c_out0); cb_push_back(tt::CB::c_out0, onetile); - release_dst(tt::DstMode::Half); + release_dst(); } } } diff --git a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_w.cpp b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_w.cpp index dde4ae13bd6..95f7f9ce1dc 100644 --- a/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_w.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/generic/device/kernels/compute/reduce_w.cpp @@ -32,7 +32,7 @@ void MAIN { // tiles are expected to be coming in in NCHW order (W-contiguous) // reducing in W means out[h][0] = sum(w=0..W-1, in[h][w]) // in this case we just sequentially add to accumulator all the W-tiles in a row - acquire_dst(tt::DstMode::Half); + acquire_dst(); for(uint32_t wt = 0; wt < Wt; ++wt) { cb_wait_front(tt::CB::c_in0, onetile); // REDUCE_OP is expected to come from add_define @@ -47,7 +47,7 @@ void MAIN { cb_reserve_back(tt::CB::c_out0, onetile); pack_tile(reduce_dst_idx, tt::CB::c_out0); cb_push_back(tt::CB::c_out0, onetile); - release_dst(tt::DstMode::Half); + release_dst(); } } } diff --git a/ttnn/cpp/ttnn/operations/reduction/moe/device/kernels/compute/moe.cpp b/ttnn/cpp/ttnn/operations/reduction/moe/device/kernels/compute/moe.cpp index d86140851aa..2df4663dfd1 100644 --- a/ttnn/cpp/ttnn/operations/reduction/moe/device/kernels/compute/moe.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/moe/device/kernels/compute/moe.cpp @@ -75,14 +75,14 @@ void add_block_bcast_rows_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t row cb_wait_front(in1_cb, cols); for (uint32_t i = 0; i < rows; ++i) { for (uint32_t j = 0; j < cols; ++j) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); add_tiles_bcast_rows(in0_cb, in1_cb, 0, j, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_reconfig_data_format(in0_cb); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } cb_pop_front(in1_cb, cols); @@ -96,14 +96,14 @@ void mul_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); mul_tiles(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_reconfig_data_format(in0_cb); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t rows, uint32_t cols) { @@ -118,13 +118,13 @@ void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t row cb_wait_front(in1_cb, rows); for (uint32_t i = 0; i < rows; ++i) { for (uint32_t j = 0; j < cols; ++j) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); mul_tiles_bcast_cols(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } cb_pop_front(in1_cb, rows); @@ -138,14 +138,14 @@ void eqz_block_inplace(uint32_t in0_cb, uint32_t num_tiles) { eqz_tile_init(); cb_wait_front(in0_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); eqz_tile(0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_reconfig_data_format(in0_cb); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -157,14 +157,14 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) { cb_wait_front(in_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in_cb, 0, 0); cb_pop_front(in_cb, 1); recip_tile(0); cb_reserve_back(in_cb, 1); pack_tile(0, in_cb); cb_push_back(in_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -187,7 +187,7 @@ void reduce_c() { constexpr uint32_t reduce_dst_idx = 0; for (uint32_t i = 0; i < rows; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t j = 0; j < cols; j++) { reduce_tile(in0_cb, scale_cb, i*cols+j, 0, reduce_dst_idx); } @@ -196,7 +196,7 @@ void reduce_c() { pack_reconfig_data_format(out_cb); pack_tile(reduce_dst_idx, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } reduce_revert_delta(out_cb); @@ -223,7 +223,7 @@ void top_k() { // streaming in input and index tiles to transpose and bitonic local sort them, two tiles at a time for (uint32_t wt = 0; wt < Wt; wt+=2) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); // local sort into k groups cb_wait_front(input_cb_index, 2); cb_wait_front(index_cb_index, 2); @@ -253,7 +253,7 @@ void top_k() { cb_pop_front(input_cb_index, 2); cb_pop_front(index_cb_index, 2); - release_dst(tt::DstMode::Half); + release_dst(); } cb_push_back(input_transposed_cb_index, Wt); @@ -271,7 +271,7 @@ void top_k() { for (uint32_t left_ind = 0; left_ind < Wt - (1 << m_iter); left_ind += 2 << m_iter) { uint32_t right_ind = left_ind + (1 << m_iter); - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile_to_dst_init_short_with_dt(index_transposed_cb_index, input_transposed_cb_index); copy_tile(input_transposed_cb_index, left_ind, input_dest_start); @@ -295,7 +295,7 @@ void top_k() { // pack index tiles in-place in the single-buffered cb_intermed1, we only need the upper 32 values for topk, which was in index_dest_start pack_reconfig_data_format(index_transposed_cb_index); pack_tile(index_dest_start, index_transposed_cb_index, left_ind); - release_dst(tt::DstMode::Half); + release_dst(); a = !a; } @@ -317,12 +317,12 @@ void top_k() { pack_reconfig_data_format(input_transposed_cb_index); cb_wait_front(input_transposed_cb_index, Kt); for (uint32_t i = 0; i < Kt; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_reserve_back(values_cb_index, 1); transpose_wh_tile(input_transposed_cb_index, i, 0); pack_tile(0, values_cb_index); cb_push_back(values_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(input_transposed_cb_index, Wt); cb_pop_front(input_transposed_cb_index, Wt); @@ -333,12 +333,12 @@ void top_k() { pack_reconfig_data_format(index_transposed_cb_index); cb_wait_front(index_transposed_cb_index, Kt); for (uint32_t i = 0; i < Kt; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_reserve_back(output_ind_cb_index, 1); transpose_wh_tile(index_transposed_cb_index, i, 0); pack_tile(0, output_ind_cb_index); cb_push_back(output_ind_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(index_transposed_cb_index, Wt); cb_pop_front(index_transposed_cb_index, Wt); diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp index 0b47c50202a..b6e0652cfaa 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk.cpp @@ -46,7 +46,7 @@ void MAIN { // streaming in input and index tiles to transpose and bitonic local sort them, two tiles at a time for (uint32_t wt = 0; wt < Wt; wt+=2) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); // local sort into k groups cb_wait_front(input_cb_index, 2); cb_wait_front(index_cb_index, 2); @@ -76,7 +76,7 @@ void MAIN { cb_pop_front(input_cb_index, 2); cb_pop_front(index_cb_index, 2); - release_dst(tt::DstMode::Half); + release_dst(); } cb_push_back(input_transposed_cb_index, Wt); @@ -94,7 +94,7 @@ void MAIN { for (uint32_t left_ind = 0; left_ind < Wt - (1 << m_iter); left_ind += 2 << m_iter) { uint32_t right_ind = left_ind + (1 << m_iter); - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile_to_dst_init_short_with_dt(index_transposed_cb_index, input_transposed_cb_index); copy_tile(input_transposed_cb_index, left_ind, input_dest_start); @@ -118,7 +118,7 @@ void MAIN { // pack index tiles in-place in the single-buffered cb_intermed1, we only need the upper 32 values for topk, which was in index_dest_start pack_reconfig_data_format(index_transposed_cb_index); pack_tile(index_dest_start, index_transposed_cb_index, left_ind); - release_dst(tt::DstMode::Half); + release_dst(); a = !a; } cb_reserve_back(input_transposed_cb_index, Wt); @@ -139,12 +139,12 @@ void MAIN { pack_reconfig_data_format(input_transposed_cb_index); cb_wait_front(input_transposed_cb_index, Kt); for (uint32_t i = 0; i < Kt; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_reserve_back(values_cb_index, 1); transpose_wh_tile(input_transposed_cb_index, i, 0); pack_tile(0, values_cb_index); cb_push_back(values_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(input_transposed_cb_index, Wt); cb_pop_front(input_transposed_cb_index, Wt); @@ -155,12 +155,12 @@ void MAIN { pack_reconfig_data_format(index_transposed_cb_index); cb_wait_front(index_transposed_cb_index, Kt); for (uint32_t i = 0; i < Kt; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_reserve_back(output_ind_cb_index, 1); transpose_wh_tile(index_transposed_cb_index, i, 0); pack_tile(0, output_ind_cb_index); cb_push_back(output_ind_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(index_transposed_cb_index, Wt); cb_pop_front(index_transposed_cb_index, Wt); diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_final.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_final.cpp index d3b42dc722c..7a6e9eff113 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_final.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_final.cpp @@ -52,14 +52,14 @@ void MAIN { pack_reconfig_data_format(input_transposed_cb_index); // streaming in input and index tiles to transpose and bitonic local sort them, two tiles at a time for (uint32_t wt = 0; wt < Wt; wt++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); // copy in inputs from input_cb_index - TODO: figure out how to optimize this out cb_reserve_back(input_transposed_cb_index, 1); copy_tile(input_cb_index, wt, 0); // pack value tiles into cb_intermed2 pack_tile(0, input_transposed_cb_index); cb_push_back(input_transposed_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(input_transposed_cb_index, Wt); cb_pop_front(input_cb_index, Wt); @@ -67,14 +67,14 @@ void MAIN { copy_tile_to_dst_init_short_with_dt(input_cb_index, index_cb_index); pack_reconfig_data_format(index_transposed_cb_index); for (uint32_t wt = 0; wt < Wt; wt++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); // copy in inputs from index_cb_index cb_reserve_back(index_transposed_cb_index, 1); copy_tile(index_cb_index, wt, 0); // pack value tiles into cb_intermed3 pack_tile(0, index_transposed_cb_index); cb_push_back(index_transposed_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(index_transposed_cb_index, Wt); cb_pop_front(index_cb_index, Wt); @@ -92,7 +92,7 @@ void MAIN { uint32_t stride = 1 << m_iter; for (uint32_t left_ind = 0; left_ind < Wt - stride; left_ind += 2 << m_iter) { uint32_t right_ind = left_ind + stride; - acquire_dst(tt::DstMode::Half); + acquire_dst(); // unpack values into dest copy_tile_to_dst_init_short_with_dt(index_transposed_cb_index, input_transposed_cb_index); @@ -116,7 +116,7 @@ void MAIN { // pack index tiles in-place in the single-buffered cb_intermed1, we only need the upper 32 values for topk, which was in index_dest_start pack_reconfig_data_format(index_transposed_cb_index); pack_tile(index_dest_start, index_transposed_cb_index, left_ind); - release_dst(tt::DstMode::Half); + release_dst(); direction = !direction; } cb_reserve_back(input_transposed_cb_index, Wt); @@ -135,12 +135,12 @@ void MAIN { pack_reconfig_data_format(input_transposed_cb_index); cb_wait_front(input_transposed_cb_index, Kt); for (uint32_t i = 0; i < Kt; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_reserve_back(values_cb_index, 1); transpose_wh_tile(input_transposed_cb_index, i, 0); pack_tile(0, values_cb_index); cb_push_back(values_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(input_transposed_cb_index, Wt); cb_pop_front(input_transposed_cb_index, Wt); @@ -151,12 +151,12 @@ void MAIN { pack_reconfig_data_format(index_transposed_cb_index); cb_wait_front(index_transposed_cb_index, Kt); for (uint32_t i = 0; i < Kt; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_reserve_back(output_ind_cb_index, 1); transpose_wh_tile(index_transposed_cb_index, i, 0); pack_tile(0, output_ind_cb_index); cb_push_back(output_ind_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(index_transposed_cb_index, Wt); cb_pop_front(index_transposed_cb_index, Wt); diff --git a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp index a3a71d5ef1b..1276c35ac0f 100644 --- a/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/topk/device/kernels/compute/topk_local.cpp @@ -51,7 +51,7 @@ void MAIN { // streaming in input and index tiles to transpose and bitonic local sort them, two tiles at a time for (uint32_t wt = 0; wt < Wt; wt+=2) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); // transpose tiles and then local sort into k groups cb_wait_front(input_cb_index, 2); cb_wait_front(index_cb_index, 2); @@ -81,7 +81,7 @@ void MAIN { cb_pop_front(input_cb_index, 2); cb_pop_front(index_cb_index, 2); - release_dst(tt::DstMode::Half); + release_dst(); } cb_push_back(input_transposed_cb_index, Wt); @@ -99,7 +99,7 @@ void MAIN { uint32_t stride = 1 << m_iter; for (uint32_t left_ind = 0; left_ind < Wt - stride; left_ind += 2 << m_iter) { uint32_t right_ind = left_ind + stride; - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile_to_dst_init_short_with_dt(index_transposed_cb_index, input_transposed_cb_index); copy_tile(input_transposed_cb_index, left_ind, input_dest_start); @@ -123,7 +123,7 @@ void MAIN { // pack index tiles in-place in the single-buffered cb_intermed1, we only need the upper 32 values for topk, which was in index_dest_start pack_reconfig_data_format(index_transposed_cb_index); pack_tile(index_dest_start, index_transposed_cb_index, left_ind); - release_dst(tt::DstMode::Half); + release_dst(); direction = !direction; } cb_reserve_back(input_transposed_cb_index, Wt); @@ -142,12 +142,12 @@ void MAIN { pack_reconfig_data_format(input_transposed_cb_index); cb_wait_front(input_transposed_cb_index, Kt); for (uint32_t i = 0; i < Kt; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_reserve_back(values_cb_index, 1); copy_tile(input_transposed_cb_index, i, 0); pack_tile(0, values_cb_index); cb_push_back(values_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(input_transposed_cb_index, Wt); cb_pop_front(input_transposed_cb_index, Wt); @@ -158,12 +158,12 @@ void MAIN { pack_reconfig_data_format(index_transposed_cb_index); cb_wait_front(index_transposed_cb_index, Kt); for (uint32_t i = 0; i < Kt; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); cb_reserve_back(output_ind_cb_index, 1); copy_tile(index_transposed_cb_index, i, 0); pack_tile(0, output_ind_cb_index); cb_push_back(output_ind_cb_index, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_wait_front(index_transposed_cb_index, Wt); cb_pop_front(index_transposed_cb_index, Wt); diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp index 0617c7db17b..0338c8a10be 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa.cpp @@ -29,7 +29,7 @@ void max_block_inplace() { cb_wait_front(in0, num_tiles); cb_wait_front(in1, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in0, 0, dst_reg_0); copy_tile(in1, i, dst_reg_1); cb_pop_front(in0, 1); @@ -37,7 +37,7 @@ void max_block_inplace() { max_tile(dst_reg_0, dst_reg_1); pack_tile(dst_reg_0, in0); cb_push_back(in0, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -60,7 +60,7 @@ void reduce_c() { constexpr uint32_t reduce_dst_idx = 0; for (uint32_t i = 0; i < rows; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t j = 0; j < cols; j++) { reduce_tile(in0_cb, scale_cb, i*cols+j, 0, reduce_dst_idx); } @@ -68,7 +68,7 @@ void reduce_c() { cb_reserve_back(out_cb, 1); pack_tile(reduce_dst_idx, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } reduce_revert_delta(out_cb); @@ -82,14 +82,14 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) { cb_wait_front(in_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in_cb, 0, 0); cb_pop_front(in_cb, 1); recip_tile(0); cb_reserve_back(in_cb, 1); pack_tile(0, in_cb); cb_push_back(in_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -140,13 +140,13 @@ void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t row cb_wait_front(in1_cb, rows); for (uint32_t i = 0; i < rows; ++i) { for (uint32_t j = 0; j < cols; ++j) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); mul_tiles_bcast_cols(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } cb_pop_front(in1_cb, rows); @@ -166,7 +166,7 @@ void mul_block_bcast_scalar_inplace() { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_scalar_cb, 1); for (uint32_t g = 0; g < granularity; ++g) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t i = 0; i < dst_tiles; ++i) { mul_tiles_bcast_scalar(in0_cb, in1_scalar_cb, i, 0, i); } @@ -176,7 +176,7 @@ void mul_block_bcast_scalar_inplace() { pack_tile(i, in0_cb); } cb_push_back(in0_cb, dst_tiles); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -189,13 +189,13 @@ void add_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); add_tiles(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_pop_front(in1_cb, num_tiles); @@ -210,13 +210,13 @@ void mul_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); mul_tiles(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -233,7 +233,7 @@ void sub_exp_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t n for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); sub_tiles(in0_cb, in1_cb, i, i, 0); @@ -242,7 +242,7 @@ void sub_exp_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t n pack_tile(0, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -259,11 +259,11 @@ void copy_block(uint32_t in_cb, uint32_t out_cb, uint32_t num_tiles) { #pragma GCC unroll 0 for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in_cb, i, 0/*dst*/); pack_tile(0, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_pop_front(in_cb, num_tiles); } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa_noncausal.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa_noncausal.cpp index 15a9df120de..9d612ebb137 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa_noncausal.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa/device/kernels/compute/sdpa_noncausal.cpp @@ -29,7 +29,7 @@ void max_block_inplace() { cb_wait_front(in0, num_tiles); cb_wait_front(in1, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in0, 0, dst_reg_0); copy_tile(in1, i, dst_reg_1); cb_pop_front(in0, 1); @@ -37,7 +37,7 @@ void max_block_inplace() { max_tile(dst_reg_0, dst_reg_1); pack_tile(dst_reg_0, in0); cb_push_back(in0, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -60,7 +60,7 @@ void reduce_c() { constexpr uint32_t reduce_dst_idx = 0; for (uint32_t i = 0; i < rows; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t j = 0; j < cols; j++) { reduce_tile(in0_cb, scale_cb, i*cols+j, 0, reduce_dst_idx); } @@ -68,7 +68,7 @@ void reduce_c() { cb_reserve_back(out_cb, 1); pack_tile(reduce_dst_idx, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } reduce_revert_delta(out_cb); @@ -83,14 +83,14 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) { cb_wait_front(in_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in_cb, 0, 0); cb_pop_front(in_cb, 1); recip_tile(0); cb_reserve_back(in_cb, 1); pack_tile(0, in_cb); cb_push_back(in_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -141,13 +141,13 @@ void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t row cb_wait_front(in1_cb, rows); for (uint32_t i = 0; i < rows; ++i) { for (uint32_t j = 0; j < cols; ++j) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); mul_tiles_bcast_cols(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } cb_pop_front(in1_cb, rows); @@ -167,7 +167,7 @@ void mul_block_bcast_scalar_inplace() { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_scalar_cb, 1); for (uint32_t g = 0; g < granularity; ++g) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t i = 0; i < dst_tiles; ++i) { mul_tiles_bcast_scalar(in0_cb, in1_scalar_cb, i, 0, i); } @@ -177,7 +177,7 @@ void mul_block_bcast_scalar_inplace() { pack_tile(i, in0_cb); } cb_push_back(in0_cb, dst_tiles); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -190,13 +190,13 @@ void add_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); add_tiles(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_pop_front(in1_cb, num_tiles); @@ -211,13 +211,13 @@ void mul_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); mul_tiles(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -234,7 +234,7 @@ void sub_exp_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t n for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); sub_tiles(in0_cb, in1_cb, i, i, 0); @@ -243,7 +243,7 @@ void sub_exp_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t n pack_tile(0, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -259,11 +259,11 @@ void copy_block(uint32_t in_cb, uint32_t out_cb, uint32_t num_tiles) { cb_reserve_back(out_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in_cb, i, 0/*dst*/); pack_tile(0, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_pop_front(in_cb, num_tiles); } diff --git a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp index 04f9a751503..c96a9dadfd0 100644 --- a/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp +++ b/ttnn/cpp/ttnn/operations/transformer/sdpa_decode/device/kernels/compute/sdpa_flash_decode.cpp @@ -29,7 +29,7 @@ void max_block_inplace(uint32_t in0, uint32_t in1, uint32_t num_tiles) { cb_wait_front(in0, num_tiles); cb_wait_front(in1, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in0, 0, dst_reg_0); copy_tile(in1, i, dst_reg_1); cb_pop_front(in0, 1); @@ -37,7 +37,7 @@ void max_block_inplace(uint32_t in0, uint32_t in1, uint32_t num_tiles) { max_tile(dst_reg_0, dst_reg_1); pack_tile(dst_reg_0, in0); cb_push_back(in0, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -52,12 +52,12 @@ void max_block(uint32_t in0, uint32_t in1, uint32_t out_cb, uint32_t num_tiles) cb_wait_front(in1, num_tiles); cb_reserve_back(out_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in0, i, dst_reg_0); copy_tile(in1, i, dst_reg_1); max_tile(dst_reg_0, dst_reg_1); pack_tile(dst_reg_0, out_cb, i); - release_dst(tt::DstMode::Half); + release_dst(); } cb_push_back(out_cb, num_tiles); } @@ -81,7 +81,7 @@ void reduce_c() { constexpr uint32_t reduce_dst_idx = 0; for (uint32_t i = 0; i < rows; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t j = 0; j < cols; j++) { reduce_tile(in0_cb, scale_cb, i*cols+j, 0, reduce_dst_idx); } @@ -89,7 +89,7 @@ void reduce_c() { cb_reserve_back(out_cb, 1); pack_tile(reduce_dst_idx, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } reduce_revert_delta(out_cb); @@ -103,14 +103,14 @@ void recip_block_inplace(uint32_t in_cb, uint32_t num_tiles) { cb_wait_front(in_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; ++i) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in_cb, 0, 0); cb_pop_front(in_cb, 1); recip_tile(0); cb_reserve_back(in_cb, 1); pack_tile(0, in_cb); cb_push_back(in_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -160,13 +160,13 @@ void mul_block_bcast_cols_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t row cb_wait_front(in1_cb, rows); for (uint32_t i = 0; i < rows; ++i) { for (uint32_t j = 0; j < cols; ++j) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); mul_tiles_bcast_cols(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } cb_pop_front(in1_cb, rows); @@ -185,7 +185,7 @@ void mul_block_bcast_scalar_inplace(uint32_t in0_cb, uint32_t in1_scalar_cb, uin cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_scalar_cb, 1); for (uint32_t g = 0; g < granularity; ++g) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); for (uint32_t i = 0; i < dst_tiles; ++i) { mul_tiles_bcast_scalar(in0_cb, in1_scalar_cb, i, 0, i); } @@ -195,7 +195,7 @@ void mul_block_bcast_scalar_inplace(uint32_t in0_cb, uint32_t in1_scalar_cb, uin pack_tile(i, in0_cb); } cb_push_back(in0_cb, dst_tiles); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -209,13 +209,13 @@ void add_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); add_tiles(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } if (pop_in1) cb_pop_front(in1_cb, num_tiles); } @@ -230,10 +230,10 @@ void add_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t num_t cb_wait_front(in1_cb, num_tiles); cb_reserve_back(out_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); add_tiles(in0_cb, in1_cb, i, i, 0); pack_tile(0, out_cb, i); - release_dst(tt::DstMode::Half); + release_dst(); } cb_push_back(out_cb, num_tiles); @@ -250,13 +250,13 @@ void mul_block_inplace(uint32_t in0_cb, uint32_t in1_cb, uint32_t num_tiles) { cb_wait_front(in0_cb, num_tiles); cb_wait_front(in1_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); mul_tiles(in0_cb, in1_cb, 0, i, 0); cb_pop_front(in0_cb, 1); cb_reserve_back(in0_cb, 1); pack_tile(0, in0_cb); cb_push_back(in0_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -271,12 +271,12 @@ void sub_exp_block(uint32_t in0_cb, uint32_t in1_cb, uint32_t out_cb, uint32_t n cb_reserve_back(out_cb, num_tiles); for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); sub_tiles(in0_cb, in1_cb, i, i, 0); exp_tile(0); pack_tile(0, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } } @@ -293,11 +293,11 @@ void copy_block(uint32_t in_cb, uint32_t out_cb, uint32_t num_tiles) { #pragma GCC unroll 0 for (uint32_t i = 0; i < num_tiles; i++) { - acquire_dst(tt::DstMode::Half); + acquire_dst(); copy_tile(in_cb, i, 0/*dst*/); pack_tile(0, out_cb); cb_push_back(out_cb, 1); - release_dst(tt::DstMode::Half); + release_dst(); } cb_pop_front(in_cb, num_tiles); } From 7c5b6a976ad5768500ca3fb13fd748f4a5209695 Mon Sep 17 00:00:00 2001 From: Milos Trajkovic Date: Fri, 4 Oct 2024 17:24:06 -0400 Subject: [PATCH 41/58] #11239: Add APIs to read and clear FPU/SFPU special number registers --- .../metal/llk_api/llk_math_common_api.h | 25 ++++++++++++++++++ .../metal/llk_api/llk_math_common_api.h | 20 ++++++++++++++ .../metal/llk_api/llk_math_common_api.h | 26 +++++++++++++++++++ tt_metal/hw/inc/wormhole/tensix.h | 1 + tt_metal/include/compute_kernel_api.h | 26 +++++++++++++++++++ tt_metal/third_party/tt_llk_blackhole | 2 +- tt_metal/third_party/tt_llk_wormhole_b0 | 2 +- 7 files changed, 100 insertions(+), 2 deletions(-) diff --git a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_common_api.h b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_common_api.h index 5a0c5c8b04b..e50268e1b4c 100644 --- a/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_common_api.h +++ b/tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_math_common_api.h @@ -112,3 +112,28 @@ inline void llk_math_reconfig_data_format_srcb( llk_math_reconfig_data_format_srcb(srcb_new_operand); } } + +inline std::uint32_t llk_math_get_compute_special_value_flags() { + return _llk_math_get_compute_special_value_flags_(); +} + +inline std::uint32_t llk_math_get_compute_special_value_flags_fpu(std::uint32_t special_value_flags_reg) { + constexpr std::uint32_t special_value_flags_fpu_mask = 0xf; + constexpr std::uint32_t special_value_flags_fpu_shift = 4; + return (special_value_flags_reg & special_value_flags_fpu_mask) >> special_value_flags_fpu_shift; +} + +inline std::uint32_t llk_math_get_compute_special_value_flags_sfpu(std::uint32_t special_value_flags_reg) { + constexpr std::uint32_t special_value_flags_sfpu_mask = 0xf; + constexpr std::uint32_t special_value_flags_sfpu_shift = 0; + return (special_value_flags_reg & special_value_flags_sfpu_mask) >> special_value_flags_sfpu_shift; +} + +inline void llk_math_clear_compute_special_value_flags() { + _llk_math_clear_compute_special_value_flags_(); +} + +inline void llk_math_store_compute_special_value_flags_to_l1(std::uint32_t l1_addr) { + volatile tt_l1_ptr std::uint32_t* l1_addr_ptr = reinterpret_cast(l1_addr); + l1_addr_ptr[0] = _llk_math_get_compute_special_value_flags_(); +} diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h index ee1f4715f15..2a0736300bd 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h @@ -81,3 +81,23 @@ inline void llk_math_reconfig_data_format_srcb( const std::uint32_t srcb_old_operand, const std::uint32_t srcb_new_operand) { _llk_math_reconfig_data_format_srcb_(); } + +inline std::uint32_t llk_math_get_compute_special_value_flags() { + static_assert(false && "API not supported in Grayskull"); +} + +inline std::uint32_t llk_math_get_compute_special_value_flags_fpu(std::uint32_t special_value_flags_reg) { + static_assert(false && "API not supported in Grayskull"); +} + +inline std::uint32_t llk_math_get_compute_special_value_flags_sfpu(std::uint32_t special_value_flags_reg) { + static_assert(false && "API not supported in Grayskull"); +} + +inline std::uint32_t llk_math_clear_compute_special_value_flags() { + static_assert(false && "API not supported in Grayskull"); +} + +inline void llk_math_store_compute_special_value_flags_to_l1(std::uint32_t l1_addr) { + static_assert(false && "API not supported in Grayskull"); +} diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_common_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_common_api.h index 90d724edbf4..f7ae56652be 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_common_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_common_api.h @@ -110,3 +110,29 @@ inline void llk_math_reconfig_data_format_srcb( llk_math_reconfig_data_format_srcb(srcb_new_operand); } } + + +inline std::uint32_t llk_math_get_compute_special_value_flags() { + return _llk_math_get_compute_special_value_flags_(); +} + +inline std::uint32_t llk_math_get_compute_special_value_flags_fpu(std::uint32_t special_value_flags_reg) { + constexpr std::uint32_t special_value_flags_fpu_mask = 0x7; + constexpr std::uint32_t special_value_flags_fpu_shift = 4; + return (special_value_flags_reg & special_value_flags_fpu_mask) >> special_value_flags_fpu_shift; +} + +inline std::uint32_t llk_math_get_compute_special_value_flags_sfpu(std::uint32_t special_value_flags_reg) { + constexpr std::uint32_t special_value_flags_sfpu_mask = 0xf; + constexpr std::uint32_t special_value_flags_sfpu_shift = 0; + return (special_value_flags_reg & special_value_flags_sfpu_mask) >> special_value_flags_sfpu_shift; +} + +inline void llk_math_clear_compute_special_value_flags() { + _llk_math_clear_compute_special_value_flags_(); +} + +inline void llk_math_store_compute_special_value_flags_to_l1(std::uint32_t l1_addr) { + volatile tt_l1_ptr std::uint32_t* l1_addr_ptr = reinterpret_cast(l1_addr); + l1_addr_ptr[0] = _llk_math_get_compute_special_value_flags_(); +} diff --git a/tt_metal/hw/inc/wormhole/tensix.h b/tt_metal/hw/inc/wormhole/tensix.h index 83d781e263c..71e3ca58a81 100644 --- a/tt_metal/hw/inc/wormhole/tensix.h +++ b/tt_metal/hw/inc/wormhole/tensix.h @@ -139,6 +139,7 @@ typedef std::uint8_t byte; #define RISCV_DEBUG_REG_INSTRN_BUF_CTRL0 (RISCV_DEBUG_REGS_START_ADDR | 0x0A0) #define RISCV_DEBUG_REG_INSTRN_BUF_CTRL1 (RISCV_DEBUG_REGS_START_ADDR | 0x0A4) #define RISCV_DEBUG_REG_INSTRN_BUF_STATUS (RISCV_DEBUG_REGS_START_ADDR | 0x0A8) +#define RISCV_DEBUG_REG_FPU_STICKY_BITS (RISCV_DEBUG_REGS_START_ADDR | 0x0B4) #define RISCV_DEBUG_REG_PERF_CNT_TDMA_PACK0 (RISCV_DEBUG_REGS_START_ADDR | 0x0F0) #define RISCV_DEBUG_REG_PERF_CNT_TDMA_PACK1 (RISCV_DEBUG_REGS_START_ADDR | 0x0F4) #define RISCV_DEBUG_REG_PERF_CNT_TDMA_PACK2 (RISCV_DEBUG_REGS_START_ADDR | 0x0F8) diff --git a/tt_metal/include/compute_kernel_api.h b/tt_metal/include/compute_kernel_api.h index 606de2fa712..afb4a2bb1e1 100644 --- a/tt_metal/include/compute_kernel_api.h +++ b/tt_metal/include/compute_kernel_api.h @@ -888,5 +888,31 @@ ALWI void unary_lt_tile_init() { MATH(( llk_math_eltwise_unary_sfpu_unary_lt_init() )); } +ALWI uint32_t get_compute_special_value_flags() { + uint32_t ret_val = 0; + MATH(( ret_val = llk_math_get_compute_special_value_flags() )); + return ret_val; +} + +ALWI uint32_t get_compute_special_value_flags_fpu(uint32_t special_value_flags_reg) { + uint32_t ret_val = 0; + MATH (( ret_val = llk_math_get_compute_special_value_flags_fpu(special_value_flags_reg) )); + return ret_val; +} + +ALWI uint32_t get_compute_special_value_flags_sfpu(uint32_t special_value_flags_reg) { + uint32_t ret_val = 0; + MATH (( ret_val = llk_math_get_compute_special_value_flags_sfpu(special_value_flags_reg) )); + return ret_val; +} + +ALWI void clear_compute_special_value_flags() { + MATH (( llk_math_clear_compute_special_value_flags() )); +} + +ALWI void store_compute_special_value_flags_to_l1(uint32_t l1_addr) { + MATH (( llk_math_store_compute_special_value_flags_to_l1(l1_addr) )); +} + } // namespace ckernel diff --git a/tt_metal/third_party/tt_llk_blackhole b/tt_metal/third_party/tt_llk_blackhole index 05709f423aa..9a68fd09d8e 160000 --- a/tt_metal/third_party/tt_llk_blackhole +++ b/tt_metal/third_party/tt_llk_blackhole @@ -1 +1 @@ -Subproject commit 05709f423aa713fd299f52f4779d09e791a3228e +Subproject commit 9a68fd09d8ee2d81c445c576861cf146c9b54810 diff --git a/tt_metal/third_party/tt_llk_wormhole_b0 b/tt_metal/third_party/tt_llk_wormhole_b0 index 47bc7d232ed..166515054c0 160000 --- a/tt_metal/third_party/tt_llk_wormhole_b0 +++ b/tt_metal/third_party/tt_llk_wormhole_b0 @@ -1 +1 @@ -Subproject commit 47bc7d232edd7d7974938ec539a5661e689f5b53 +Subproject commit 166515054c09553317f569d3689198c3891cefe0 From 39537a089a947618e8442e0d3c1d0e4eca92adaf Mon Sep 17 00:00:00 2001 From: Milos Trajkovic Date: Fri, 4 Oct 2024 22:40:13 +0000 Subject: [PATCH 42/58] #11239: fix gs functions --- .../ckernels/grayskull/metal/llk_api/llk_math_common_api.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h index 2a0736300bd..423328c6bc8 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h @@ -84,17 +84,20 @@ inline void llk_math_reconfig_data_format_srcb( inline std::uint32_t llk_math_get_compute_special_value_flags() { static_assert(false && "API not supported in Grayskull"); + return 0; } inline std::uint32_t llk_math_get_compute_special_value_flags_fpu(std::uint32_t special_value_flags_reg) { static_assert(false && "API not supported in Grayskull"); + return 0; } inline std::uint32_t llk_math_get_compute_special_value_flags_sfpu(std::uint32_t special_value_flags_reg) { static_assert(false && "API not supported in Grayskull"); + return 0; } -inline std::uint32_t llk_math_clear_compute_special_value_flags() { +inline void llk_math_clear_compute_special_value_flags() { static_assert(false && "API not supported in Grayskull"); } From 27ba1f1cb58de1161d6e47e4692ed78973d08229 Mon Sep 17 00:00:00 2001 From: Milos Date: Mon, 7 Oct 2024 17:05:56 +0000 Subject: [PATCH 43/58] #11239: remove static asserts and only add comments that functions are unused --- .../metal/llk_api/llk_math_common_api.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h index 423328c6bc8..01c319cab44 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_api/llk_math_common_api.h @@ -83,24 +83,24 @@ inline void llk_math_reconfig_data_format_srcb( } inline std::uint32_t llk_math_get_compute_special_value_flags() { - static_assert(false && "API not supported in Grayskull"); - return 0; + // API not supported in Grayskull + return 0xFFFFFFFF; } inline std::uint32_t llk_math_get_compute_special_value_flags_fpu(std::uint32_t special_value_flags_reg) { - static_assert(false && "API not supported in Grayskull"); - return 0; + // API not supported in Grayskull + return 0xFFFFFFFF; } inline std::uint32_t llk_math_get_compute_special_value_flags_sfpu(std::uint32_t special_value_flags_reg) { - static_assert(false && "API not supported in Grayskull"); - return 0; + // API not supported in Grayskull + return 0xFFFFFFFF; } inline void llk_math_clear_compute_special_value_flags() { - static_assert(false && "API not supported in Grayskull"); + // API not supported in Grayskull } inline void llk_math_store_compute_special_value_flags_to_l1(std::uint32_t l1_addr) { - static_assert(false && "API not supported in Grayskull"); + // API not supported in Grayskull } From ac3fa9aed7f5cd28f3ddb72bd15613f049f47b51 Mon Sep 17 00:00:00 2001 From: thanhnguyen-moreh Date: Wed, 2 Oct 2024 06:35:45 +0000 Subject: [PATCH 44/58] #13273: Delete deprecated version --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 12 +- .../misc/test_moreh_logsoftmax.py | 414 --------------- .../unit_testing/misc/test_moreh_softmax.py | 479 ------------------ .../unit_testing/misc/test_moreh_softmin.py | 407 --------------- .../operations/test_moreh_logsoftmax.py | 184 +++++-- .../operations/test_moreh_softmax.py | 214 ++++---- .../operations/test_moreh_softmin.py | 233 ++++++--- ttnn/cpp/pybind11/__init__.cpp | 1 - .../tt_dnn/op_library/CMakeLists.txt | 12 - .../kernels/moreh_softmax_c_large.cpp | 129 ----- .../moreh_softmax/kernels/moreh_softmax_h.cpp | 168 ------ .../kernels/moreh_softmax_h_large.cpp | 173 ------- .../moreh_softmax/kernels/moreh_softmax_w.cpp | 173 ------- .../kernels/moreh_softmax_w_large.cpp | 177 ------- .../kernels/reader_moreh_softmax_c_large.cpp | 66 --- .../kernels/reader_moreh_softmax_h.cpp | 53 -- .../kernels/reader_moreh_softmax_h_large.cpp | 77 --- .../kernels/reader_moreh_softmax_w.cpp | 48 -- .../kernels/reader_moreh_softmax_w_large.cpp | 68 --- .../kernels/writer_moreh_softmax_c_large.cpp | 44 -- .../kernels/writer_moreh_softmax_h.cpp | 42 -- .../kernels/writer_moreh_softmax_h_large.cpp | 41 -- .../kernels/writer_moreh_softmax_w.cpp | 35 -- .../kernels/writer_moreh_softmax_w_large.cpp | 36 -- .../moreh_softmax/moreh_softmax_op.cpp | 263 ---------- .../moreh_softmax/moreh_softmax_op.hpp | 95 ---- .../softmax_c_large/softmax_c_large.cpp | 143 ------ .../softmax_h_large/softmax_h_large.cpp | 137 ----- .../softmax_h_small/softmax_h_small.cpp | 170 ------- .../softmax_w_large/softmax_w_large.cpp | 138 ----- .../softmax_w_small/softmax_w_small.cpp | 170 ------- .../moreh_softmax_backward_c_large.cpp | 90 ---- .../kernels/moreh_softmax_backward_h.cpp | 104 ---- .../moreh_softmax_backward_h_large.cpp | 110 ---- .../kernels/moreh_softmax_backward_w.cpp | 104 ---- .../moreh_softmax_backward_w_large.cpp | 108 ---- .../reader_moreh_softmax_backward_c.cpp | 82 --- .../reader_moreh_softmax_backward_h.cpp | 72 --- .../reader_moreh_softmax_backward_h_large.cpp | 95 ---- .../reader_moreh_softmax_backward_w.cpp | 67 --- .../reader_moreh_softmax_backward_w_large.cpp | 89 ---- .../writer_moreh_softmax_backward_c.cpp | 44 -- .../writer_moreh_softmax_backward_h.cpp | 41 -- .../writer_moreh_softmax_backward_w.cpp | 36 -- .../kernels/writer_moreh_softmax_h.cpp | 41 -- .../kernels/writer_moreh_softmax_w.cpp | 36 -- .../moreh_softmax_backward_op.cpp | 283 ----------- .../moreh_softmax_backward_op.hpp | 125 ----- .../softmax_backward_c_large.cpp | 146 ------ .../softmax_backward_h_large.cpp | 141 ------ .../softmax_backward_h_small.cpp | 163 ------ .../softmax_backward_w_large.cpp | 141 ------ .../softmax_backward_w_small.cpp | 164 ------ .../tt_lib/csrc/operations/primary/module.hpp | 84 --- .../device/moreh_softmax_device_operation.cpp | 49 +- .../device/moreh_softmax_device_operation.hpp | 4 +- .../softmax_c_large/softmax_c_large.cpp | 7 +- .../softmax_h_large/softmax_h_large.cpp | 5 +- .../softmax_h_small/softmax_h_small.cpp | 7 +- .../softmax_w_large/softmax_w_large.cpp | 7 +- .../softmax_w_small/softmax_w_small.cpp | 7 +- ...oreh_softmax_backward_device_operation.cpp | 10 +- .../softmax_backward_c_large.cpp | 2 +- .../softmax_backward_h_large.cpp | 2 +- .../softmax_backward_h_small.cpp | 2 +- .../softmax_backward_w_large.cpp | 2 +- .../softmax_backward_w_small.cpp | 2 +- .../normalization/softmax/softmax.cpp | 6 +- 68 files changed, 466 insertions(+), 6414 deletions(-) delete mode 100644 tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_logsoftmax.py delete mode 100644 tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_softmax.py delete mode 100644 tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_softmin.py delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_c_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_c_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_c_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_c_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_c.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_c.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_h.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_w.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_h.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_w.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp delete mode 100644 ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index c6f12613744..73ae0517fe9 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -209,17 +209,17 @@ TT-LIB API through ``tt_lib`` Primary Operations ================== -.. autofunction:: tt_lib.operations.primary.moreh_softmax +.. autofunction:: ttnn.operations.moreh.softmax -.. autofunction:: tt_lib.operations.primary.moreh_softmax_backward +.. autofunction:: ttnn.operations.moreh.softmax_backward -.. autofunction:: tt_lib.operations.primary.moreh_softmin +.. autofunction:: ttnn.operations.moreh.softmin -.. autofunction:: tt_lib.operations.primary.moreh_softmin_backward +.. autofunction:: ttnn.operations.moreh.softmin_backward -.. autofunction:: tt_lib.operations.primary.moreh_logsoftmax +.. autofunction:: ttnn.operations.moreh.logsoftmax -.. autofunction:: tt_lib.operations.primary.moreh_logsoftmax_backward +.. autofunction:: ttnn.operations.moreh.logsoftmax_backward .. autofunction:: ttnn.operations.moreh.mean diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_logsoftmax.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_logsoftmax.py deleted file mode 100644 index 61d73aa759d..00000000000 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_logsoftmax.py +++ /dev/null @@ -1,414 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch - -import ttnn -import pytest -from models.utility_functions import comp_allclose_and_pcc -from loguru import logger -import torch.nn.functional as F -from models.utility_functions import is_wormhole_b0 - -from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( - get_compute_kernel_options, - compute_kernel_options, - compute_kernel_ids, -) - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((50, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_logsoftmax_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + 100 - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).pad_to_tile(float("nan")).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = F.log_softmax(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax( - dev_x, dim, compute_kernel_config=compute_kernel_config - ) - - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - assert list(tt_dev.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_dev.to_torch().to(torch.bfloat16) - - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_logsoftmax_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + 100 - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = F.log_softmax(x, dim) - strategy = ( - ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_W - if dim == 3 - else ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_H - ) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax( - dev_x, dim, None, strategy, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_logsoftmax_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).pad_to_tile(float("nan")).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = F.log_softmax(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax( - dev_x, dim, compute_kernel_config=compute_kernel_config - ) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_logsoftmax_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + 100 - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).pad_to_tile(float("nan")).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = F.log_softmax(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax( - dev_x, dim, compute_kernel_config=compute_kernel_config - ) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 2), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 5, 32), 2), # multiple tiles per core - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_logsoftmax_backward_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.log_softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.5 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_logsoftmax_backward_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.log_softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - strategy = ( - ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_W - if dim == 3 - else ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_H - ) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax_backward( - dev_y, dev_dy, dim, None, strategy, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.5 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_logsoftmax_backward_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.log_softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).pad_to_tile(float("10")).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).pad_to_tile(float("200")).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.1 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_logsoftmax_backward_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.log_softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).pad_to_tile(float("10")).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).pad_to_tile(float("10")).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to_torch().to(torch.bfloat16) - - rtol = atol = 0.5 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - (((32, 32), 1),), # single tile -) -@pytest.mark.parametrize( - "optional_output_tensor", - (True, False), -) -def test_logsoftmax_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - # cpu calculation - tt_cpu = F.log_softmax(x, dim) - - # npu calculation - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - if optional_output_tensor: - dev_y = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax(dev_x, dim, dev_y) - else: - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax(dev_x, dim) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - (((32, 32), 1),), # single tile -) -@pytest.mark.parametrize( - "optional_output_tensor", - (True, False), -) -def test_logsoftmax_backward_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - # cpu calculation - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.log_softmax(x, dim) - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - y.backward(dy) - - # npu calculation - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - if optional_output_tensor: - dev_dx = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax_backward(dev_y, dev_dy, dim, dev_dx) - else: - tt_npu = ttnn.experimental.operations.primary.moreh_logsoftmax_backward(dev_y, dev_dy, dim) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_softmax.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_softmax.py deleted file mode 100644 index 498891c1432..00000000000 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_softmax.py +++ /dev/null @@ -1,479 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch - -import ttnn -import pytest -from models.utility_functions import comp_allclose_and_pcc -from loguru import logger -from models.utility_functions import is_wormhole_b0 - -from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( - get_compute_kernel_options, - compute_kernel_options, - compute_kernel_ids, -) - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmax_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + 100 - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = torch.softmax(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax(dev_x, dim, compute_kernel_config=compute_kernel_config) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmax_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + 100 - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = torch.softmax(x, dim) - - strategy = ( - ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_W - if dim == 3 - else ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_H - ) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax( - dev_x, dim, None, strategy, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmax_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).pad_to_tile(float("nan")).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = torch.softmax(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax(dev_x, dim, compute_kernel_config=compute_kernel_config) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmax_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + 100 - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).pad_to_tile(float("7")).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = torch.softmax(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax(dev_x, dim, compute_kernel_config=compute_kernel_config) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmax_backward_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = torch.softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmax_backward_large_algorithmfor_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = torch.softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - - strategy = ( - ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_W - if dim == 3 - else ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_H - ) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax_backward( - dev_y, dev_dy, dim, None, strategy, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmax_backward_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = torch.softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).pad_to_tile(float("10")).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).pad_to_tile(float("20")).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((15, 32, 32), 0), # single tile c - ((15, 32 * 7, 32 * 5), 0), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmax_backward_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = torch.softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT) - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim_strategy", - ( - ((32, 32), 1, ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.SMALL_W), - ((32, 32), 0, ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.SMALL_H), - ((32, 32), 1, ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_W), - ((32, 32), 0, ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_H), - ((1, 1, 32, 32), 1, ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_C), - ((1, 1, 32, 32), 0, ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_C), - ), -) -def test_softmax_callback(shape_dim_strategy, device): - device.enable_program_cache() - - shape, dim, strategy = shape_dim_strategy - torch.manual_seed(0) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = torch.softmax(x, dim) - for i in range(2): - tt_npu = ttnn.experimental.operations.primary.moreh_softmax(dev_x, dim, None, strategy) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim_strategy", - ( - ((32, 32), 1, ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.SMALL_W), - ((32, 32), 0, ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.SMALL_H), - ((32, 32), 1, ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_W), - ((32, 32), 0, ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_H), - ((1, 1, 32, 32), 1, ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_C), - ((1, 1, 32, 32), 0, ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_C), - ), -) -def test_softmax_backward_callback(shape_dim_strategy, device): - device.enable_program_cache() - shape, dim, strategy = shape_dim_strategy - torch.manual_seed(0) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = torch.softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - for i in range(2): - tt_npu = ttnn.experimental.operations.primary.moreh_softmax_backward(dev_y, dev_dy, dim, None, strategy) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - (((32, 32), 1),), # single tile -) -@pytest.mark.parametrize( - "optional_output_tensor", - (True, False), -) -def test_softmax_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - # cpu calculation - tt_cpu = torch.softmax(x, dim) - - # npu calculation - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - if optional_output_tensor: - dev_y = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_npu = ttnn.experimental.operations.primary.moreh_softmax(dev_x, dim, dev_y) - else: - tt_npu = ttnn.experimental.operations.primary.moreh_softmax(dev_x, dim) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - (((32, 32), 1),), # single tile -) -@pytest.mark.parametrize( - "optional_output_tensor", - (True, False), -) -def test_softmax_backward_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - # cpu calculation - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = torch.softmax(x, dim) - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - y.backward(dy) - - # npu calculation - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - if optional_output_tensor: - dev_dx = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - tt_npu = ttnn.experimental.operations.primary.moreh_softmax_backward(dev_y, dev_dy, dim, dev_dx) - else: - tt_npu = ttnn.experimental.operations.primary.moreh_softmax_backward(dev_y, dev_dy, dim) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_softmin.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_softmin.py deleted file mode 100644 index 19f77781990..00000000000 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_softmin.py +++ /dev/null @@ -1,407 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch - -import ttnn -import pytest -from models.utility_functions import comp_allclose_and_pcc -from loguru import logger -import torch.nn.functional as F -from models.utility_functions import is_wormhole_b0 - -from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( - get_compute_kernel_options, - compute_kernel_options, - compute_kernel_ids, -) - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmin_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = F.softmin(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin(dev_x, dim, compute_kernel_config=compute_kernel_config) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmin_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = F.softmin(x, dim) - strategy = ( - ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_W - if dim == 3 - else ttnn.experimental.operations.primary.MorehSoftmaxOpParallelizationStrategy.LARGE_H - ) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin( - dev_x, dim, None, strategy, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmin_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).pad_to_tile(float("nan")).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = F.softmin(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin(dev_x, dim, compute_kernel_config=compute_kernel_config) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmin_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - dev_x = ttnn.Tensor(x, ttnn.bfloat16).pad_to_tile(float("7")).to(ttnn.TILE_LAYOUT).to(device) - - tt_cpu = F.softmin(x, dim) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin(dev_x, dim, compute_kernel_config=compute_kernel_config) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmin_backward_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.softmin(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmin_backward_large_algorithmfor_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.softmin(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - strategy = ( - ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_W - if dim == 3 - else ttnn.experimental.operations.primary.MorehSoftmaxBackwardOpParallelizationStrategy.LARGE_H - ) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin_backward( - dev_y, dev_dy, dim, None, strategy, compute_kernel_config=compute_kernel_config - ) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmin_backward_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.softmin(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).pad_to_tile(float("10")).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).pad_to_tile(float("20")).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), -) -@pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) -def test_softmin_backward_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - compute_kernel_config = get_compute_kernel_options(compute_kernel_options) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.softmin(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).pad_to_tile(float("10")).to(ttnn.TILE_LAYOUT).to(device) - - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).pad_to_tile(float("10")).to(ttnn.TILE_LAYOUT).to(device) - - y.backward(dy) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin_backward( - dev_y, dev_dy, dim, compute_kernel_config=compute_kernel_config - ) - tt_npu = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.debug(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - (((32, 32), 1),), # single tile -) -@pytest.mark.parametrize( - "optional_output_tensor", - (True, False), -) -def test_softmin_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() - - shape, dim = shape_dim - torch.manual_seed(0) - - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - - # cpu calculation - tt_cpu = F.softmin(x, dim) - - # npu calculation - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - if optional_output_tensor: - dev_y = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - tt_npu = ttnn.experimental.operations.primary.moreh_softmin(dev_x, dim, dev_y) - else: - tt_npu = ttnn.experimental.operations.primary.moreh_softmin(dev_x, dim) - - assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing - - -@pytest.mark.parametrize( - "shape_dim", - (((32, 32), 1),), # single tile -) -@pytest.mark.parametrize( - "optional_output_tensor", - (True, False), -) -def test_softmin_backward_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() - shape, dim = shape_dim - torch.manual_seed(0) - - # cpu calculation - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - - y = F.softmin(x, dim) - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - y.backward(dy) - - # npu calculation - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - - if optional_output_tensor: - dev_dx = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - tt_npu = ttnn.experimental.operations.primary.moreh_softmin_backward(dev_y, dev_dy, dim, dev_dx) - else: - tt_npu = ttnn.experimental.operations.primary.moreh_softmin_backward(dev_y, dev_dy, dim) - - assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) - tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) - - rtol = atol = 0.05 - passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) - logger.info(out) - assert passing diff --git a/tests/ttnn/unit_tests/operations/test_moreh_logsoftmax.py b/tests/ttnn/unit_tests/operations/test_moreh_logsoftmax.py index 10ed7d6adbb..0a09b77ba22 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_logsoftmax.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_logsoftmax.py @@ -11,24 +11,25 @@ import torch.nn.functional as F from models.utility_functions import is_wormhole_b0 -from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( +from tests.ttnn.unit_tests.operations.test_utils import ( get_compute_kernel_options, compute_kernel_options, compute_kernel_ids, + to_npu, ) @pytest.mark.parametrize( "shape_dim", ( - ((50, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core + [[50, 32], 1], # single tile + [[3, 32, 32 * 5], 2], # mutiple tile with dim W + [[5, 6, 32, 32], 3], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 3], # multiple tiles per core + [[32, 32], 0], # single tile + [[3, 32 * 5, 32], 1], # mutiple tile with dim H + [[5, 6, 32, 32], 2], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 2], # multiple tiles per core ), ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) @@ -60,8 +61,8 @@ def test_logsoftmax_for_dim_hw(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), + [[2, 3, 32 * 4, 32 * 5], 3], + [[2, 3, 32 * 4, 32 * 5], 2], ), ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) @@ -99,10 +100,10 @@ def test_logsoftmax_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options @pytest.mark.parametrize( "shape_dim", ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim + [[1, 1, 10, 15], 3], # single tile + [[1, 1, 10, 32 * 2 + 10], 3], # mutiple tile with dim + [[1, 1, 15, 10], 2], # single tile + [[1, 1, 32 * 2 + 10, 32], 2], # mutiple tile with dim ), ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) @@ -133,12 +134,12 @@ def test_logsoftmax_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_opti @pytest.mark.parametrize( "shape_dim", ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores + [[1, 15, 32, 32], 1], # single tile c + [[1, 15, 32 * 7, 32 * 5], 1], # mutiple cores + [[109, 15, 32, 32], 1], # mutiple tiles per cores + [[15, 1, 32, 32], 0], # single tile n + [[15, 1, 32 * 7, 32 * 5], 0], # mutiple cores + [[15, 109, 32 * 2, 32 * 2], 0], # mutiple tiles per cores ), ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) @@ -169,14 +170,14 @@ def test_logsoftmax_for_dim_nc(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 2), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 5, 32), 2), # multiple tiles per core + [[32, 32], 1], # single tile + [[3, 32, 32 * 2], 2], # mutiple tile with dim W + [[5, 6, 32, 32], 3], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 3], # multiple tiles per core + [[32, 32], 0], # single tile + [[3, 32 * 5, 32], 1], # mutiple tile with dim H + [[5, 6, 32, 32], 2], # multiple cores + [[10, 20, 32 * 5, 32], 2], # multiple tiles per core ), ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) @@ -210,8 +211,8 @@ def test_logsoftmax_backward_for_dim_hw(shape_dim, compute_kernel_options, devic @pytest.mark.parametrize( "shape_dim", ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), + [[2, 3, 32 * 4, 32 * 5], 3], + [[2, 3, 32 * 4, 32 * 5], 2], ), ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) @@ -252,10 +253,10 @@ def test_logsoftmax_backward_large_algorithm_for_dim_hw(shape_dim, compute_kerne @pytest.mark.parametrize( "shape_dim", ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim + [[1, 1, 10, 15], 3], # single tile + [[1, 1, 10, 32 * 2 + 10], 3], # mutiple tile with dim + [[1, 1, 15, 10], 2], # single tile + [[1, 1, 32 * 2 + 10, 32], 2], # mutiple tile with dim ), ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) @@ -290,12 +291,12 @@ def test_logsoftmax_backward_not_multiple_of_32_for_dim_hw(shape_dim, compute_ke @pytest.mark.parametrize( "shape_dim", ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores + [[1, 15, 32, 32], 1], # single tile c + [[1, 15, 32 * 7, 32 * 5], 1], # mutiple cores + [[109, 15, 32, 32], 1], # mutiple tiles per cores + [[15, 1, 32, 32], 0], # single tile n + [[15, 1, 32 * 7, 32 * 5], 0], # mutiple cores + [[15, 109, 32 * 2, 32 * 2], 0], # mutiple tiles per cores ), ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) @@ -328,7 +329,9 @@ def test_logsoftmax_backward_for_dim_nc(shape_dim, compute_kernel_options, devic @pytest.mark.parametrize( "shape_dim", - (((32, 32), 1),), # single tile + [ + [[32, 32], 1], + ], # single tile ) @pytest.mark.parametrize( "optional_output_tensor", @@ -365,7 +368,9 @@ def test_logsoftmax_optional_output_tensor(shape_dim, optional_output_tensor, de @pytest.mark.parametrize( "shape_dim", - (((32, 32), 1),), # single tile + [ + [[32, 32], 1], + ], # single tile ) @pytest.mark.parametrize( "optional_output_tensor", @@ -400,3 +405,96 @@ def test_logsoftmax_backward_optional_output_tensor(shape_dim, optional_output_t passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) logger.info(out) assert passing + + +@pytest.mark.parametrize( + "shape_dim_strategy", + ( + [[50, 32], 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.SMALL_W], + [[32, 32], 0, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.SMALL_H], + [[2, 3, 32 * 4, 32 * 5], 3, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_W], + [[2, 3, 32 * 4, 32 * 5], 2, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_H], + [[1, 15, 32, 32], 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_C], + ), +) +def test_logsoftmax_callback(shape_dim_strategy, device, use_program_cache): + shape, dim, strategy = shape_dim_strategy + torch.manual_seed(0) + + for i in range(2): + x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + 100 + i + tt_cpu = F.log_softmax(x, dim) + dev_x = ttnn.Tensor(x, ttnn.bfloat16).pad_to_tile(float("nan")).to(ttnn.TILE_LAYOUT).to(device) + tt_npu = ttnn.operations.moreh.logsoftmax(dev_x, dim, strategy=strategy) + if i == 0: + num_program_cache_entries = device.num_program_cache_entries() + assert num_program_cache_entries > 0 + else: + assert device.num_program_cache_entries() == num_program_cache_entries + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) + + tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(shape) + assert list(tt_dev.shape.with_tile_padding()) == list(tt_cpu.shape) + tt_dev = tt_dev.to_torch().to(torch.bfloat16) + + rtol = atol = 0.1 + passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) + logger.debug(out) + assert passing + + +def logsoftmax_backward_step(dev_y, dev_dy, dim, strategy, device, num_program_cache_entries=None): + """ + Runs a single step of logsoftmax_backward and checks the program cache if needed. + """ + tt_npu = ttnn.operations.moreh.logsoftmax_backward(dev_y, dev_dy, dim, strategy=strategy) + + if num_program_cache_entries is not None: + assert device.num_program_cache_entries() == num_program_cache_entries + else: + num_program_cache_entries = device.num_program_cache_entries() + assert num_program_cache_entries > 0 + + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) + + return tt_npu, num_program_cache_entries + + +@pytest.mark.parametrize( + "shape_dim_strategy", + ( + [[32, 32], 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.SMALL_W], + [[32, 32], 0, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.SMALL_H], + [[2, 3, 32 * 4, 32 * 5], 3, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_W], + [[2, 3, 32 * 4, 32 * 5], 2, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_H], + [[1, 15, 32, 32], 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_C], + ), +) +def test_logsoftmax_backward_callback(shape_dim_strategy, device, use_program_cache): + shape, dim, strategy = shape_dim_strategy + torch.manual_seed(0) + + num_program_cache_entries = None + for i in range(2): + x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) + y = F.log_softmax(x, dim) + dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + + dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + + y.backward(dy) + + tt_npu, num_program_cache_entries = logsoftmax_backward_step( + dev_y, dev_dy, dim, strategy, device, num_program_cache_entries + ) + + assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) + tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) + + rtol = atol = 0.5 + passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) + logger.debug(out) + assert passing diff --git a/tests/ttnn/unit_tests/operations/test_moreh_softmax.py b/tests/ttnn/unit_tests/operations/test_moreh_softmax.py index bd286abe891..58ad9c82e68 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_softmax.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_softmax.py @@ -10,30 +10,29 @@ from loguru import logger from models.utility_functions import is_wormhole_b0 -from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( +from tests.ttnn.unit_tests.operations.test_utils import ( get_compute_kernel_options, compute_kernel_options, compute_kernel_ids, + to_npu, ) @pytest.mark.parametrize( "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), + [ + [[32, 32], 1], # single tile + [[3, 32, 32 * 5], 2], # mutiple tile with dim W + [[5, 6, 32, 32], 3], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 3], # multiple tiles per core + [[32, 32], 0], # single tile + [[3, 32 * 5, 32], 1], # mutiple tile with dim H + [[5, 6, 32, 32], 2], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 2], # multiple tiles per core + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmax_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim torch.manual_seed(0) @@ -57,15 +56,13 @@ def test_softmax_for_dim_hw(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), + [ + [[2, 3, 32 * 4, 32 * 5], 3], + [[2, 3, 32 * 4, 32 * 5], 2], + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmax_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim torch.manual_seed(0) @@ -95,16 +92,15 @@ def test_softmax_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options, d @pytest.mark.parametrize( "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), + [ + [[1, 1, 10, 15], 3], # single tile + [[1, 1, 10, 32 * 2 + 10], 3], # mutiple tile with dim + [[1, 1, 15, 10], 2], # single tile + [[1, 1, 32 * 2 + 10, 32], 2], # mutiple tile with dim + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmax_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -129,18 +125,17 @@ def test_softmax_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options @pytest.mark.parametrize( "shape_dim", - ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), + [ + [[1, 15, 32, 32], 1], # single tile c + [[1, 15, 32 * 7, 32 * 5], 1], # mutiple cores + [[109, 15, 32, 32], 1], # mutiple tiles per cores + [[15, 1, 32, 32], 0], # single tile n + [[15, 1, 32 * 7, 32 * 5], 0], # mutiple cores + [[15, 109, 32 * 2, 32 * 2], 0], # mutiple tiles per cores + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmax_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -165,20 +160,19 @@ def test_softmax_for_dim_nc(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), + [ + [[32, 32], 1], # single tile + [[3, 32, 32 * 5], 2], # mutiple tile with dim W + [[5, 6, 32, 32], 3], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 3], # multiple tiles per core + [[32, 32], 0], # single tile + [[3, 32 * 5, 32], 1], # mutiple tile with dim H + [[5, 6, 32, 32], 2], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 2], # multiple tiles per core + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmax_backward_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -206,14 +200,13 @@ def test_softmax_backward_for_dim_hw(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), + [ + [[2, 3, 32 * 4, 32 * 5], 3], + [[2, 3, 32 * 4, 32 * 5], 2], + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmax_backward_large_algorithmfor_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -249,16 +242,15 @@ def test_softmax_backward_large_algorithmfor_dim_hw(shape_dim, compute_kernel_op @pytest.mark.parametrize( "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), + [ + [[1, 1, 10, 15], 3], # single tile + [[1, 1, 10, 32 * 2 + 10], 3], # mutiple tile with dim + [[1, 1, 15, 10], 2], # single tile + [[1, 1, 32 * 2 + 10, 32], 2], # mutiple tile with dim + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmax_backward_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -287,18 +279,17 @@ def test_softmax_backward_not_multiple_of_32_for_dim_hw(shape_dim, compute_kerne @pytest.mark.parametrize( "shape_dim", - ( - ((15, 32, 32), 0), # single tile c - ((15, 32 * 7, 32 * 5), 0), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), + [ + [[15, 32, 32], 0], # single tile c + [[15, 32 * 7, 32 * 5], 0], # mutiple cores + [[109, 15, 32, 32], 1], # mutiple tiles per cores + [[15, 1, 32, 32], 0], # single tile n + [[15, 1, 32 * 7, 32 * 5], 0], # mutiple cores + [[15, 109, 32 * 2, 32 * 2], 0], # mutiple tiles per cores + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmax_backward_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -326,28 +317,33 @@ def test_softmax_backward_for_dim_nc(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim_strategy", - ( - ((32, 32), 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.SMALL_W), - ((32, 32), 0, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.SMALL_H), - ((32, 32), 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_W), - ((32, 32), 0, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_H), - ((1, 1, 32, 32), 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_C), - ((1, 1, 32, 32), 0, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_C), - ), + [ + [[32, 32], 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.SMALL_W], + [[32, 32], 0, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.SMALL_H], + [[32, 32], 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_W], + [[32, 32], 0, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_H], + [[1, 1, 32, 32], 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_C], + [[1, 1, 32, 32], 0, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_C], + ], ) -def test_softmax_callback(shape_dim_strategy, device): - device.enable_program_cache() - +def test_softmax_callback(shape_dim_strategy, device, use_program_cache): shape, dim, strategy = shape_dim_strategy torch.manual_seed(0) - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + for i in range(2): + x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - tt_cpu = torch.softmax(x, dim) - for i in range(2): + tt_cpu = torch.softmax(x, dim) tt_npu = ttnn.operations.moreh.softmax(dev_x, dim, strategy=strategy) + if i == 0: + num_program_cache_entries = device.num_program_cache_entries() + assert num_program_cache_entries > 0 + else: + assert device.num_program_cache_entries() == num_program_cache_entries + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) @@ -360,31 +356,36 @@ def test_softmax_callback(shape_dim_strategy, device): @pytest.mark.parametrize( "shape_dim_strategy", - ( - ((32, 32), 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.SMALL_W), - ((32, 32), 0, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.SMALL_H), - ((32, 32), 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_W), - ((32, 32), 0, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_H), - ((1, 1, 32, 32), 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_C), - ((1, 1, 32, 32), 0, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_C), - ), + [ + [[32, 32], 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.SMALL_W], + [[32, 32], 0, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.SMALL_H], + [[32, 32], 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_W], + [[32, 32], 0, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_H], + [[1, 1, 32, 32], 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_C], + ], ) -def test_softmax_backward_callback(shape_dim_strategy, device): - device.enable_program_cache() +def test_softmax_backward_callback(shape_dim_strategy, device, use_program_cache): shape, dim, strategy = shape_dim_strategy torch.manual_seed(0) - x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) + for i in range(2): + x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) - y = torch.softmax(x, dim) - dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + y = torch.softmax(x, dim) + dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) - dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) - y.backward(dy) - for i in range(2): + y.backward(dy) tt_npu = ttnn.operations.moreh.softmax_backward(dev_y, dev_dy, dim, strategy=strategy) + if i == 0: + num_program_cache_entries = device.num_program_cache_entries() + assert num_program_cache_entries > 0 + else: + assert device.num_program_cache_entries() == num_program_cache_entries + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) @@ -397,15 +398,15 @@ def test_softmax_backward_callback(shape_dim_strategy, device): @pytest.mark.parametrize( "shape_dim", - (((32, 32), 1),), # single tile + [ + [[32, 32], 1], + ], # single tile ) @pytest.mark.parametrize( "optional_output_tensor", - (True, False), + [True, False], ) def test_softmax_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() - shape, dim = shape_dim torch.manual_seed(0) @@ -434,14 +435,15 @@ def test_softmax_optional_output_tensor(shape_dim, optional_output_tensor, devic @pytest.mark.parametrize( "shape_dim", - (((32, 32), 1),), # single tile + [ + [[32, 32], 1], + ], # single tile ) @pytest.mark.parametrize( "optional_output_tensor", - (True, False), + [True, False], ) def test_softmax_backward_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) diff --git a/tests/ttnn/unit_tests/operations/test_moreh_softmin.py b/tests/ttnn/unit_tests/operations/test_moreh_softmin.py index fb00e9d34c1..4e079c0c41c 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_softmin.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_softmin.py @@ -11,30 +11,29 @@ import torch.nn.functional as F from models.utility_functions import is_wormhole_b0 -from tests.tt_eager.python_api_testing.unit_testing.misc.test_utils import ( +from tests.ttnn.unit_tests.operations.test_utils import ( get_compute_kernel_options, compute_kernel_options, compute_kernel_ids, + to_npu, ) @pytest.mark.parametrize( "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), + [ + [[32, 32], 1], # single tile + [[3, 32, 32 * 5], 2], # mutiple tile with dim W + [[5, 6, 32, 32], 3], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 3], # multiple tiles per core + [[32, 32], 0], # single tile + [[3, 32 * 5, 32], 1], # mutiple tile with dim H + [[5, 6, 32, 32], 2], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 2], # multiple tiles per core + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmin_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim torch.manual_seed(0) @@ -58,15 +57,13 @@ def test_softmin_for_dim_hw(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), + [ + [[2, 3, 32 * 4, 32 * 5], 3], + [[2, 3, 32 * 4, 32 * 5], 2], + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmin_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() - shape, dim = shape_dim torch.manual_seed(0) @@ -95,16 +92,15 @@ def test_softmin_large_algorithm_for_dim_hw(shape_dim, compute_kernel_options, d @pytest.mark.parametrize( "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), + [ + [[1, 1, 10, 15], 3], # single tile + [[1, 1, 10, 32 * 2 + 10], 3], # mutiple tile with dim + [[1, 1, 15, 10], 2], # single tile + [[1, 1, 32 * 2 + 10, 32], 2], # mutiple tile with dim + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmin_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -129,18 +125,17 @@ def test_softmin_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options @pytest.mark.parametrize( "shape_dim", - ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), + [ + [[1, 15, 32, 32], 1], # single tile c + [[1, 15, 32 * 7, 32 * 5], 1], # mutiple cores + [[109, 15, 32, 32], 1], # mutiple tiles per cores + [[15, 1, 32, 32], 0], # single tile n + [[15, 1, 32 * 7, 32 * 5], 0], # mutiple cores + [[15, 109, 32 * 2, 32 * 2], 0], # mutiple tiles per cores + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmin_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -165,20 +160,19 @@ def test_softmin_for_dim_nc(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", - ( - ((32, 32), 1), # single tile - ((3, 32, 32 * 5), 2), # mutiple tile with dim W - ((5, 6, 32, 32), 3), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 3), # multiple tiles per core - ((32, 32), 0), # single tile - ((3, 32 * 5, 32), 1), # mutiple tile with dim H - ((5, 6, 32, 32), 2), # multiple cores - ((10, 20, 32 * 3, 32 * 5), 2), # multiple tiles per core - ), + [ + [[32, 32], 1], # single tile + [[3, 32, 32 * 5], 2], # mutiple tile with dim W + [[5, 6, 32, 32], 3], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 3], # multiple tiles per core + [[32, 32], 0], # single tile + [[3, 32 * 5, 32], 1], # mutiple tile with dim H + [[5, 6, 32, 32], 2], # multiple cores + [[10, 20, 32 * 3, 32 * 5], 2], # multiple tiles per core + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmin_backward_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -206,14 +200,13 @@ def test_softmin_backward_for_dim_hw(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", - ( - ((2, 3, 32 * 4, 32 * 5), 3), - ((2, 3, 32 * 4, 32 * 5), 2), - ), + [ + [[2, 3, 32 * 4, 32 * 5], 3], + [[2, 3, 32 * 4, 32 * 5], 2], + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmin_backward_large_algorithmfor_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -248,16 +241,15 @@ def test_softmin_backward_large_algorithmfor_dim_hw(shape_dim, compute_kernel_op @pytest.mark.parametrize( "shape_dim", - ( - ((1, 1, 10, 15), 3), # single tile - ((1, 1, 10, 32 * 2 + 10), 3), # mutiple tile with dim - ((1, 1, 15, 10), 2), # single tile - ((1, 1, 32 * 2 + 10, 32), 2), # mutiple tile with dim - ), + [ + [[1, 1, 10, 15], 3], # single tile + [[1, 1, 10, 32 * 2 + 10], 3], # mutiple tile with dim + [[1, 1, 15, 10], 2], # single tile + [[1, 1, 32 * 2 + 10, 32], 2], # mutiple tile with dim + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmin_backward_not_multiple_of_32_for_dim_hw(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -286,18 +278,17 @@ def test_softmin_backward_not_multiple_of_32_for_dim_hw(shape_dim, compute_kerne @pytest.mark.parametrize( "shape_dim", - ( - ((1, 15, 32, 32), 1), # single tile c - ((1, 15, 32 * 7, 32 * 5), 1), # mutiple cores - ((109, 15, 32, 32), 1), # mutiple tiles per cores - ((15, 1, 32, 32), 0), # single tile n - ((15, 1, 32 * 7, 32 * 5), 0), # mutiple cores - ((15, 109, 32 * 2, 32 * 2), 0), # mutiple tiles per cores - ), + [ + [[1, 15, 32, 32], 1], # single tile c + [[1, 15, 32 * 7, 32 * 5], 1], # mutiple cores + [[109, 15, 32, 32], 1], # mutiple tiles per cores + [[15, 1, 32, 32], 0], # single tile n + [[15, 1, 32 * 7, 32 * 5], 0], # mutiple cores + [[15, 109, 32 * 2, 32 * 2], 0], # mutiple tiles per cores + ], ) @pytest.mark.parametrize("compute_kernel_options", compute_kernel_options, ids=compute_kernel_ids) def test_softmin_backward_for_dim_nc(shape_dim, compute_kernel_options, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -325,15 +316,15 @@ def test_softmin_backward_for_dim_nc(shape_dim, compute_kernel_options, device): @pytest.mark.parametrize( "shape_dim", - (((32, 32), 1),), # single tile + [ + [[32, 32], 1], + ], # single tile ) @pytest.mark.parametrize( "optional_output_tensor", - (True, False), + [True, False], ) def test_softmin_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() - shape, dim = shape_dim torch.manual_seed(0) @@ -362,14 +353,15 @@ def test_softmin_optional_output_tensor(shape_dim, optional_output_tensor, devic @pytest.mark.parametrize( "shape_dim", - (((32, 32), 1),), # single tile + [ + [[32, 32], 1], + ], # single tile ) @pytest.mark.parametrize( "optional_output_tensor", - (True, False), + [True, False], ) def test_softmin_backward_optional_output_tensor(shape_dim, optional_output_tensor, device): - device.enable_program_cache() shape, dim = shape_dim torch.manual_seed(0) @@ -397,3 +389,96 @@ def test_softmin_backward_optional_output_tensor(shape_dim, optional_output_tens passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) logger.info(out) assert passing + + +@pytest.mark.parametrize( + "shape_dim_strategy", + [ + [[32, 32], 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.SMALL_W], + [[32, 32], 0, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.SMALL_H], + [[2, 3, 32 * 4, 32 * 5], 3, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_W], + [[2, 3, 32 * 4, 32 * 5], 2, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_H], + [[1, 15, 32, 32], 1, ttnn.operations.moreh.SoftmaxOpParallelizationStrategy.LARGE_C], + ], +) +def test_softmin_callback(shape_dim_strategy, device, use_program_cache): + shape, dim, strategy = shape_dim_strategy + torch.manual_seed(0) + + for i in range(2): + x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + i + + dev_x = ttnn.Tensor(x, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + + tt_cpu = F.softmin(x, dim) + tt_npu = ttnn.operations.moreh.softmin(dev_x, dim, strategy=strategy) + if i == 0: + num_program_cache_entries = device.num_program_cache_entries() + assert num_program_cache_entries > 0 + else: + assert device.num_program_cache_entries() == num_program_cache_entries + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) + + assert list(tt_npu.shape.with_tile_padding()) == list(tt_cpu.shape) + tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) + + rtol = atol = 0.05 + passing, out = comp_allclose_and_pcc(tt_cpu, tt_dev, rtol=rtol, atol=atol) + logger.debug(out) + assert passing + + +def softmin_backward_step(dev_y, dev_dy, dim, strategy, device, num_program_cache_entries=None): + """ + Runs a single step of softmin_backward and checks the program cache if needed. + """ + tt_npu = ttnn.operations.moreh.softmin_backward(dev_y, dev_dy, dim, strategy=strategy) + + if num_program_cache_entries is not None: + assert device.num_program_cache_entries() == num_program_cache_entries + else: + num_program_cache_entries = device.num_program_cache_entries() + assert num_program_cache_entries > 0 + + torch_dummy = torch.randn([32, 32]) + tt_dummy = to_npu(torch_dummy, device) + + return tt_npu, num_program_cache_entries + + +@pytest.mark.parametrize( + "shape_dim_strategy", + [ + [[32, 32], 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.SMALL_W], + [[32, 32], 0, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.SMALL_H], + [[2, 3, 32 * 4, 32 * 5], 3, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_W], + [[2, 3, 32 * 4, 32 * 5], 2, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_H], + [[1, 15, 32, 32], 1, ttnn.operations.moreh.SoftmaxBackwardOpParallelizationStrategy.LARGE_C], + ], +) +def test_softmin_backward_callback(shape_dim_strategy, device, use_program_cache): + shape, dim, strategy = shape_dim_strategy + torch.manual_seed(0) + + num_program_cache_entries = None + for i in range(2): + x = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16).requires_grad_(True) + y = F.softmin(x, dim) + dev_y = ttnn.Tensor(y, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + + dy = torch.randint(low=0, high=4, size=shape).to(torch.bfloat16) + dev_dy = ttnn.Tensor(dy, ttnn.bfloat16).to(ttnn.TILE_LAYOUT).to(device) + + y.backward(dy) + tt_npu, num_program_cache_entries = softmin_backward_step( + dev_y, dev_dy, dim, strategy, device, num_program_cache_entries + ) + + assert list(tt_npu.shape.with_tile_padding()) == list(x.grad.shape) + tt_dev = tt_npu.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch().to(torch.bfloat16) + + rtol = atol = 0.05 + passing, out = comp_allclose_and_pcc(x.grad, tt_dev, rtol=rtol, atol=atol) + logger.debug(out) + assert passing diff --git a/ttnn/cpp/pybind11/__init__.cpp b/ttnn/cpp/pybind11/__init__.cpp index b83e380b121..39c2d977307 100644 --- a/ttnn/cpp/pybind11/__init__.cpp +++ b/ttnn/cpp/pybind11/__init__.cpp @@ -52,7 +52,6 @@ PYBIND11_MODULE(_ttnn, module) { ttnn::tensor::pytensor_module_types(m_tensor); ttnn::graph::py_graph_module_types(m_graph); - tt::operations::primary::py_module_types(m_primary_ops); ttnn::types::py_module_types(m_types); ttnn::activation::py_module_types(m_activation); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt index 86ec5e6a663..1be253ed036 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/CMakeLists.txt @@ -7,18 +7,6 @@ set(TT_DNN_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/moreh_clip_grad_norm/moreh_clip_grad_norm_step1/moreh_clip_grad_norm_step1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_clip_grad_norm/moreh_clip_grad_norm_step2/moreh_clip_grad_norm_step2.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_clip_grad_norm/moreh_clip_grad_norm_step3/moreh_clip_grad_norm_step3.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/moreh_softmax_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_w_small/softmax_w_small.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_h_small/softmax_h_small.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_w_large/softmax_w_large.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_h_large/softmax_h_large.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax/softmax_c_large/softmax_c_large.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax_backward/moreh_softmax_backward_op.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_h_impl/moreh_sum_h_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_h_impl/moreh_int_sum_h_impl.cpp ${CMAKE_CURRENT_SOURCE_DIR}/moreh_sum/moreh_sum_w_impl/moreh_sum_w_impl.cpp diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_c_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_c_large.cpp deleted file mode 100644 index edbd39002aa..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_c_large.cpp +++ /dev/null @@ -1,129 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_ROW - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { -void MAIN { - constexpr auto cb_in0 = tt::CB::c_in0; - constexpr auto cb_out0 = tt::CB::c_out0; - constexpr auto cb_exps = tt::CB::c_intermed0; - constexpr auto cb_recipsumexps = tt::CB::c_intermed1; - constexpr auto cb_add = tt::CB::c_intermed2; - constexpr auto cb_max = tt::CB::c_intermed3; - constexpr auto cb_tmp = tt::CB::c_intermed4; - - constexpr uint32_t onetile = 1; - constexpr int dst0 = 0; - constexpr int dst1 = 1; - - uint32_t N = get_compile_time_arg_val(0); - uint32_t dim_size = get_compile_time_arg_val(1); - - binary_op_init_common(cb_in0, cb_exps); - - for (uint32_t n = 0; n < N; ++n) { - // find max - for (uint32_t i = 0; i < dim_size; ++i) { - if (i == 0) { - copy_tile_to_cb(cb_in0, cb_max); - } else { - cb_wait_front(cb_in0, onetile); - cb_wait_front(cb_max, onetile); - - tile_regs_acquire(); - - copy_tile_init_with_dt(cb_in0); - copy_tile(cb_in0, 0, dst0); - - copy_tile_init_with_dt(cb_max); - copy_tile(cb_max, 0, dst1); - - max_tile_init(); - max_tile(dst0, dst1); - tile_regs_commit(); - - cb_pop_front(cb_max, onetile); - cb_reserve_back(cb_max, onetile); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_max); - tile_regs_release(); - - cb_push_back(cb_max, onetile); - cb_pop_front(cb_in0, onetile); - } - } - - // compute exp(x - max(x)) - for (uint32_t i = 0; i < dim_size; ++i) { - #ifdef SOFTMAX - sub_tiles_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - exp_tile_to_cb(cb_tmp, cb_exps); - #else - sub_tiles_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - rexp_tile_to_cb(cb_tmp, cb_exps); - #endif - - if (i == 0) { - copy_tile_to_cb(cb_exps, cb_add); - } else { - add_tiles_to_cb(cb_add, cb_exps, cb_add); - } - } - -#ifdef LOG - // compute log(sum) - log_tile_to_cb(cb_add, cb_recipsumexps); -#else - // compute 1/sum(exp(x)) - recip_tile_to_cb(cb_add, cb_recipsumexps); -#endif - - // step 3, compute final result - cb_wait_front(cb_recipsumexps, onetile); - for (uint32_t i = 0; i < dim_size; ++i) { - #ifdef LOG - #ifdef SOFTMAX - // x - max - log(sum) - sub_tiles_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - sub_tiles_to_cb(cb_tmp, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #else - // -x + max - log(sum) - sub_tiles_to_cb(cb_max, cb_in0, cb_tmp, 0, 0, /*pop0=*/0, /*pop1=*/1); - - sub_tiles_to_cb(cb_tmp, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #endif - #else - #ifdef SOFTMAX - // exp(x - max) / sum - sub_tiles_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - exp_tile_to_cb(cb_tmp, cb_exps); - - mul_tiles_to_cb(cb_exps, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #else - // rexp(x - max) / sum - sub_tiles_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - rexp_tile_to_cb(cb_tmp, cb_exps); - - mul_tiles_to_cb(cb_exps, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #endif - #endif - } - - cb_pop_front(cb_recipsumexps, onetile); - cb_pop_front(cb_max, onetile); - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h.cpp deleted file mode 100644 index 2c576fdfeb3..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h.cpp +++ /dev/null @@ -1,168 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_COL - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { -void MAIN { - constexpr auto cb_in0 = tt::CB::c_in0; - constexpr auto cb_mask = tt::CB::c_in1; - constexpr auto cb_bcast_scaler = tt::CB::c_in2; - constexpr auto cb_out0 = tt::CB::c_out0; - constexpr auto cb_exps = tt::CB::c_intermed0; - constexpr auto cb_recipsumexps = tt::CB::c_intermed1; - constexpr auto cb_max = tt::CB::c_intermed2; - constexpr auto cb_x_m_max = tt::CB::c_intermed3; - constexpr auto cb_tmp = tt::CB::c_intermed4; - - constexpr int dst0 = 0; - constexpr int dst1 = 1; - constexpr uint32_t onetile = 1; - - binary_op_init_common(cb_in0, cb_bcast_scaler); - - uint32_t N = get_compile_time_arg_val(0); - uint32_t Ht = get_compile_time_arg_val(1); - - cb_wait_front(cb_mask, onetile); - cb_wait_front(cb_bcast_scaler, onetile); - - for (uint32_t n = 0; n < N; ++n) { - // find max value - if (Ht == 1) { - mask_tile_to_cb(cb_in0, cb_mask, cb_tmp, 0, 0, /*pop0=*/0, /*popm=*/0); - - reduce_tile_to_cb(cb_tmp, cb_bcast_scaler, cb_max, Ht, /*pop0=*/1, /*pop1=*/0); - } else { - reduce_tile_to_cb(cb_in0, cb_bcast_scaler, cb_max, Ht - 1, /*pop0=*/0, /*pop1=*/0); - - mask_tile_to_cb(cb_in0, cb_mask, cb_tmp, Ht - 1, 0, /*pop0=*/0, /*popm=*/0); - - cb_wait_front(cb_max, 1); - cb_wait_front(cb_tmp, 1); - - tile_regs_acquire(); - copy_tile_init_with_dt(cb_max); - copy_tile(cb_max, 0, dst0); - - constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB - reduce_init_delta_with_dt(cb_max, cb_tmp, cb_bcast_scaler); - reduce_tile(cb_tmp, cb_bcast_scaler, 0, bcast_scaler0, dst0); - reduce_revert_delta(cb_max); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_max); - tile_regs_release(); - - cb_pop_front(cb_max, 1); - cb_pop_front(cb_tmp, 1); - cb_push_back(cb_max, 1); - } - - // compute x - max(x) - cb_reserve_back(cb_x_m_max, Ht); - cb_wait_front(cb_in0, Ht); - cb_wait_front(cb_max, 1); - - for (uint32_t h = 0; h < Ht; ++h) { - tile_regs_acquire(); - sub_bcast_rows_init_short_with_dt(cb_in0, cb_max); - sub_tiles_bcast(cb_in0, cb_max, h, 0, dst0); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_x_m_max); - tile_regs_release(); - } - cb_pop_front(cb_max, 1); - cb_pop_front(cb_in0, Ht); - cb_push_back(cb_x_m_max, Ht); - - // compute exp(x - max(x)) - cb_reserve_back(cb_exps, Ht); - cb_wait_front(cb_x_m_max, Ht); - for (uint32_t h = 0; h < Ht; ++h) { - tile_regs_acquire(); - copy_tile_init_with_dt(cb_x_m_max); - copy_tile(cb_x_m_max, h, dst0); - -#ifndef SOFTMAX - negative_tile_init(); - negative_tile(dst0); -#endif - - exp_tile_init(); - exp_tile(dst0); - - if (h == Ht - 1) { - copy_tile_init_with_dt(cb_mask); - copy_tile(cb_mask, 0, dst1); - - mask_tile_init(); - mask_tile(dst0, dst1); - } - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_exps); - tile_regs_release(); - } - cb_push_back(cb_exps, Ht); - - -#ifdef LOG - // log(sum) - reduce_and_log_tile_to_cb(cb_exps, cb_bcast_scaler, cb_recipsumexps, Ht, /*pop0=*/Ht, /*pop1=*/0); -#else - // 1/sum - reduce_and_recip_tile_to_cb(cb_exps, cb_bcast_scaler, cb_recipsumexps, Ht, /*pop0=*/0, /*pop1=*/0); -#endif - - // compute final result - cb_reserve_back(cb_out0, Ht); - cb_wait_front(cb_x_m_max, Ht); - cb_wait_front(cb_recipsumexps, 1); -#ifndef LOG - cb_wait_front(cb_exps, Ht); -#endif - - for (uint32_t h = 0; h < Ht; h += onetile) { -#ifdef LOG - // x - max - log(sum) - tile_regs_acquire(); - sub_bcast_rows_init_short_with_dt(cb_x_m_max, cb_recipsumexps); - sub_tiles_bcast(cb_x_m_max, cb_recipsumexps, h, 0, dst0); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_out0); - tile_regs_release(); -#else - // exp(x - max) / psum - tile_regs_acquire(); - mul_bcast_rows_init_short_with_dt(cb_exps, cb_recipsumexps); - mul_tiles_bcast_rows(cb_exps, cb_recipsumexps, h, 0, dst0); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_out0); - tile_regs_release(); -#endif - } - - cb_pop_front(cb_recipsumexps, 1); - cb_pop_front(cb_x_m_max, Ht); - cb_push_back(cb_out0, Ht); -#ifndef LOG - cb_pop_front(cb_exps, Ht); -#endif - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h_large.cpp deleted file mode 100644 index c8a945e7958..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h_large.cpp +++ /dev/null @@ -1,173 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_COL - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { -void MAIN { - constexpr auto cb_in0 = tt::CB::c_in0; - constexpr auto cb_mask = tt::CB::c_in1; - constexpr auto cb_bcast_scaler = tt::CB::c_in2; - constexpr auto cb_out0 = tt::CB::c_out0; - constexpr auto cb_exps = tt::CB::c_intermed0; - constexpr auto cb_recipsumexps = tt::CB::c_intermed1; - constexpr auto cb_add = tt::CB::c_intermed2; - constexpr auto cb_max = tt::CB::c_intermed3; - constexpr auto cb_tmp = tt::CB::c_intermed4; - - binary_op_init_common(cb_in0, cb_bcast_scaler); - - constexpr uint32_t onetile = 1; - constexpr int dst0 = 0; - - uint32_t N = get_compile_time_arg_val(0); - uint32_t Ht = get_compile_time_arg_val(1); - - for (uint32_t n = 0; n < N; ++n) { - - // find max - if (Ht == 1) { - mask_tile_to_cb(cb_in0, cb_mask, cb_tmp, 0, 0, /*pop0=*/1, /*popm=*/0); - - reduce_tile_to_cb(cb_tmp, cb_bcast_scaler, cb_max, Ht, /*pop0=*/1, /*pop1=*/0); - } else { - cb_reserve_back(cb_max, onetile); - - tile_regs_acquire(); - reduce_init_delta_with_dt(cb_max, cb_in0, cb_bcast_scaler); - for (uint32_t h = 0; h < Ht - 1; ++h) { - cb_wait_front(cb_in0, onetile); - - constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB - reduce_tile(cb_in0, cb_bcast_scaler, 0, bcast_scaler0, dst0); - - cb_pop_front(cb_in0, onetile); - } - reduce_revert_delta(cb_max); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_max); - tile_regs_release(); - - cb_push_back(cb_max, onetile); - - - mask_tile_to_cb(cb_in0, cb_mask, cb_tmp, 0, 0, /*pop0=*/1, /*popm=*/0); - - - cb_wait_front(cb_max, onetile); - cb_wait_front(cb_tmp, onetile); - - tile_regs_acquire(); - copy_tile_init_with_dt(cb_max); - copy_tile(cb_max, 0, dst0); - - constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB - reduce_init_delta_with_dt(cb_max, cb_tmp, cb_bcast_scaler); - reduce_tile(cb_tmp, cb_bcast_scaler, 0, bcast_scaler0, dst0); - reduce_revert_delta(cb_max); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_max); - tile_regs_release(); - - cb_pop_front(cb_max, onetile); - cb_pop_front(cb_tmp, onetile); - cb_push_back(cb_max, onetile); - } - - for (uint32_t h = 0; h < Ht; h += onetile) { - // compute exp(x - max(x)) - if (h == Ht - 1) { - #ifdef SOFTMAX - sub_tiles_bcast_rows_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - exp_tile_and_mask_tile_to_cb( - cb_tmp, - cb_mask, - cb_exps, - /*itile=*/0, - /*mtile=*/0, - /*pop=*/1, - /*popm=*/0); - #else - rexp_tile_and_mask_tile_to_cb( - cb_in0, - cb_mask, - cb_exps, - /*itile=*/0, - /*mtile=*/0, - /*pop=*/1, - /*popm=*/0); - #endif - } else { - #ifdef SOFTMAX - sub_tiles_bcast_rows_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - exp_tile_to_cb(cb_tmp, cb_exps); - #else - sub_tiles_bcast_rows_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - rexp_tile_to_cb(cb_tmp, cb_exps); - #endif - } - - if (h == 0) { - copy_tile_to_cb(cb_exps, cb_add); - } else { - add_tiles_to_cb(cb_add, cb_exps, cb_add); - } - } - -#ifdef LOG - // compute log(sum) - reduce_and_log_tile_to_cb(cb_add, cb_bcast_scaler, cb_recipsumexps, /*size=*/1, /*pop0=*/1, /*pop1=*/0); -#else - // compute 1/sum(exp(x)) - reduce_and_recip_tile_to_cb(cb_add, cb_bcast_scaler, cb_recipsumexps, /*size=*/1, /*pop0=*/1, /*pop1=*/0); -#endif - - // step 3, compute final result - for (uint32_t h = 0; h < Ht; h += onetile) { - #ifdef LOG - #ifdef SOFTMAX - // x - max - log(sum) - sub_tiles_bcast_rows_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - sub_tiles_bcast_rows_to_cb(cb_tmp, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #else - // -x + max - log(sum) - // logsoftmin not implemented - #endif - #else - #ifdef SOFTMAX - // exp(x - max) / sum - sub_tiles_bcast_rows_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - exp_tile_to_cb(cb_tmp, cb_exps); - - mul_tiles_bcast_rows_to_cb(cb_exps, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #else - // rexp(x - max) / sum - sub_tiles_bcast_rows_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - rexp_tile_to_cb(cb_tmp, cb_exps); - - mul_tiles_bcast_rows_to_cb(cb_exps, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #endif - #endif - } - - cb_pop_front(cb_recipsumexps, onetile); - cb_pop_front(cb_max, onetile); - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w.cpp deleted file mode 100644 index bd4b1753e3e..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w.cpp +++ /dev/null @@ -1,173 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_ROW - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { - -void MAIN { - constexpr auto cb_in0 = tt::CB::c_in0; - constexpr auto cb_mask = tt::CB::c_in1; - constexpr auto cb_bcast_scaler = tt::CB::c_in2; - constexpr auto cb_out0 = tt::CB::c_out0; - constexpr auto cb_exps = tt::CB::c_intermed0; - constexpr auto cb_recipsumexps = tt::CB::c_intermed1; - constexpr auto cb_max = tt::CB::c_intermed2; - constexpr auto cb_x_m_max = tt::CB::c_intermed3; - constexpr auto cb_tmp = tt::CB::c_intermed4; - - binary_op_init_common(cb_in0, cb_bcast_scaler); - - constexpr int dst0 = 0; - constexpr int dst1 = 1; - constexpr uint32_t onetile = 1; - - uint32_t N = get_compile_time_arg_val(0); - uint32_t Wt = get_compile_time_arg_val(1); - - cb_wait_front(cb_mask, onetile); - cb_wait_front(cb_bcast_scaler, onetile); - - for (uint32_t n = 0; n < N; ++n) { - - // find max value - if (Wt == 1) { - mask_tile_to_cb(cb_in0, cb_mask, cb_tmp, 0, 0, /*pop0=*/0, /*popm=*/0); - - reduce_tile_to_cb(cb_tmp, cb_bcast_scaler, cb_max, Wt, /*pop0=*/1, /*pop1=*/0); - } else { - reduce_tile_to_cb(cb_in0, cb_bcast_scaler, cb_max, Wt - 1, /*pop0=*/0, /*pop1=*/0); - - mask_tile_to_cb(cb_in0, cb_mask, cb_tmp, Wt - 1, 0, /*pop0=*/0, /*popm=*/0); - - cb_wait_front(cb_max, 1); - cb_wait_front(cb_tmp, 1); - - tile_regs_acquire(); - copy_tile_init_with_dt(cb_max); - copy_tile(cb_max, 0, dst0); - - constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB - reduce_init_delta_with_dt(cb_max, cb_tmp, cb_bcast_scaler); - reduce_tile(cb_tmp, cb_bcast_scaler, 0, bcast_scaler0, dst0); - reduce_revert_delta(cb_max); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_max); - tile_regs_release(); - - cb_pop_front(cb_max, 1); - cb_pop_front(cb_tmp, 1); - cb_push_back(cb_max, 1); - } - - - // compute x - max(x) - cb_reserve_back(cb_x_m_max, Wt); - cb_wait_front(cb_in0, Wt); - cb_wait_front(cb_max, 1); - - for (uint32_t w = 0; w < Wt; ++w) { - tile_regs_acquire(); - sub_bcast_cols_init_short_with_dt(cb_in0, cb_max); - sub_tiles_bcast(cb_in0, cb_max, w, 0, dst0); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_x_m_max); - tile_regs_release(); - } - cb_pop_front(cb_max, 1); - cb_pop_front(cb_in0, Wt); - cb_push_back(cb_x_m_max, Wt); - - - // compute exp(x - max(x)) - cb_reserve_back(cb_exps, Wt); - cb_wait_front(cb_x_m_max, Wt); - for (uint32_t w = 0; w < Wt; ++w) { - tile_regs_acquire(); - copy_tile_init_with_dt(cb_x_m_max); - copy_tile(cb_x_m_max, w, dst0); - -#ifndef SOFTMAX - negative_tile_init(); - negative_tile(dst0); -#endif - - exp_tile_init(); - exp_tile(dst0); - - if (w == Wt - 1){ - copy_tile_init_with_dt(cb_mask); - copy_tile(cb_mask, 0, dst1); - - mask_tile_init(); - mask_tile(dst0, dst1); - } - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_exps); - tile_regs_release(); - } - cb_push_back(cb_exps, Wt); - - -#ifdef LOG - // log(sum) - reduce_and_log_tile_to_cb(cb_exps, cb_bcast_scaler, cb_recipsumexps, Wt, /*pop0=*/Wt, /*pop1=*/0); -#else - // 1/sum - reduce_and_recip_tile_to_cb(cb_exps, cb_bcast_scaler, cb_recipsumexps, Wt, /*pop0=*/0, /*pop1=*/0); -#endif - - // compute final result - cb_reserve_back(cb_out0, Wt); - cb_wait_front(cb_x_m_max, Wt); - cb_wait_front(cb_recipsumexps, 1); - -#ifndef LOG - cb_wait_front(cb_exps, Wt); -#endif - - for (uint32_t w = 0; w < Wt; w += onetile) { -#ifdef LOG - // x - max - log(sum) - tile_regs_acquire(); - sub_bcast_cols_init_short_with_dt(cb_x_m_max, cb_recipsumexps); - sub_tiles_bcast(cb_x_m_max, cb_recipsumexps, w, 0, dst0); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_out0); - tile_regs_release(); -#else - // exp(x - max) / psum - tile_regs_acquire(); - mul_bcast_cols_init_short_with_dt(cb_exps, cb_recipsumexps); - mul_tiles_bcast_cols(cb_exps, cb_recipsumexps, w, 0, dst0); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_out0); - tile_regs_release(); -#endif - } - - cb_pop_front(cb_recipsumexps, 1); - cb_pop_front(cb_x_m_max, Wt); - cb_push_back(cb_out0, Wt); -#ifndef LOG - cb_pop_front(cb_exps, Wt); -#endif - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w_large.cpp deleted file mode 100644 index 1bb14fbaddb..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w_large.cpp +++ /dev/null @@ -1,177 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_ROW - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { -void MAIN { - constexpr auto cb_in0 = tt::CB::c_in0; - constexpr auto cb_mask = tt::CB::c_in1; - constexpr auto cb_bcast_scaler = tt::CB::c_in2; - constexpr auto cb_out0 = tt::CB::c_out0; - constexpr auto cb_exps = tt::CB::c_intermed0; - constexpr auto cb_recipsumexps = tt::CB::c_intermed1; - constexpr auto cb_add = tt::CB::c_intermed2; - constexpr auto cb_max = tt::CB::c_intermed3; - constexpr auto cb_tmp = tt::CB::c_intermed4; - - binary_op_init_common(cb_in0, cb_bcast_scaler); - - constexpr uint32_t onetile = 1; - constexpr int dst0 = 0; - - uint32_t N = get_compile_time_arg_val(0); - uint32_t Wt = get_compile_time_arg_val(1); - - for (uint32_t n = 0; n < N; ++n) { - - // find max - if (Wt == 1) { - mask_tile_to_cb(cb_in0, cb_mask, cb_tmp, 0, 0, /*pop0=*/1, /*popm=*/0); - - reduce_tile_to_cb( - cb_tmp, cb_bcast_scaler, cb_max, Wt, /*pop0=*/1, /*pop1=*/0); - } else { - cb_reserve_back(cb_max, onetile); - - tile_regs_acquire(); - reduce_init_delta_with_dt(cb_max, cb_in0, cb_bcast_scaler); - for (uint32_t w = 0; w < Wt - 1; ++w) { - cb_wait_front(cb_in0, onetile); - - constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB - reduce_tile(cb_in0, cb_bcast_scaler, 0, bcast_scaler0, dst0); - - cb_pop_front(cb_in0, onetile); - } - reduce_revert_delta(cb_max); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_max); - tile_regs_release(); - - cb_push_back(cb_max, onetile); - - - mask_tile_to_cb(cb_in0, cb_mask, cb_tmp, 0, 0, /*pop0=*/1, /*popm=*/0); - - - cb_wait_front(cb_max, onetile); - cb_wait_front(cb_tmp, onetile); - - tile_regs_acquire(); - copy_tile_init_with_dt(cb_max); - copy_tile(cb_max, 0, dst0); - - constexpr uint32_t bcast_scaler0 = 0; // 0th index from bcast_scaler CB - reduce_init_delta_with_dt(cb_max, cb_tmp, cb_bcast_scaler); - reduce_tile(cb_tmp, cb_bcast_scaler, 0, bcast_scaler0, dst0); - reduce_revert_delta(cb_max); - tile_regs_commit(); - - tile_regs_wait(); - pack_tile_with_dt(dst0, cb_max); - tile_regs_release(); - - cb_pop_front(cb_max, onetile); - cb_pop_front(cb_tmp, onetile); - cb_push_back(cb_max, onetile); - } - - // step 1 - for (uint32_t w = 0; w < Wt; ++w) { - // compute exp(x) - if (w == Wt - 1) { - #ifdef SOFTMAX - sub_tiles_bcast_cols_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - exp_tile_and_mask_tile_to_cb( - cb_tmp, - cb_mask, - cb_exps, - /*itile=*/0, - /*mtile=*/0, - /*pop=*/1, - /*popm=*/0); - #else - rexp_tile_and_mask_tile_to_cb( - cb_in0, - cb_mask, - cb_exps, - /*itile=*/0, - /*mtile=*/0, - /*pop=*/1, - /*popm=*/0); - #endif - } else { - #ifdef SOFTMAX - sub_tiles_bcast_cols_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - exp_tile_to_cb(cb_tmp, cb_exps); - #else - sub_tiles_bcast_cols_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - rexp_tile_to_cb(cb_tmp, cb_exps); - #endif - } - - if (w == 0) { - copy_tile_to_cb(cb_exps, cb_add); - } else { - add_tiles_to_cb(cb_add, cb_exps, cb_add); - } - } - -#ifdef LOG - // compute log(sum) - reduce_and_log_tile_to_cb( - cb_add, cb_bcast_scaler, cb_recipsumexps, /*size=*/1, /*pop0=*/1, /*pop1=*/0); -#else - // compute 1/sum(exp(x)) - reduce_and_recip_tile_to_cb( - cb_add, cb_bcast_scaler, cb_recipsumexps, /*size=*/1, /*pop0=*/1, /*pop1=*/0); -#endif - - // step 3, compute final result - for (uint32_t w = 0; w < Wt; w += onetile) { - #ifdef LOG - #ifdef SOFTMAX - // x - max - log(sum) - sub_tiles_bcast_cols_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - sub_tiles_bcast_cols_to_cb(cb_tmp, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #else - // -x + max - log(sum) - // logsoftmin not implemented - #endif - #else - #ifdef SOFTMAX - // exp(x - max) / sum - sub_tiles_bcast_cols_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - exp_tile_to_cb(cb_tmp, cb_exps); - - mul_tiles_bcast_cols_to_cb(cb_exps, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #else - // rexp(x - max) / sum - sub_tiles_bcast_cols_to_cb(cb_in0, cb_max, cb_tmp, 0, 0, /*pop0=*/1, /*pop1=*/0); - - rexp_tile_to_cb(cb_tmp, cb_exps); - - mul_tiles_bcast_cols_to_cb(cb_exps, cb_recipsumexps, cb_out0, 0, 0, /*pop0=*/1, /*pop1=*/0); - #endif - #endif - } - - cb_pop_front(cb_recipsumexps, onetile); - cb_pop_front(cb_max, onetile); - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_c_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_c_large.cpp deleted file mode 100644 index 2aa6412b95c..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_c_large.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t src_addr = get_arg_val(0); - uint32_t num_tiles = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t outer_stride = get_arg_val(3); - uint32_t inner_size = get_arg_val(4); - uint32_t dim_size = get_arg_val(5); - - constexpr auto cb_in = tt::CB::c_in0; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t src_in_tile_bytes = get_tile_size(cb_in); - const DataFormat src_in_data_format = get_dataformat(cb_in); - - constexpr bool in_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast src_in = { - .bank_base_address = src_addr, .page_size = src_in_tile_bytes, .data_format = src_in_data_format}; - - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < num_tiles; i += onetile) { - uint32_t outer_idx = curr_tile / (inner_size); - uint32_t inner_idx = curr_tile % inner_size; - uint32_t tile_idx = outer_idx * outer_stride + inner_idx; - - uint32_t dim_stride = inner_size; - for (uint32_t d = 0; d < dim_size; d++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(tile_idx, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - tile_idx += dim_stride; - } - - tile_idx = outer_idx * outer_stride + inner_idx; - for (uint32_t d = 0; d < dim_size; d++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(tile_idx, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - tile_idx += dim_stride; - } - - tile_idx = outer_idx * outer_stride + inner_idx; - for (uint32_t d = 0; d < dim_size; d++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(tile_idx, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - tile_idx += dim_stride; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h.cpp deleted file mode 100644 index 2c2d074bc0c..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - -void kernel_main() { - uint32_t src_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Ht = get_arg_val(3); - uint32_t Wt = get_arg_val(4); - uint32_t scaler = get_arg_val(5); - uint32_t mask_h = get_arg_val(6); - - constexpr auto cb_in = tt::CB::c_in0; - constexpr auto cb_mask = tt::CB::c_in1; - constexpr auto cb_scaler = tt::CB::c_in2; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t src_in_tile_bytes = get_tile_size(cb_in); - const DataFormat src_in_data_format = get_dataformat(cb_in); - - constexpr bool in_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast src_in = { - .bank_base_address = src_addr, .page_size = src_in_tile_bytes, .data_format = src_in_data_format}; - - // TODO(AP): cleanup, probably with named args/param pack/reflection. - generate_bcast_scaler(cb_scaler, scaler); - generate_mask_h(cb_mask, mask_h); - - // read ublocks from src0 to CB0, then push ublocks to compute (unpacker) - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i += onetile) { - uint32_t w_idx = curr_tile % Wt; - uint32_t nc_idx = curr_tile / Wt; - uint32_t tile_idx = nc_idx * Ht * Wt + w_idx; - cb_reserve_back(cb_in, Ht); - l1_write_addr_in = get_write_ptr(cb_in); - for (uint32_t h = 0; h < Ht; h++) { - noc_async_read_tile(tile_idx, src_in, l1_write_addr_in); - l1_write_addr_in += src_in_tile_bytes; - tile_idx += Wt; - } - noc_async_read_barrier(); - cb_push_back(cb_in, Ht); - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h_large.cpp deleted file mode 100644 index 2cd5a54e485..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h_large.cpp +++ /dev/null @@ -1,77 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - -void kernel_main() { - uint32_t src_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Ht = get_arg_val(3); - uint32_t Wt = get_arg_val(4); - uint32_t scaler = get_arg_val(5); - uint32_t mask_h = get_arg_val(6); - - constexpr auto cb_in = tt::CB::c_in0; - constexpr auto cb_mask = tt::CB::c_in1; - constexpr auto cb_scaler = tt::CB::c_in2; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t src_in_tile_bytes = get_tile_size(cb_in); - const DataFormat src_in_data_format = get_dataformat(cb_in); - - constexpr bool in_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast src_in = { - .bank_base_address = src_addr, .page_size = src_in_tile_bytes, .data_format = src_in_data_format}; - - // TODO(AP): cleanup, probably with named args/param pack/reflection. - generate_bcast_scaler(cb_scaler, scaler); - generate_mask_h(cb_mask, mask_h); - - // read ublocks from src0 to CB0, then push ublocks to compute (unpacker) - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i += onetile) { - uint32_t w_idx = curr_tile % Wt; - uint32_t nc_idx = curr_tile / Wt; - uint32_t tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(tile_idx, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - tile_idx += Wt; - } - - w_idx = curr_tile % Wt; - nc_idx = curr_tile / Wt; - tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(tile_idx, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - tile_idx += Wt; - } - - w_idx = curr_tile % Wt; - nc_idx = curr_tile / Wt; - tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(tile_idx, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - tile_idx += Wt; - } - - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w.cpp deleted file mode 100644 index 3ea949d6d61..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w.cpp +++ /dev/null @@ -1,48 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - -void kernel_main() { - uint32_t src_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Wt = get_arg_val(3); - uint32_t scaler = get_arg_val(4); - uint32_t mask_w = get_arg_val(5); - - constexpr auto cb_in = tt::CB::c_in0; - constexpr auto cb_mask = tt::CB::c_in1; - constexpr auto cb_scaler = tt::CB::c_in2; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t src_in_tile_bytes = get_tile_size(cb_in); - const DataFormat src_in_data_format = get_dataformat(cb_in); - - constexpr bool in_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast src_in = { - .bank_base_address = src_addr, .page_size = src_in_tile_bytes, .data_format = src_in_data_format}; - - // TODO(AP): cleanup, probably with named args/param pack/reflection. - generate_bcast_scaler(cb_scaler, scaler); - generate_mask_w(cb_mask, mask_w); - - // read ublocks from src0 to CB0, then push ublocks to compute (unpacker) - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i += onetile) { - cb_reserve_back(cb_in, Wt); - l1_write_addr_in = get_write_ptr(cb_in); - for (uint32_t w = 0; w < Wt; w++) { - noc_async_read_tile(curr_tile, src_in, l1_write_addr_in); - l1_write_addr_in += src_in_tile_bytes; - curr_tile++; - } - noc_async_read_barrier(); - cb_push_back(cb_in, Wt); - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w_large.cpp deleted file mode 100644 index 188d8ce6e1d..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w_large.cpp +++ /dev/null @@ -1,68 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - -void kernel_main() { - uint32_t src_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Wt = get_arg_val(3); - uint32_t scaler = get_arg_val(4); - uint32_t mask_w = get_arg_val(5); - - constexpr auto cb_in = tt::CB::c_in0; - constexpr auto cb_mask = tt::CB::c_in1; - constexpr auto cb_scaler = tt::CB::c_in2; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t src_in_tile_bytes = get_tile_size(cb_in); - const DataFormat src_in_data_format = get_dataformat(cb_in); - - constexpr bool in_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast src_in = { - .bank_base_address = src_addr, .page_size = src_in_tile_bytes, .data_format = src_in_data_format}; - - // TODO(AP): cleanup, probably with named args/param pack/reflection. - generate_bcast_scaler(cb_scaler, scaler); - generate_mask_w(cb_mask, mask_w); - - // read ublocks from src0 to CB0, then push ublocks to compute (unpacker) - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i += onetile) { - uint32_t curr_offset_i = curr_tile; - for (uint32_t w = 0; w < Wt; w++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(curr_tile, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - curr_tile++; - } - - curr_tile = curr_offset_i; - for (uint32_t w = 0; w < Wt; w++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(curr_tile, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - curr_tile++; - } - - curr_tile = curr_offset_i; - for (uint32_t w = 0; w < Wt; w++) { - cb_reserve_back(cb_in, onetile); - l1_write_addr_in = get_write_ptr(cb_in); - noc_async_read_tile(curr_tile, src_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_in, onetile); - curr_tile++; - } - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_c_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_c_large.cpp deleted file mode 100644 index fcec659ae37..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_c_large.cpp +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t num_tiles = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t outer_stride = get_arg_val(3); - uint32_t inner_size = get_arg_val(4); - uint32_t dim_size = get_arg_val(5); - - constexpr auto cb_out = tt::CB::c_out0; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t dst_out_tile_bytes = get_tile_size(cb_out); - const DataFormat dst_out_data_format = get_dataformat(cb_out); - - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast dst_out = { - .bank_base_address = dst_addr, .page_size = dst_out_tile_bytes, .data_format = dst_out_data_format}; - - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < num_tiles; i += onetile) { - uint32_t outer_idx = curr_tile / (inner_size); - uint32_t inner_idx = curr_tile % inner_size; - uint32_t tile_idx = outer_idx * outer_stride + inner_idx; - - uint32_t dim_stride = inner_size; - for (uint32_t d = 0; d < dim_size; d++) { - cb_wait_front(cb_out, onetile); - uint32_t l1_read_addr = get_read_ptr(cb_out); - noc_async_write_tile(tile_idx, dst_out, l1_read_addr); - noc_async_write_barrier(); - cb_pop_front(cb_out, onetile); - tile_idx += dim_stride; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h.cpp deleted file mode 100644 index 72a281e2dfd..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Ht = get_arg_val(3); - uint32_t Wt = get_arg_val(4); - - constexpr uint32_t cb_id_out = tt::CB::c_out0; - constexpr uint32_t onetile = 1; - uint32_t tile_bytes = get_tile_size(cb_id_out); - - const DataFormat data_format = get_dataformat(cb_id_out); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast s = { - .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; - - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i++) { - uint32_t w_idx = curr_tile % Wt; - uint32_t nc_idx = curr_tile / Wt; - uint32_t tile_idx = nc_idx * Ht * Wt + w_idx; - - - cb_wait_front(cb_id_out, Ht); - auto l1_read_addr = get_read_ptr(cb_id_out); - for (uint32_t h = 0; h < Ht; h++) { - noc_async_write_tile(tile_idx, s, l1_read_addr); - l1_read_addr += tile_bytes; - tile_idx += Wt; - } - noc_async_write_barrier(); - cb_pop_front(cb_id_out, Ht); - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h_large.cpp deleted file mode 100644 index bbad6708df3..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h_large.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Ht = get_arg_val(3); - uint32_t Wt = get_arg_val(4); - - constexpr uint32_t cb_id_out = tt::CB::c_out0; - constexpr uint32_t onetile = 1; - uint32_t tile_bytes = get_tile_size(cb_id_out); - - const DataFormat data_format = get_dataformat(cb_id_out); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast s = { - .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; - - uint32_t blk = 1; - - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i++) { - uint32_t w_idx = curr_tile % Wt; - uint32_t nc_idx = curr_tile / Wt; - uint32_t tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - cb_wait_front(cb_id_out, blk); - uint32_t l1_read_addr = get_read_ptr(cb_id_out); - noc_async_write_tile(tile_idx, s, l1_read_addr); - noc_async_write_barrier(); - cb_pop_front(cb_id_out, blk); - tile_idx += Wt; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w.cpp deleted file mode 100644 index 45cf65005e2..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w.cpp +++ /dev/null @@ -1,35 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Wt = get_arg_val(3); - - constexpr uint32_t cb_id_out = tt::CB::c_out0; - constexpr uint32_t onetile = 1; - uint32_t tile_bytes = get_tile_size(cb_id_out); - - const DataFormat data_format = get_dataformat(cb_id_out); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast s = { - .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; - - uint32_t tile_id = tile_offset; - for (uint32_t i = 0; i < N; i++) { - cb_wait_front(cb_id_out, Wt); - auto l1_read_addr = get_read_ptr(cb_id_out); - for (uint32_t w = 0; w < Wt; w++) { - noc_async_write_tile(tile_id, s, l1_read_addr); - l1_read_addr += tile_bytes; - tile_id++; - } - noc_async_write_barrier(); - cb_pop_front(cb_id_out, Wt); - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w_large.cpp deleted file mode 100644 index 883562f2c8f..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w_large.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Wt = get_arg_val(3); - - constexpr uint32_t cb_id_out = tt::CB::c_out0; - constexpr uint32_t onetile = 1; - uint32_t tile_bytes = get_tile_size(cb_id_out); - - const DataFormat data_format = get_dataformat(cb_id_out); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast s = { - .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; - - uint32_t blk = 1; - - uint32_t tile_id = tile_offset; - for (uint32_t i = 0; i < N; i++) { - for (uint32_t w = 0; w < Wt; w++) { - cb_wait_front(cb_id_out, blk); - uint32_t l1_read_addr = get_read_ptr(cb_id_out); - noc_async_write_tile(tile_id, s, l1_read_addr); - noc_async_write_barrier(); - cb_pop_front(cb_id_out, blk); - tile_id++; - } - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.cpp deleted file mode 100644 index f81696a2d1c..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.cpp +++ /dev/null @@ -1,263 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" - -#include "ttnn/run_operation.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -void MorehSoftmax::validate_with_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const { - // validate input tensor - auto& input_tensor = input_tensors.at(0); - TT_ASSERT(input_tensor.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!"); - TT_ASSERT(input_tensor.buffer() != nullptr, "Operands to softmax need to be allocated in buffers on device!"); - TT_ASSERT((input_tensor.get_layout() == Layout::TILE), "Inputs to softmax must be tilized"); - TT_ASSERT(input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::BFLOAT8_B); - - // validate parameters - auto rank = input_tensor.get_legacy_shape().rank(); - - TT_ASSERT( - this->dim >= 0 && this->dim < rank, - "dim {} should be less than output tensor rank {}", this->dim, rank); - - if (output_tensors.empty() || !output_tensors.at(0).has_value()) { - // If the user decided to not use any optional output tensors, then this would be empty or would be a nullptr. - return; - } - TT_ASSERT(input_tensors.size() == 1, "Must have 1 input tensors"); - TT_ASSERT(output_tensors.size() == 1, "Must have 1 output tensors"); -} - -std::vector MorehSoftmax::compute_output_shapes(const std::vector& input_tensors) const { - return {input_tensors.at(0).get_legacy_shape()}; -} - -std::vector MorehSoftmax::create_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const { - if (!output_tensors.empty() && output_tensors.at(0).has_value()) { - return {output_tensors.at(0).value()}; - } - const auto& output_shape = input_tensors.at(0).get_legacy_shape(); - - return {operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config)}; -} - -operation::ProgramWithCallbacks MorehSoftmax::create_program( - const std::vector& input_tensors, std::vector& output_tensors) const { - auto& input = input_tensors.at(0); - auto& output = output_tensors.at(0); - - auto parallelization_strategy = this->get_parallelization_strategy(input_tensors); - - switch (parallelization_strategy) { - case MorehSoftmaxOpParallelizationStrategy::SMALL_W: - return {moreh_softmax_w_small(input, output, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxOpParallelizationStrategy::SMALL_H: - return {moreh_softmax_h_small(input, output, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxOpParallelizationStrategy::LARGE_W: - return {moreh_softmax_w_large(input, output, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxOpParallelizationStrategy::LARGE_H: - return {moreh_softmax_h_large(input, output, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxOpParallelizationStrategy::LARGE_C: - return {moreh_softmax_c_large( - input, output, this->dim, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxOpParallelizationStrategy::NONE: - default: break; - } - - return {moreh_softmax_h_large(input, output, this->core_range, this->op, this->compute_kernel_config)}; -} - -MorehSoftmaxOpParallelizationStrategy MorehSoftmax::get_parallelization_strategy( - const std::vector& input_tensors) const { - const auto& input = input_tensors.at(0); - - auto rank = input.get_legacy_shape().rank(); - if (this->strategy == MorehSoftmaxOpParallelizationStrategy::NONE) { - if (rank - 1 == this->dim) { - if (is_moreh_softmax_w_small_available(input, this->compute_kernel_config)) { - return MorehSoftmaxOpParallelizationStrategy::SMALL_W; - } - return MorehSoftmaxOpParallelizationStrategy::LARGE_W; - } - if (rank - 2 == this->dim) { - if (is_moreh_softmax_h_small_available(input, this->compute_kernel_config)) { - return MorehSoftmaxOpParallelizationStrategy::SMALL_H; - } - return MorehSoftmaxOpParallelizationStrategy::LARGE_H; - } - return MorehSoftmaxOpParallelizationStrategy::LARGE_C; - } - - if (rank - 2 == this->dim) { - TT_ASSERT( - this->strategy == MorehSoftmaxOpParallelizationStrategy::SMALL_H || - this->strategy == MorehSoftmaxOpParallelizationStrategy::LARGE_H, - "Invalid parallelization strategy. {} is not for dim H", this->strategy); - - if (this->strategy == MorehSoftmaxOpParallelizationStrategy::SMALL_H) { - TT_ASSERT( - is_moreh_softmax_h_small_available(input, this->compute_kernel_config), - "not enough circular buffer memory for {}", this->strategy); - } - } else if (rank - 1 == this->dim) { - TT_ASSERT( - this->strategy == MorehSoftmaxOpParallelizationStrategy::SMALL_W || - this->strategy == MorehSoftmaxOpParallelizationStrategy::LARGE_W, - "Invalid parallelization strategy. {} is not for dim W", this->strategy); - - if (this->strategy == MorehSoftmaxOpParallelizationStrategy::SMALL_W) { - TT_ASSERT( - is_moreh_softmax_w_small_available(input, this->compute_kernel_config), - "not enough circular buffer memory for {}", this->strategy); - } - } else { - TT_ASSERT( - this->strategy == MorehSoftmaxOpParallelizationStrategy::LARGE_C, - "Invalid parallelization strategy. large c is for dim 0 to (rank - 3)"); - } - - return this->strategy; -} - -Tensor moreh_softmax( - const Tensor& input_tensor, - uint32_t dim, - std::optional output_tensor, - const MorehSoftmaxOpParallelizationStrategy strategy, - const MemoryConfig& output_mem_config, - std::optional compute_kernel_config) { - auto device = input_tensor.device(); - auto grid_coord = device->compute_with_storage_grid_size(); - const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); - - auto kernel_config_val = - init_device_compute_kernel_config(device->arch(), compute_kernel_config, MathFidelity::HiFi4); - - std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; - - operation::launch_op( - [dim, all_cores, strategy, output_mem_config, kernel_config_val]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run( - MorehSoftmax{ - .dim = dim, - .core_range = all_cores, - .op = MorehSoftmaxOp::SOFTMAX, - .strategy = strategy, - .output_mem_config = output_mem_config, - .compute_kernel_config = kernel_config_val}, - input_tensors, - optional_input_tensors, - optional_output_tensors); - }, - {input_tensor}, - output_tensors, - {}, - {output_tensor}); - - return output_tensors.at(0); -} - -Tensor moreh_softmin( - const Tensor& input_tensor, - uint32_t dim, - std::optional output_tensor, - const MorehSoftmaxOpParallelizationStrategy strategy, - const MemoryConfig& output_mem_config, - std::optional compute_kernel_config) { - auto device = input_tensor.device(); - auto grid_coord = device->compute_with_storage_grid_size(); - const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); - - auto kernel_config_val = - init_device_compute_kernel_config(device->arch(), compute_kernel_config, MathFidelity::HiFi4); - - std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; - - operation::launch_op( - [dim, all_cores, strategy, output_mem_config, kernel_config_val]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run( - MorehSoftmax{ - .dim = dim, - .core_range = all_cores, - .op = MorehSoftmaxOp::SOFTMIN, - .strategy = strategy, - .output_mem_config = output_mem_config, - .compute_kernel_config = kernel_config_val}, - input_tensors, - optional_input_tensors, - optional_output_tensors); - }, - {input_tensor}, - output_tensors, - {}, - {output_tensor}); - - return output_tensors.at(0); -} - -Tensor moreh_logsoftmax( - const Tensor& input_tensor, - uint32_t dim, - std::optional output_tensor, - const MorehSoftmaxOpParallelizationStrategy strategy, - const MemoryConfig& output_mem_config, - std::optional compute_kernel_config) { - auto device = input_tensor.device(); - auto grid_coord = device->compute_with_storage_grid_size(); - const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); - - auto kernel_config_val = - init_device_compute_kernel_config(device->arch(), compute_kernel_config, MathFidelity::HiFi4); - - std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input_tensor}))}; - - operation::launch_op( - [dim, all_cores, strategy, output_mem_config, kernel_config_val]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run( - MorehSoftmax{ - .dim = dim, - .core_range = all_cores, - .op = MorehSoftmaxOp::LOGSOFTMAX, - .strategy = strategy, - .output_mem_config = output_mem_config, - .compute_kernel_config = kernel_config_val}, - input_tensors, - optional_input_tensors, - optional_output_tensors); - }, - {input_tensor}, - output_tensors, - {}, - {output_tensor}); - - return output_tensors.at(0); -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp deleted file mode 100644 index 97ea874759d..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp +++ /dev/null @@ -1,95 +0,0 @@ -/* - * SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - * - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "ttnn/operation.hpp" -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" -#include - -namespace tt { -namespace operations { -namespace primary { - -using namespace tt_metal; - -enum class MorehSoftmaxOpParallelizationStrategy { - NONE, - SMALL_W, - SMALL_H, - LARGE_W, - LARGE_H, - LARGE_C, -}; - -enum class MorehSoftmaxOp { - SOFTMAX, - SOFTMIN, - LOGSOFTMAX, -}; - -bool is_moreh_softmax_w_small_available(const Tensor &tensor, const ttnn::DeviceComputeKernelConfig& compute_kernel_config); -bool is_moreh_softmax_h_small_available(const Tensor &tensor, const ttnn::DeviceComputeKernelConfig& compute_kernel_config); - -operation::ProgramWithCallbacks moreh_softmax_w_small( - const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config); -operation::ProgramWithCallbacks moreh_softmax_w_large( - const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config); -operation::ProgramWithCallbacks moreh_softmax_h_small( - const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config); -operation::ProgramWithCallbacks moreh_softmax_h_large( - const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config); -operation::ProgramWithCallbacks moreh_softmax_c_large( - const Tensor &input, const Tensor &output, uint32_t dim, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config); - -struct MorehSoftmax { - const uint32_t dim; - const CoreRange core_range; // unused for now - const MorehSoftmaxOp op; - const MorehSoftmaxOpParallelizationStrategy strategy; - const MemoryConfig output_mem_config; - const ttnn::DeviceComputeKernelConfig compute_kernel_config; - - void validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; - MorehSoftmaxOpParallelizationStrategy get_parallelization_strategy(const std::vector &input_tensors) const; - static constexpr auto attribute_names = std::make_tuple("dim", "op", "strategy", "output_mem_config", "compute_kernel_config"); - const auto attribute_values() const { - return std::make_tuple(std::cref(this->dim), std::cref(this->op), std::cref(this->strategy), std::cref(this->output_mem_config), std::cref(this->compute_kernel_config)); - } -}; - -// const ref prevents -Tensor moreh_softmax( - const Tensor &input_tensor, - uint32_t dim, - std::optional output_tensor = std::nullopt, - const MorehSoftmaxOpParallelizationStrategy strategy = MorehSoftmaxOpParallelizationStrategy::NONE, - const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - std::optional compute_kernel_config = std::nullopt); - -Tensor moreh_softmin( - const Tensor &input_tensor, - uint32_t dim, - std::optional output_tensor = std::nullopt, - const MorehSoftmaxOpParallelizationStrategy strategy = MorehSoftmaxOpParallelizationStrategy::NONE, - const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - std::optional compute_kernel_config = std::nullopt); - -Tensor moreh_logsoftmax( - const Tensor &input_tensor, - uint32_t dim, - std::optional output_tensor = std::nullopt, - const MorehSoftmaxOpParallelizationStrategy strategy = MorehSoftmaxOpParallelizationStrategy::NONE, - const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - std::optional compute_kernel_config = std::nullopt); - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp deleted file mode 100644 index c7bebf27af8..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp +++ /dev/null @@ -1,143 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -operation::ProgramWithCallbacks moreh_softmax_c_large(const Tensor &input, const Tensor &output, uint32_t dim, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Large tensor algorithm selected"); - // split work - auto shape = input.get_legacy_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - uint32_t num_tiles = input.volume() / shape[dim] / H / W * Ht * Wt; - - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_tiles); - - auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - auto data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, 2}, // input - {CB::c_out0, 2}, // output - {CB::c_intermed0, 1, intermed_data_format}, // exp(x) - {CB::c_intermed1, 1, intermed_data_format}, // recips - {CB::c_intermed2, 2, intermed_data_format}, // add - {CB::c_intermed3, 1}, // max - {CB::c_intermed4, 1, intermed_data_format}, // tmp - }); - - // create read/wrtie kernel - bool src_is_dram = input.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_c_large.cpp", all_cores, {src_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_c_large.cpp", all_cores, {dst_is_dram}, writer_defines); - - auto outer_stride = Ht * Wt; - for(int i = dim ; i < shape.rank() - 2; i++ ) { - outer_stride *= shape[i]; - } - auto dim_size = shape[dim]; - auto inner_size = outer_stride / dim_size; - - std::map compute_defines; - if (op == MorehSoftmaxOp::SOFTMAX || op == MorehSoftmaxOp::LOGSOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxOp::LOGSOFTMAX) { - compute_defines["LOG"] = "1"; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_c_large.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, dim_size}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, dim_size}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - vector reader_args = { - input.buffer()->address(), num_tiles_per_core, tile_offset, - outer_stride, inner_size, - dim_size}; - - vector writer_args = {output.buffer()->address(), num_tiles_per_core, tile_offset, - outer_stride, inner_size, - dim_size}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp deleted file mode 100644 index d2ae2936535..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp +++ /dev/null @@ -1,137 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -operation::ProgramWithCallbacks moreh_softmax_h_large(const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Large tensor algorithm selected"); - // split work - auto shape = input.get_padded_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - auto num = input.volume() / H / W; - uint32_t num_cols_tiles = num * Wt; - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_cols_tiles); - - auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - auto data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, 2}, // input - {CB::c_in1, 1}, // mask - {CB::c_in2, 1}, // scaler - {CB::c_out0, 2}, // output - {CB::c_intermed0, 2, intermed_data_format}, // exp(x) - {CB::c_intermed1, 1, intermed_data_format}, // reduce - {CB::c_intermed2, 1, intermed_data_format}, // syn - {CB::c_intermed3, 1, intermed_data_format}, // max - {CB::c_intermed4, 1, intermed_data_format}, // tmp - }); - - // create read/wrtie kernel - bool src_is_dram = input.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h_large.cpp", all_cores, {src_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h_large.cpp", all_cores, {dst_is_dram}, writer_defines); - - std::map compute_defines; - if (op == MorehSoftmaxOp::SOFTMAX || op == MorehSoftmaxOp::LOGSOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxOp::LOGSOFTMAX) { - compute_defines["LOG"] = "1"; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h_large.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, Ht}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, Ht}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - float scaler = 1.0f; - uint32_t mask_h = input.get_logical_shape()[-2] % TILE_HEIGHT; - if(mask_h == 0) mask_h = TILE_HEIGHT; - vector reader_args = { - input.buffer()->address(), num_tiles_per_core, tile_offset, Ht, Wt, *reinterpret_cast(&scaler), mask_h}; - - vector writer_args = {output.buffer()->address(), num_tiles_per_core, tile_offset, Ht, Wt}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp deleted file mode 100644 index 2a63ac8af88..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp +++ /dev/null @@ -1,170 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -#define L1_512KB (512 * 1024) - -bool is_moreh_softmax_h_small_available(const Tensor &tensor, const ttnn::DeviceComputeKernelConfig& compute_kernel_config) { - auto h = tensor.get_legacy_shape()[-2]; - int32_t Ht = (h + TILE_HEIGHT - 1) / TILE_HEIGHT; - - auto arch = tensor.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - auto data_format = tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); - auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - - auto tile_size = tt_metal::detail::TileSize(data_format); - auto intermed_tile_size = tt_metal::detail::TileSize(intermed_data_format); - - int32_t cb_usage = 0; // bytes - cb_usage += Ht * tile_size; // input; - cb_usage += 1 * tile_size; // mask; - cb_usage += 1 * tile_size; // scaler; - - cb_usage += Ht * tile_size; // output; - - cb_usage += Ht * intermed_tile_size; // exp(x); - cb_usage += 1 * intermed_tile_size; // reduce; - cb_usage += 1 * intermed_tile_size; // max; - cb_usage += Ht * intermed_tile_size; // x - max; - cb_usage += 1 * intermed_tile_size; // tmp; - - return (tensor.device()->get_base_allocator_addr(HalMemType::L1) + cb_usage <= L1_512KB); -} - -operation::ProgramWithCallbacks moreh_softmax_h_small(const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Small tensor algorithm selected"); - // split work - auto shape = input.get_padded_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - auto num = input.volume() / H / W; - uint32_t num_cols_tiles = num * Wt; - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_cols_tiles); - - auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - auto data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, Ht}, // input - {CB::c_in1, 1}, // mask - {CB::c_in2, 1}, // scaler - {CB::c_out0, Ht}, // output - {CB::c_intermed0, Ht, intermed_data_format}, // exp(x) - {CB::c_intermed1, 1, intermed_data_format}, // reduce - {CB::c_intermed2, 1, intermed_data_format}, // max - {CB::c_intermed3, Ht, intermed_data_format}, // x - max - {CB::c_intermed4, 1, intermed_data_format} // tmp - }); - - // create read/wrtie kernel - bool src_is_dram = input.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_h.cpp", all_cores, {src_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_h.cpp", all_cores, {dst_is_dram}, writer_defines); - - std::map compute_defines; - if (op == MorehSoftmaxOp::SOFTMAX || op == MorehSoftmaxOp::LOGSOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxOp::LOGSOFTMAX) { - compute_defines["LOG"] = "1"; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_h.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, Ht}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, Ht}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - float scaler = 1.0f; - uint32_t mask_h = input.get_logical_shape()[-2] % TILE_HEIGHT; - if(mask_h == 0) mask_h = TILE_HEIGHT; - vector reader_args = { - input.buffer()->address(), num_tiles_per_core, tile_offset, Ht, Wt, *reinterpret_cast(&scaler), mask_h}; - - vector writer_args = {output.buffer()->address(), num_tiles_per_core, tile_offset, Ht, Wt}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp deleted file mode 100644 index d29342441ae..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp +++ /dev/null @@ -1,138 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -operation::ProgramWithCallbacks moreh_softmax_w_large(const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Large tensor algorithm selected"); - // split work - auto shape = input.get_padded_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - auto num = input.volume() / H / W; - - uint32_t num_kernel_rows = num * Ht; - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_kernel_rows); - - auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - auto data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, 2}, // input - {CB::c_in1, 1}, // mask - {CB::c_in2, 1}, // scaler - {CB::c_out0, 2}, // output - {CB::c_intermed0, 2, intermed_data_format}, // exp(x) - {CB::c_intermed1, 1, intermed_data_format}, // reduce - {CB::c_intermed2, 1, intermed_data_format}, // syn - {CB::c_intermed3, 1, intermed_data_format}, // max - {CB::c_intermed4, 1, intermed_data_format}, // tmp - }); - - // create read/wrtie kernel - bool src_is_dram = input.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w_large.cpp", all_cores, {src_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w_large.cpp", all_cores, {dst_is_dram}, writer_defines); - - std::map compute_defines; - if (op == MorehSoftmaxOp::SOFTMAX || op == MorehSoftmaxOp::LOGSOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxOp::LOGSOFTMAX) { - compute_defines["LOG"] = "1"; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w_large.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, Wt}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, Wt}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - float scaler = 1.0f; - uint32_t mask_w = input.get_logical_shape()[-1] % TILE_WIDTH; - if(mask_w == 0) mask_w = TILE_WIDTH; - vector reader_args = { - input.buffer()->address(), num_tiles_per_core, tile_offset, Wt, *reinterpret_cast(&scaler), mask_w}; - - vector writer_args = {output.buffer()->address(), num_tiles_per_core, tile_offset, Wt}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core * Wt; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp deleted file mode 100644 index 2a05e42095a..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp +++ /dev/null @@ -1,170 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -#define L1_512KB (512 * 1024) - -bool is_moreh_softmax_w_small_available(const Tensor &tensor, const ttnn::DeviceComputeKernelConfig& compute_kernel_config) { - auto w = tensor.get_legacy_shape()[-1]; - int32_t Wt = (w + TILE_WIDTH - 1) / TILE_WIDTH; - - auto arch = tensor.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - auto data_format = tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); - auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - - auto tile_size = tt_metal::detail::TileSize(data_format); - auto intermed_tile_size = tt_metal::detail::TileSize(intermed_data_format); - - int32_t cb_usage = 0; // bytes - cb_usage += Wt * tile_size; // input; - cb_usage += 1 * tile_size; // mask; - cb_usage += 1 * tile_size; // scaler; - - cb_usage += Wt * tile_size; // output; - - cb_usage += Wt * intermed_tile_size; // exp(x); - cb_usage += 1 * intermed_tile_size; // reduce; - cb_usage += 1 * intermed_tile_size; // max; - cb_usage += Wt * intermed_tile_size; // x - max; - cb_usage += 1 * intermed_tile_size; // tmp; - - return (tensor.device()->get_base_allocator_addr(HalMemType::L1) + cb_usage <= L1_512KB); -} - -operation::ProgramWithCallbacks moreh_softmax_w_small(const Tensor &input, const Tensor &output, const CoreRange core_range, const MorehSoftmaxOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Small tensor algorithm selected"); - // split work - auto shape = input.get_padded_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - auto num = input.volume() / H / W; - - uint32_t num_kernel_rows = num * Ht; - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_kernel_rows); - - auto arch = input.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - auto data_format = tt_metal::datatype_to_dataformat_converter(input.get_dtype()); - auto intermed_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, Wt}, // input - {CB::c_in1, 1}, // mask - {CB::c_in2, 1}, // scaler - {CB::c_out0, Wt}, // output - {CB::c_intermed0, Wt, intermed_data_format}, // exp(x) - {CB::c_intermed1, 1, intermed_data_format}, // reduce - {CB::c_intermed2, 1, intermed_data_format}, // max - {CB::c_intermed3, Wt, intermed_data_format}, // x - max - {CB::c_intermed4, 1, intermed_data_format} // tmp - }); - - // create read/wrtie kernel - bool src_is_dram = input.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dst_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/reader_moreh_softmax_w.cpp", all_cores, {src_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/writer_moreh_softmax_w.cpp", all_cores, {dst_is_dram}, writer_defines); - - std::map compute_defines; - if (op == MorehSoftmaxOp::SOFTMAX || op == MorehSoftmaxOp::LOGSOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxOp::LOGSOFTMAX) { - compute_defines["LOG"] = "1"; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax/kernels/moreh_softmax_w.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, Wt}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, Wt}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - float scaler = 1.0f; - uint32_t mask_w = input.get_logical_shape()[-1] % TILE_WIDTH; - if(mask_w == 0) mask_w = TILE_WIDTH; - vector reader_args = { - input.buffer()->address(), num_tiles_per_core, tile_offset, Wt, *reinterpret_cast(&scaler), mask_w}; - - vector writer_args = {output.buffer()->address(), num_tiles_per_core, tile_offset, Wt}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core * Wt; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_c_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_c_large.cpp deleted file mode 100644 index 7fe79390b21..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_c_large.cpp +++ /dev/null @@ -1,90 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_ROW - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { - -void MAIN { - constexpr uint32_t onetile = 1; - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_dx = tt::CB::c_out0; - - constexpr auto cb_ydy = tt::CB::c_intermed0; // y * dy - constexpr auto cb_sum = tt::CB::c_intermed1; - constexpr auto cb_dy_m_sum = tt::CB::c_intermed2; // dy - sum - - uint32_t N = get_compile_time_arg_val(0); - uint32_t dim_size = get_compile_time_arg_val(1); - - binary_op_init_common(cb_dy, cb_y); - - constexpr int dst0 = 0; - for (uint32_t n = 0; n < N; ++n) { - #ifdef LOG - for (uint32_t i = 0; i < dim_size; ++i) { - if (i == 0) { - copy_tile_to_cb(cb_dy, cb_sum); - } else { - add_tiles_to_cb(cb_sum, cb_dy, cb_sum); - } - } - - for (uint32_t i = 0; i < dim_size; ++i) { - // exp(y) - constexpr auto cb_exp = tt::CB::c_intermed0; - exp_tile_to_cb(cb_y, cb_exp); - - // sum * exp(y) - constexpr auto cb_inter2 = tt::CB::c_intermed2; - mul_tiles_to_cb(cb_sum, cb_exp, cb_inter2, 0, 0, /*pop0=*/0, /*pop1=*/1); - - // dy - sum * exp(y) - sub_tiles_to_cb(cb_dy, cb_inter2, cb_dx); - } - cb_pop_front(cb_sum, onetile); - #else - // compute sum(y * dy) - for (uint32_t i = 0; i < dim_size; ++i) { - mul_tiles_to_cb(cb_y, cb_dy, cb_ydy); - - if (i == 0) { - copy_tile_to_cb(cb_ydy, cb_sum); - } else { - add_tiles_to_cb(cb_sum, cb_ydy, cb_sum); - } - } - - // compute final result - for (uint32_t i = 0; i < dim_size; ++i) { - // dy - sum - sub_tiles_to_cb( - cb_dy, - cb_sum, - cb_dy_m_sum, - /*itile0=*/0, - /*itile1=*/0, - /*pop0=*/1, - /*pop1=*/0); - - #ifdef SOFTMAX - // (dy - sum) * y - mul_tiles_to_cb(cb_dy_m_sum, cb_y, cb_dx); - #else - // -(dy - sum) * y - mul_tiles_and_negative_to_cb(cb_dy_m_sum, cb_y, cb_dx); - #endif - } - cb_pop_front(cb_sum, onetile); - #endif - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h.cpp deleted file mode 100644 index 912316de588..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h.cpp +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_COL - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { -void MAIN { - constexpr uint32_t onetile = 1; - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_bcast_scaler = tt::CB::c_in2; - constexpr auto cb_mask = tt::CB::c_in3; - constexpr auto cb_dx = tt::CB::c_out0; - - constexpr auto cb_ydy = tt::CB::c_intermed0; // y * dy - constexpr auto cb_sum = tt::CB::c_intermed1; - constexpr auto cb_inter2 = tt::CB::c_intermed2; - - binary_op_init_common(cb_y, cb_bcast_scaler); - - uint32_t N = get_compile_time_arg_val(0); - uint32_t Ht = get_compile_time_arg_val(1); - - for (uint32_t n = 0; n < N; ++n) { - #ifdef LOG - // sum(dy) - if (Ht == 1) { - // apply mask - mask_tile_to_cb(cb_dy, cb_mask, cb_inter2, /*itile=*/0, /*mtile=*/0, /*pop=*/0, /*popm=*/0); - - reduce_tile_to_cb(cb_inter2, cb_bcast_scaler, cb_sum, 1, /*pop0=*/1, /*pop=1*/0); - } else { - constexpr auto cb_inter0 = tt::CB::c_intermed0; - reduce_tile_to_cb(cb_dy, cb_bcast_scaler, cb_inter0, Ht - 1, /*pop0=*/0, /*pop=1*/0); - - constexpr auto cb_inter1 = tt::CB::c_intermed1; - mask_tile_to_cb(cb_dy, cb_mask, cb_inter1, /*itile=*/Ht - 1, /*mtile=*/0, /*pop=*/0, /*popm=*/0); - - constexpr auto cb_inter2 = tt::CB::c_intermed2; - reduce_tile_to_cb(cb_inter1, cb_bcast_scaler, cb_inter2, 1, /*pop0=*/1, /*pop=1*/0); - - add_tiles_to_cb(cb_inter0, cb_inter2, cb_sum); - } - - - // dy - sum * exp(y) - constexpr auto cb_exp = tt::CB::c_intermed0; // y * dy - - for (uint32_t w = 0; w < Ht; w += onetile) { - // exp(y) - exp_tile_to_cb(cb_y, cb_exp, w, /*dst=*/0, /*pop=*/0); - - // sum * exp(y) - mul_tiles_bcast_rows_to_cb(cb_exp, cb_sum, cb_inter2, 0, 0, /*pop0=*/1, /*pop1=*/0); - - // dy - sum * exp(y) - sub_tiles_to_cb(cb_dy, cb_inter2, cb_dx, w, 0, /*pop0=*/0, /*pop1=*/1); - } - - cb_pop_front(cb_sum, onetile); - cb_pop_front(cb_y, Ht); - cb_pop_front(cb_dy, Ht); - #else - // step 1, compute y * dy - for (uint32_t h = 0; h < Ht; ++h) { - if (h == Ht - 1) { - mul_tiles_and_mask_tile_to_cb( - cb_y, cb_dy, cb_mask, cb_ydy, h, h, 0, /*pop0=*/0, /*pop1=*/0, /*popm=*/0); - } else { - mul_tiles_to_cb(cb_y, cb_dy, cb_ydy, h, h, /*pop0=*/0, /*pop1=*/0); - } - } - - // step 2, compute sum(y * dy) - reduce_tile_to_cb(cb_ydy, cb_bcast_scaler, cb_sum, Ht, /*pop0=*/Ht, /*pop=1*/0); - - // step 3, compute final result - for (uint32_t h = 0; h < Ht; ++h) { - // dy - sum - sub_tiles_bcast_rows_to_cb(cb_dy, cb_sum, cb_inter2, h, 0, /*pop0=*/0, /*pop1=*/0); - - #ifdef SOFTMAX - // (dy - sum) * y - mul_tiles_to_cb(cb_y, cb_inter2, cb_dx, h, 0, /*pop0=*/0, /*pop1=*/1); - #else - // -(dy - sum) * y - mul_tiles_and_negative_to_cb(cb_y, cb_inter2, cb_dx, h, 0, /*pop0=*/0, /*pop1=*/1); - #endif - } - - cb_pop_front(cb_sum, onetile); - cb_pop_front(cb_dy, Ht); - cb_pop_front(cb_y, Ht); - #endif - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h_large.cpp deleted file mode 100644 index e75de0b5cd9..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h_large.cpp +++ /dev/null @@ -1,110 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_COL - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { -void MAIN { - constexpr uint32_t onetile = 1; - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_bcast_scaler = tt::CB::c_in2; - constexpr auto cb_mask = tt::CB::c_in3; - constexpr auto cb_dx = tt::CB::c_out0; - - constexpr auto cb_ydy = tt::CB::c_intermed0; // y * dy - constexpr auto cb_sum = tt::CB::c_intermed1; - constexpr auto cb_inter2 = tt::CB::c_intermed2; - constexpr auto cb_add = tt::CB::c_intermed3; - - binary_op_init_common(cb_y, cb_bcast_scaler); - - uint32_t N = get_compile_time_arg_val(0); - uint32_t Ht = get_compile_time_arg_val(1); - - for (uint32_t n = 0; n < N; ++n) { - - #ifdef LOG - // sum(dy) - for (uint32_t h = 0; h < Ht; ++h) { - if (h == Ht - 1) { - if (h == 0){ - mask_tile_to_cb(cb_dy, cb_mask, cb_add, /*itile=*/0, /*mtile=*/0, /*pop=*/1, /*popm=*/0); - } else { - constexpr auto cb_inter0 = tt::CB::c_intermed0; - mask_tile_to_cb(cb_dy, cb_mask, cb_inter0, /*itile=*/0, /*mtile=*/0, /*pop=*/1, /*popm=*/0); - - add_tiles_to_cb(cb_add, cb_inter0, cb_add); - } - } else { - if (h == 0) { - copy_tile_to_cb(cb_dy, cb_add); - } - else { - add_tiles_to_cb(cb_add, cb_dy, cb_add); - } - } - } - - reduce_tile_to_cb(cb_add, cb_bcast_scaler, cb_sum, 1, /*pop0=*/1, /*pop1=*/0); - - for (uint32_t h = 0; h < Ht; ++h) { - // exp(y) - constexpr auto cb_exp = tt::CB::c_intermed0; - exp_tile_to_cb(cb_y, cb_exp, 0); - - // sum * exp(y) - mul_tiles_bcast_rows_to_cb(cb_exp, cb_sum, cb_inter2, 0, 0, /*pop0=*/1, /*pop1=*/0); - - // dy - sum * exp(y) - sub_tiles_to_cb(cb_dy, cb_inter2, cb_dx); - } - - cb_pop_front(cb_sum, onetile); - #else - - // step 1, compute y * dy - for (uint32_t h = 0; h < Ht; ++h) { - if (h == Ht - 1) { - mul_tiles_and_mask_tile_to_cb( - cb_y, cb_dy, cb_mask, cb_ydy, 0, 0, 0, /*pop0=*/1, /*pop1=*/1, /*popm=*/0); - } else { - mul_tiles_to_cb(cb_y, cb_dy, cb_ydy); - } - - if (h == 0) { - copy_tile_to_cb(cb_ydy, cb_add); - } else { - add_tiles_to_cb(cb_add, cb_ydy, cb_add); - } - } - - // step 2, compute sum(y * dy) - reduce_tile_to_cb(cb_add, cb_bcast_scaler, cb_sum, /*size=*/1, /*pop0=*/1, /*pop1=*/0); - - // step 3, compute final result - for (uint32_t h = 0; h < Ht; ++h) { - // dy - sum - sub_tiles_bcast_rows_to_cb(cb_dy, cb_sum, cb_inter2, 0, 0, /*pop0=*/1, /*pop1=*/0); - - #ifdef SOFTMAX - // (dy - sum) * y - mul_tiles_to_cb(cb_y, cb_inter2, cb_dx); - #else - // -(dy - sum) * y - mul_tiles_and_negative_to_cb(cb_y, cb_inter2, cb_dx); - #endif - } - - cb_pop_front(cb_sum, onetile); - #endif - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w.cpp deleted file mode 100644 index 55933075ae5..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w.cpp +++ /dev/null @@ -1,104 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_ROW - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { -void MAIN { - constexpr uint32_t onetile = 1; - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_bcast_scaler = tt::CB::c_in2; - constexpr auto cb_mask = tt::CB::c_in3; - constexpr auto cb_dx = tt::CB::c_out0; - - constexpr auto cb_ydy = tt::CB::c_intermed0; // y * dy - constexpr auto cb_sum = tt::CB::c_intermed1; - constexpr auto cb_inter2 = tt::CB::c_intermed2; - - binary_op_init_common(cb_y, cb_bcast_scaler); - - uint32_t N = get_compile_time_arg_val(0); - uint32_t Wt = get_compile_time_arg_val(1); - - for (uint32_t n = 0; n < N; ++n) { - - #ifdef LOG - // sum(dy) - if (Wt == 1) { - // apply mask - mask_tile_to_cb(cb_dy, cb_mask, cb_inter2, /*itile=*/0, /*mtile=*/0, /*pop=*/0, /*popm=*/0); - - reduce_tile_to_cb(cb_inter2, cb_bcast_scaler, cb_sum, 1, /*pop0=*/1, /*pop=1*/0); - } else { - constexpr auto cb_inter0 = tt::CB::c_intermed0; - reduce_tile_to_cb(cb_dy, cb_bcast_scaler, cb_inter0, Wt - 1, /*pop0=*/0, /*pop=1*/0); - - constexpr auto cb_inter1 = tt::CB::c_intermed1; - mask_tile_to_cb(cb_dy, cb_mask, cb_inter1, /*itile=*/Wt - 1, /*mtile=*/0, /*pop=*/0, /*popm=*/0); - - constexpr auto cb_inter2 = tt::CB::c_intermed2; - reduce_tile_to_cb(cb_inter1, cb_bcast_scaler, cb_inter2, 1, /*pop0=*/1, /*pop=1*/0); - - add_tiles_to_cb(cb_inter0, cb_inter2, cb_sum); - } - - // dy - sum * exp(y) - constexpr auto cb_exp = tt::CB::c_intermed0; // y * dy - - for (uint32_t w = 0; w < Wt; w += onetile) { - // exp(y) - exp_tile_to_cb(cb_y, cb_exp, w, /*dst=*/0, /*pop=*/0); - - // sum * exp(y) - mul_tiles_bcast_cols_to_cb(cb_exp, cb_sum, cb_inter2, 0, 0, /*pop0=*/1, /*pop1=*/0); - - // dy - sum * exp(y) - sub_tiles_to_cb(cb_dy, cb_inter2, cb_dx, w, 0, /*pop0=*/0, /*pop1=*/1); - } - - cb_pop_front(cb_sum, onetile); - cb_pop_front(cb_y, Wt); - cb_pop_front(cb_dy, Wt); - #else - // step 1, compute y * dy - for (uint32_t w = 0; w < Wt; ++w) { - if (w == Wt - 1) { - mul_tiles_and_mask_tile_to_cb( - cb_y, cb_dy, cb_mask, cb_ydy, w, w, 0, /*pop0=*/0, /*pop1=*/0, /*popm=*/0); - } else { - mul_tiles_to_cb(cb_y, cb_dy, cb_ydy, w, w, /*pop0=*/0, /*pop1=*/0); - } - } - - // step 2, compute sum(y * dy) - reduce_tile_to_cb(cb_ydy, cb_bcast_scaler, cb_sum, Wt, /*pop0=*/Wt, /*pop=1*/0); - - // step 3, compute final result - for (uint32_t w = 0; w < Wt; w += onetile) { - // dy - sum - sub_tiles_bcast_cols_to_cb(cb_dy, cb_sum, cb_inter2, w, 0, /*pop0=*/0, /*pop1=*/0); - - #ifdef SOFTMAX - // (dy - sum) * y - mul_tiles_to_cb(cb_y, cb_inter2, cb_dx, w, 0, /*pop0=*/0, /*pop1=*/1); - #else - // -(dy - sum) * y - mul_tiles_and_negative_to_cb(cb_y, cb_inter2, cb_dx, w, 0, /*pop0=*/0, /*pop1=*/1); - #endif - } - - cb_pop_front(cb_sum, onetile); - cb_pop_front(cb_dy, Wt); - cb_pop_front(cb_y, Wt); - #endif - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w_large.cpp deleted file mode 100644 index b3faa81623e..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w_large.cpp +++ /dev/null @@ -1,108 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include - -#define REDUCE_OP PoolType::SUM -#define REDUCE_DIM ReduceDim::REDUCE_ROW - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" - -namespace NAMESPACE { -void MAIN { - constexpr uint32_t onetile = 1; - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_bcast_scaler = tt::CB::c_in2; - constexpr auto cb_mask = tt::CB::c_in3; - constexpr auto cb_dx = tt::CB::c_out0; - - constexpr auto cb_ydy = tt::CB::c_intermed0; // y * dy - constexpr auto cb_sum = tt::CB::c_intermed1; - constexpr auto cb_inter2 = tt::CB::c_intermed2; - constexpr auto cb_add = tt::CB::c_intermed3; - - binary_op_init_common(cb_y, cb_bcast_scaler); - - uint32_t N = get_compile_time_arg_val(0); - uint32_t Wt = get_compile_time_arg_val(1); - - for (uint32_t n = 0; n < N; ++n) { - - #ifdef LOG - // sum(dy) - for (uint32_t w = 0; w < Wt; ++w) { - if (w == Wt - 1) { - if (w == 0){ - mask_tile_to_cb(cb_dy, cb_mask, cb_add, /*itile=*/0, /*mtile=*/0, /*pop=*/1, /*popm=*/0); - } else { - constexpr auto cb_inter0 = tt::CB::c_intermed0; - mask_tile_to_cb(cb_dy, cb_mask, cb_inter0, /*itile=*/0, /*mtile=*/0, /*pop=*/1, /*popm=*/0); - - add_tiles_to_cb(cb_add, cb_inter0, cb_add); - } - } else { - if (w == 0) { - copy_tile_to_cb(cb_dy, cb_add); - } - else { - add_tiles_to_cb(cb_add, cb_dy, cb_add); - } - } - } - - reduce_tile_to_cb(cb_add, cb_bcast_scaler, cb_sum, 1, /*pop0=*/1, /*pop1=*/0); - - for (uint32_t w = 0; w < Wt; w += onetile) { - // exp(y) - constexpr auto cb_exp = tt::CB::c_intermed0; - exp_tile_to_cb(cb_y, cb_exp, 0); - // sum * exp(y) - mul_tiles_bcast_cols_to_cb(cb_exp, cb_sum, cb_inter2, 0, 0, /*pop0=*/1, /*pop1=*/0); - - // dy - sum * exp(y) - sub_tiles_to_cb(cb_dy, cb_inter2, cb_dx); - } - - cb_pop_front(cb_sum, onetile); - #else - // step 1, compute y * dy - for (uint32_t w = 0; w < Wt; ++w) { - if (w == Wt - 1) { - mul_tiles_and_mask_tile_to_cb( - cb_y, cb_dy, cb_mask, cb_ydy, 0, 0, 0, /*pop0=*/1, /*pop1=*/1, /*popm=*/0); - } else { - mul_tiles_to_cb(cb_y, cb_dy, cb_ydy); - } - - if (w == 0) { - copy_tile_to_cb(cb_ydy, cb_add); - } else { - add_tiles_to_cb(cb_add, cb_ydy, cb_add); - } - } - - // step 2, compute sum(y * dy) - reduce_tile_to_cb(cb_add, cb_bcast_scaler, cb_sum, 1, /*pop0=*/1, /*pop1=*/0); - - // step 3, compute final result - for (uint32_t w = 0; w < Wt; w += onetile) { - // dy - sum - sub_tiles_bcast_cols_to_cb(cb_dy, cb_sum, cb_inter2, 0, 0, /*pop0=*/1, /*pop1=*/0); - - #ifdef SOFTMAX - // (dy - sum) * y - mul_tiles_to_cb(cb_y, cb_inter2, cb_dx); - #else - // -(dy - sum) * y - mul_tiles_and_negative_to_cb(cb_y, cb_inter2, cb_dx); - #endif - } - - cb_pop_front(cb_sum, onetile); - #endif - } -} -} // namespace NAMESPACE diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_c.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_c.cpp deleted file mode 100644 index 9f6ddaf6ad4..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_c.cpp +++ /dev/null @@ -1,82 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t y_addr = get_arg_val(0); - uint32_t dy_addr = get_arg_val(1); - - uint32_t num_tiles = get_arg_val(2); - uint32_t tile_offset = get_arg_val(3); - uint32_t outer_stride = get_arg_val(4); - uint32_t inner_size = get_arg_val(5); - uint32_t dim_size = get_arg_val(6); - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - - uint32_t y_tile_bytes = get_tile_size(cb_y); - const DataFormat y_data_format = get_dataformat(cb_y); - - uint32_t dy_tile_bytes = get_tile_size(cb_dy); - const DataFormat dy_data_format = get_dataformat(cb_dy); - - constexpr bool y_is_dram = get_compile_time_arg_val(0) == 1; - constexpr bool dy_is_dram = get_compile_time_arg_val(1) == 1; - - const InterleavedAddrGenFast y_in = { - .bank_base_address = y_addr, .page_size = y_tile_bytes, .data_format = y_data_format}; - - const InterleavedAddrGenFast dy_in = { - .bank_base_address = dy_addr, .page_size = dy_tile_bytes, .data_format = dy_data_format}; - - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < num_tiles; i += onetile) { - uint32_t outer_idx = curr_tile / (inner_size); - uint32_t inner_idx = curr_tile % inner_size; - uint32_t tile_idx = outer_idx * outer_stride + inner_idx; - - uint32_t dim_stride = inner_size; - for (uint32_t d = 0; d < dim_size; d++) { - #ifndef LOG - cb_reserve_back(cb_y, onetile); - l1_write_addr_in = get_write_ptr(cb_y); - noc_async_read_tile(tile_idx, y_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_y, onetile); - #endif - - cb_reserve_back(cb_dy, onetile); - l1_write_addr_in = get_write_ptr(cb_dy); - noc_async_read_tile(tile_idx, dy_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_dy, onetile); - tile_idx += dim_stride; - } - - tile_idx = outer_idx * outer_stride + inner_idx; - for (uint32_t d = 0; d < dim_size; d++) { - cb_reserve_back(cb_dy, onetile); - l1_write_addr_in = get_write_ptr(cb_dy); - noc_async_read_tile(tile_idx, dy_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_dy, onetile); - - cb_reserve_back(cb_y, onetile); - l1_write_addr_in = get_write_ptr(cb_y); - noc_async_read_tile(tile_idx, y_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_y, onetile); - - tile_idx += dim_stride; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h.cpp deleted file mode 100644 index 962a788dfc8..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - -void kernel_main() { - uint32_t y_addr = get_arg_val(0); - uint32_t dy_addr = get_arg_val(1); - - uint32_t N = get_arg_val(2); - uint32_t tile_offset = get_arg_val(3); - uint32_t Ht = get_arg_val(4); - uint32_t Wt = get_arg_val(5); - - uint32_t scaler = get_arg_val(6); - uint32_t mask_h = get_arg_val(7); - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_scaler = tt::CB::c_in2; - constexpr auto cb_mask = tt::CB::c_in3; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t y_tile_bytes = get_tile_size(cb_y); - const DataFormat y_data_format = get_dataformat(cb_y); - - uint32_t dy_tile_bytes = get_tile_size(cb_dy); - const DataFormat dy_data_format = get_dataformat(cb_dy); - - constexpr bool y_is_dram = get_compile_time_arg_val(0) == 1; - constexpr bool dy_is_dram = get_compile_time_arg_val(1) == 1; - - const InterleavedAddrGenFast y_in = { - .bank_base_address = y_addr, .page_size = y_tile_bytes, .data_format = y_data_format}; - - const InterleavedAddrGenFast dy_in = { - .bank_base_address = dy_addr, .page_size = dy_tile_bytes, .data_format = dy_data_format}; - - // TODO(AP): cleanup, probably with named args/param pack/reflection. - generate_bcast_scaler(cb_scaler, scaler); - generate_mask_h(cb_mask, mask_h); - - // read ublocks from src0 to CB0, then push ublocks to compute (unpacker) - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i += onetile) { - uint32_t w_idx = curr_tile % Wt; - uint32_t nc_idx = curr_tile / Wt; - uint32_t tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - // read y - cb_reserve_back(cb_y, onetile); - l1_write_addr_in = get_write_ptr(cb_y); - noc_async_read_tile(tile_idx, y_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_y, onetile); - - // read dy - cb_reserve_back(cb_dy, onetile); - l1_write_addr_in = get_write_ptr(cb_dy); - noc_async_read_tile(tile_idx, dy_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_dy, onetile); - - tile_idx += Wt; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h_large.cpp deleted file mode 100644 index 31a1100ba33..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h_large.cpp +++ /dev/null @@ -1,95 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - -void kernel_main() { - uint32_t y_addr = get_arg_val(0); - uint32_t dy_addr = get_arg_val(1); - - uint32_t N = get_arg_val(2); - uint32_t tile_offset = get_arg_val(3); - uint32_t Ht = get_arg_val(4); - uint32_t Wt = get_arg_val(5); - - uint32_t scaler = get_arg_val(6); - uint32_t mask_h = get_arg_val(7); - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_scaler = tt::CB::c_in2; - constexpr auto cb_mask = tt::CB::c_in3; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t y_tile_bytes = get_tile_size(cb_y); - const DataFormat y_data_format = get_dataformat(cb_y); - - uint32_t dy_tile_bytes = get_tile_size(cb_dy); - const DataFormat dy_data_format = get_dataformat(cb_dy); - - constexpr bool y_is_dram = get_compile_time_arg_val(0) == 1; - constexpr bool dy_is_dram = get_compile_time_arg_val(1) == 1; - - const InterleavedAddrGenFast y_in = { - .bank_base_address = y_addr, .page_size = y_tile_bytes, .data_format = y_data_format}; - - const InterleavedAddrGenFast dy_in = { - .bank_base_address = dy_addr, .page_size = dy_tile_bytes, .data_format = dy_data_format}; - - // TODO(AP): cleanup, probably with named args/param pack/reflection. - generate_bcast_scaler(cb_scaler, scaler); - generate_mask_h(cb_mask, mask_h); - - // read ublocks from src0 to CB0, then push ublocks to compute (unpacker) - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i += onetile) { - uint32_t w_idx = curr_tile % Wt; - uint32_t nc_idx = curr_tile / Wt; - uint32_t tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - #ifndef LOG - // read y - cb_reserve_back(cb_y, onetile); - l1_write_addr_in = get_write_ptr(cb_y); - noc_async_read_tile(tile_idx, y_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_y, onetile); - #endif - - // read dy - cb_reserve_back(cb_dy, onetile); - l1_write_addr_in = get_write_ptr(cb_dy); - noc_async_read_tile(tile_idx, dy_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_dy, onetile); - - tile_idx += Wt; - } - - w_idx = curr_tile % Wt; - nc_idx = curr_tile / Wt; - tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - // read y - cb_reserve_back(cb_y, onetile); - l1_write_addr_in = get_write_ptr(cb_y); - noc_async_read_tile(tile_idx, y_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_y, onetile); - - // read dy - cb_reserve_back(cb_dy, onetile); - l1_write_addr_in = get_write_ptr(cb_dy); - noc_async_read_tile(tile_idx, dy_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_dy, onetile); - - tile_idx += Wt; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w.cpp deleted file mode 100644 index 40d2deb0e17..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w.cpp +++ /dev/null @@ -1,67 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - -void kernel_main() { - uint32_t y_addr = get_arg_val(0); - uint32_t dy_addr = get_arg_val(1); - - uint32_t N = get_arg_val(2); - uint32_t tile_offset = get_arg_val(3); - uint32_t Wt = get_arg_val(4); - - uint32_t scaler = get_arg_val(5); - uint32_t mask_w = get_arg_val(6); - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_scaler = tt::CB::c_in2; - constexpr auto cb_mask = tt::CB::c_in3; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t y_tile_bytes = get_tile_size(cb_y); - const DataFormat y_data_format = get_dataformat(cb_y); - - uint32_t dy_tile_bytes = get_tile_size(cb_dy); - const DataFormat dy_data_format = get_dataformat(cb_dy); - - constexpr bool y_is_dram = get_compile_time_arg_val(0) == 1; - constexpr bool dy_is_dram = get_compile_time_arg_val(1) == 1; - - const InterleavedAddrGenFast y_in = { - .bank_base_address = y_addr, .page_size = y_tile_bytes, .data_format = y_data_format}; - - const InterleavedAddrGenFast dy_in = { - .bank_base_address = dy_addr, .page_size = dy_tile_bytes, .data_format = dy_data_format}; - - // TODO(AP): cleanup, probably with named args/param pack/reflection. - generate_bcast_scaler(cb_scaler, scaler); - generate_mask_w(cb_mask, mask_w); - - // read ublocks from src0 to CB0, then push ublocks to compute (unpacker) - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i += onetile) { - for (uint32_t w = 0; w < Wt; w++) { - // read y - cb_reserve_back(cb_y, onetile); - l1_write_addr_in = get_write_ptr(cb_y); - noc_async_read_tile(curr_tile, y_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_y, onetile); - - // read dy - cb_reserve_back(cb_dy, onetile); - l1_write_addr_in = get_write_ptr(cb_dy); - noc_async_read_tile(curr_tile, dy_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_dy, onetile); - - curr_tile++; - } - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w_large.cpp deleted file mode 100644 index 04346969041..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w_large.cpp +++ /dev/null @@ -1,89 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - -void kernel_main() { - uint32_t y_addr = get_arg_val(0); - uint32_t dy_addr = get_arg_val(1); - - uint32_t N = get_arg_val(2); - uint32_t tile_offset = get_arg_val(3); - uint32_t Wt = get_arg_val(4); - - uint32_t scaler = get_arg_val(5); - uint32_t mask_w = get_arg_val(6); - - constexpr auto cb_y = tt::CB::c_in0; - constexpr auto cb_dy = tt::CB::c_in1; - constexpr auto cb_scaler = tt::CB::c_in2; - constexpr auto cb_mask = tt::CB::c_in3; - - uint32_t l1_write_addr_in; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t y_tile_bytes = get_tile_size(cb_y); - const DataFormat y_data_format = get_dataformat(cb_y); - - uint32_t dy_tile_bytes = get_tile_size(cb_dy); - const DataFormat dy_data_format = get_dataformat(cb_dy); - - constexpr bool y_is_dram = get_compile_time_arg_val(0) == 1; - constexpr bool dy_is_dram = get_compile_time_arg_val(1) == 1; - - const InterleavedAddrGenFast y_in = { - .bank_base_address = y_addr, .page_size = y_tile_bytes, .data_format = y_data_format}; - - const InterleavedAddrGenFast dy_in = { - .bank_base_address = dy_addr, .page_size = dy_tile_bytes, .data_format = dy_data_format}; - - // TODO(AP): cleanup, probably with named args/param pack/reflection. - generate_bcast_scaler(cb_scaler, scaler); - generate_mask_w(cb_mask, mask_w); - - // read ublocks from src0 to CB0, then push ublocks to compute (unpacker) - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i += onetile) { - uint32_t curr_offset_i = curr_tile; - for (uint32_t w = 0; w < Wt; w++) { - #ifndef LOG - // read y - cb_reserve_back(cb_y, onetile); - l1_write_addr_in = get_write_ptr(cb_y); - noc_async_read_tile(curr_tile, y_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_y, onetile); - #endif - - // read dy - cb_reserve_back(cb_dy, onetile); - l1_write_addr_in = get_write_ptr(cb_dy); - noc_async_read_tile(curr_tile, dy_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_dy, onetile); - - curr_tile++; - } - - curr_tile = curr_offset_i; - for (uint32_t w = 0; w < Wt; w++) { - // read y - cb_reserve_back(cb_y, onetile); - l1_write_addr_in = get_write_ptr(cb_y); - noc_async_read_tile(curr_tile, y_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_y, onetile); - - // read dy - cb_reserve_back(cb_dy, onetile); - l1_write_addr_in = get_write_ptr(cb_dy); - noc_async_read_tile(curr_tile, dy_in, l1_write_addr_in); - noc_async_read_barrier(); - cb_push_back(cb_dy, onetile); - - curr_tile++; - } - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_c.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_c.cpp deleted file mode 100644 index fcec659ae37..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_c.cpp +++ /dev/null @@ -1,44 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t num_tiles = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t outer_stride = get_arg_val(3); - uint32_t inner_size = get_arg_val(4); - uint32_t dim_size = get_arg_val(5); - - constexpr auto cb_out = tt::CB::c_out0; - - // ublocks size defined in tiles - constexpr uint32_t onetile = 1; - uint32_t dst_out_tile_bytes = get_tile_size(cb_out); - const DataFormat dst_out_data_format = get_dataformat(cb_out); - - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast dst_out = { - .bank_base_address = dst_addr, .page_size = dst_out_tile_bytes, .data_format = dst_out_data_format}; - - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < num_tiles; i += onetile) { - uint32_t outer_idx = curr_tile / (inner_size); - uint32_t inner_idx = curr_tile % inner_size; - uint32_t tile_idx = outer_idx * outer_stride + inner_idx; - - uint32_t dim_stride = inner_size; - for (uint32_t d = 0; d < dim_size; d++) { - cb_wait_front(cb_out, onetile); - uint32_t l1_read_addr = get_read_ptr(cb_out); - noc_async_write_tile(tile_idx, dst_out, l1_read_addr); - noc_async_write_barrier(); - cb_pop_front(cb_out, onetile); - tile_idx += dim_stride; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_h.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_h.cpp deleted file mode 100644 index bbad6708df3..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_h.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Ht = get_arg_val(3); - uint32_t Wt = get_arg_val(4); - - constexpr uint32_t cb_id_out = tt::CB::c_out0; - constexpr uint32_t onetile = 1; - uint32_t tile_bytes = get_tile_size(cb_id_out); - - const DataFormat data_format = get_dataformat(cb_id_out); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast s = { - .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; - - uint32_t blk = 1; - - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i++) { - uint32_t w_idx = curr_tile % Wt; - uint32_t nc_idx = curr_tile / Wt; - uint32_t tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - cb_wait_front(cb_id_out, blk); - uint32_t l1_read_addr = get_read_ptr(cb_id_out); - noc_async_write_tile(tile_idx, s, l1_read_addr); - noc_async_write_barrier(); - cb_pop_front(cb_id_out, blk); - tile_idx += Wt; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_w.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_w.cpp deleted file mode 100644 index 883562f2c8f..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_w.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Wt = get_arg_val(3); - - constexpr uint32_t cb_id_out = tt::CB::c_out0; - constexpr uint32_t onetile = 1; - uint32_t tile_bytes = get_tile_size(cb_id_out); - - const DataFormat data_format = get_dataformat(cb_id_out); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast s = { - .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; - - uint32_t blk = 1; - - uint32_t tile_id = tile_offset; - for (uint32_t i = 0; i < N; i++) { - for (uint32_t w = 0; w < Wt; w++) { - cb_wait_front(cb_id_out, blk); - uint32_t l1_read_addr = get_read_ptr(cb_id_out); - noc_async_write_tile(tile_id, s, l1_read_addr); - noc_async_write_barrier(); - cb_pop_front(cb_id_out, blk); - tile_id++; - } - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_h.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_h.cpp deleted file mode 100644 index bbad6708df3..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_h.cpp +++ /dev/null @@ -1,41 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Ht = get_arg_val(3); - uint32_t Wt = get_arg_val(4); - - constexpr uint32_t cb_id_out = tt::CB::c_out0; - constexpr uint32_t onetile = 1; - uint32_t tile_bytes = get_tile_size(cb_id_out); - - const DataFormat data_format = get_dataformat(cb_id_out); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast s = { - .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; - - uint32_t blk = 1; - - uint32_t curr_tile = tile_offset; - for (uint32_t i = 0; i < N; i++) { - uint32_t w_idx = curr_tile % Wt; - uint32_t nc_idx = curr_tile / Wt; - uint32_t tile_idx = nc_idx * Ht * Wt + w_idx; - for (uint32_t h = 0; h < Ht; h++) { - cb_wait_front(cb_id_out, blk); - uint32_t l1_read_addr = get_read_ptr(cb_id_out); - noc_async_write_tile(tile_idx, s, l1_read_addr); - noc_async_write_barrier(); - cb_pop_front(cb_id_out, blk); - tile_idx += Wt; - } - curr_tile += 1; - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_w.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_w.cpp deleted file mode 100644 index 883562f2c8f..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_w.cpp +++ /dev/null @@ -1,36 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "dataflow_api.h" - -void kernel_main() { - uint32_t dst_addr = get_arg_val(0); - uint32_t N = get_arg_val(1); - uint32_t tile_offset = get_arg_val(2); - uint32_t Wt = get_arg_val(3); - - constexpr uint32_t cb_id_out = tt::CB::c_out0; - constexpr uint32_t onetile = 1; - uint32_t tile_bytes = get_tile_size(cb_id_out); - - const DataFormat data_format = get_dataformat(cb_id_out); - constexpr bool dst_is_dram = get_compile_time_arg_val(0) == 1; - - const InterleavedAddrGenFast s = { - .bank_base_address = dst_addr, .page_size = tile_bytes, .data_format = data_format}; - - uint32_t blk = 1; - - uint32_t tile_id = tile_offset; - for (uint32_t i = 0; i < N; i++) { - for (uint32_t w = 0; w < Wt; w++) { - cb_wait_front(cb_id_out, blk); - uint32_t l1_read_addr = get_read_ptr(cb_id_out); - noc_async_write_tile(tile_id, s, l1_read_addr); - noc_async_write_barrier(); - cb_pop_front(cb_id_out, blk); - tile_id++; - } - } -} diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.cpp deleted file mode 100644 index 82ca9ceba8d..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.cpp +++ /dev/null @@ -1,283 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp" - -#include "ttnn/run_operation.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/host_api.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -void MorehSoftmaxBackward::validate_with_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const { - // validate input tensor - auto& output_tensor = input_tensors.at(0); - auto& output_grad_tensor = input_tensors.at(1); - - TT_ASSERT(output_tensor.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!"); - TT_ASSERT(output_grad_tensor.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!"); - TT_ASSERT(output_tensor.buffer() != nullptr, "Operands to softmax need to be allocated in buffers on device!"); - TT_ASSERT(output_grad_tensor.buffer() != nullptr, "Operands to softmax need to be allocated in buffers on device!"); - TT_ASSERT((output_tensor.get_layout() == Layout::TILE), "Output to softmax must be tilized"); - TT_ASSERT((output_grad_tensor.get_layout() == Layout::TILE), "Output_grad to softmax must be tilized"); - TT_ASSERT(output_tensor.get_dtype() == DataType::BFLOAT16 || output_tensor.get_dtype() == DataType::BFLOAT8_B); - TT_ASSERT( - output_grad_tensor.get_dtype() == DataType::BFLOAT16 || output_grad_tensor.get_dtype() == DataType::BFLOAT8_B); - - // validate parameters - auto rank = output_tensor.get_legacy_shape().rank(); - - TT_ASSERT( - this->dim >= 0 && this->dim < rank, - "dim {} should be less than output tensor rank {}", this->dim, rank); - - if (output_tensors.empty() || !output_tensors.at(0).has_value()) { - // If the user decided to not use any optional output tensors, then this would be empty or would be a nullptr. - return; - } - TT_ASSERT(input_tensors.size() == 2, "Must have 2 input tensors"); - TT_ASSERT(output_tensors.size() == 1, "Must have 1 output tensors"); -} - -std::vector MorehSoftmaxBackward::compute_output_shapes(const std::vector& input_tensors) const { - return {input_tensors.at(0).get_legacy_shape()}; -} - -std::vector MorehSoftmaxBackward::create_output_tensors( - const std::vector& input_tensors, const std::vector>& output_tensors) const { - if (!output_tensors.empty() && output_tensors.at(0).has_value()) { - return {output_tensors.at(0).value()}; - } - const auto& output_shape = input_tensors.at(0).get_legacy_shape(); - - return {operation::generic_create_output_tensors( - *this, input_tensors, input_tensors.at(0).get_dtype(), Layout::TILE, this->output_mem_config)}; -} - -operation::ProgramWithCallbacks MorehSoftmaxBackward::create_program( - const std::vector& input_tensors, std::vector& output_tensors) const { - auto& output = input_tensors.at(0); - auto& output_grad = input_tensors.at(1); - auto& input_grad = output_tensors.at(0); - - auto parallelization_strategy = this->get_parallelization_strategy(input_tensors); - - switch (parallelization_strategy) { - case MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_W: - return {moreh_softmax_backward_w_small( - output, output_grad, input_grad, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_H: - return {moreh_softmax_backward_h_small( - output, output_grad, input_grad, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_W: - return {moreh_softmax_backward_w_large( - output, output_grad, input_grad, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_H: - return {moreh_softmax_backward_h_large( - output, output_grad, input_grad, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_C: - return {moreh_softmax_backward_c_large( - output, output_grad, input_grad, this->dim, this->core_range, this->op, this->compute_kernel_config)}; - case MorehSoftmaxBackwardOpParallelizationStrategy::NONE: - default: break; - } - - return {moreh_softmax_backward_h_large( - output, output_grad, input_grad, this->core_range, this->op, this->compute_kernel_config)}; -} - -MorehSoftmaxBackwardOpParallelizationStrategy MorehSoftmaxBackward::get_parallelization_strategy( - const std::vector& input_tensors) const { - auto& output = input_tensors.at(0); - - auto rank = output.get_legacy_shape().rank(); - - if (this->strategy == MorehSoftmaxBackwardOpParallelizationStrategy::NONE) { - if (rank - 1 == this->dim) { - if (is_moreh_softmax_backward_w_small_available(output)) { - return MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_W; - } - return MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_W; - } - if (rank - 2 == this->dim) { - if (is_moreh_softmax_backward_h_small_available(output)) { - return MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_H; - } - return MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_H; - } - return MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_C; - } - - if (rank - 2 == this->dim) { - TT_ASSERT( - this->strategy == MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_H || - this->strategy == MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_H, - "Invalid parallelization strategy. {} is not for dim H", this->strategy); - - if (this->strategy == MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_H) { - TT_ASSERT( - is_moreh_softmax_backward_h_small_available(output), - "not enough circular buffer memory for {}", this->strategy); - } - } else if (rank - 1 == this->dim) { - TT_ASSERT( - this->strategy == MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_W || - this->strategy == MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_W, - "Invalid parallelization strategy. {} is not for dim W", this->strategy); - - if (this->strategy == MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_W) { - TT_ASSERT( - is_moreh_softmax_backward_w_small_available(output), - "not enough circular buffer memory for {}", this->strategy); - } - } else { - TT_ASSERT( - this->strategy == MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_C, - "Invalid parallelization strategy. large c is for dim 0 - (rank - 3)"); - } - - return this->strategy; -} - -Tensor moreh_softmax_backward( - const Tensor& output_tensor, - const Tensor& output_grad_tensor, - uint32_t dim, - std::optional input_grad_tensor, - const MorehSoftmaxBackwardOpParallelizationStrategy strategy, - const MemoryConfig& output_mem_config, - std::optional compute_kernel_config) { - auto device = output_grad_tensor.device(); - auto grid_coord = device->compute_with_storage_grid_size(); - const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); - - auto kernel_config_val = - init_device_compute_kernel_config(device->arch(), compute_kernel_config, MathFidelity::HiFi4); - - std::vector output_tensors = { - Tensor(operation::get_workers_for_op_output({output_tensor, output_grad_tensor}))}; - - operation::launch_op( - [dim, all_cores, strategy, output_mem_config, kernel_config_val]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run( - MorehSoftmaxBackward{ - .dim = dim, - .core_range = all_cores, - .op = MorehSoftmaxBackwardOp::SOFTMAX, - .strategy = strategy, - .output_mem_config = output_mem_config, - .compute_kernel_config = kernel_config_val}, - input_tensors, - optional_input_tensors, - optional_output_tensors); - }, - {output_tensor, output_grad_tensor}, - output_tensors, - {}, - {input_grad_tensor}); - - return output_tensors.at(0); -} - -Tensor moreh_softmin_backward( - const Tensor& output_tensor, - const Tensor& output_grad_tensor, - uint32_t dim, - std::optional input_grad_tensor, - const MorehSoftmaxBackwardOpParallelizationStrategy strategy, - const MemoryConfig& output_mem_config, - std::optional compute_kernel_config) { - auto device = output_grad_tensor.device(); - auto grid_coord = device->compute_with_storage_grid_size(); - const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); - - auto kernel_config_val = - init_device_compute_kernel_config(device->arch(), compute_kernel_config, MathFidelity::HiFi4); - - std::vector output_tensors = { - Tensor(operation::get_workers_for_op_output({output_tensor, output_grad_tensor}))}; - - operation::launch_op( - [dim, all_cores, strategy, output_mem_config, kernel_config_val]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run( - MorehSoftmaxBackward{ - .dim = dim, - .core_range = all_cores, - .op = MorehSoftmaxBackwardOp::SOFTMIN, - .strategy = strategy, - .output_mem_config = output_mem_config, - .compute_kernel_config = kernel_config_val}, - input_tensors, - optional_input_tensors, - optional_output_tensors); - }, - {output_tensor, output_grad_tensor}, - output_tensors, - {}, - {input_grad_tensor}); - - return output_tensors.at(0); -} - -Tensor moreh_logsoftmax_backward( - const Tensor& output_tensor, - const Tensor& output_grad_tensor, - uint32_t dim, - std::optional input_grad_tensor, - const MorehSoftmaxBackwardOpParallelizationStrategy strategy, - const MemoryConfig& output_mem_config, - std::optional compute_kernel_config) { - auto device = output_grad_tensor.device(); - auto grid_coord = device->compute_with_storage_grid_size(); - const CoreRange all_cores({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); - - auto kernel_config_val = - init_device_compute_kernel_config(device->arch(), compute_kernel_config, MathFidelity::HiFi4); - - std::vector output_tensors = { - Tensor(operation::get_workers_for_op_output({output_tensor, output_grad_tensor}))}; - - operation::launch_op( - [dim, all_cores, strategy, output_mem_config, kernel_config_val]( - const std::vector& input_tensors, - const std::vector>& optional_input_tensors, - const std::vector>& optional_output_tensors) mutable -> std::vector { - return operation::run( - MorehSoftmaxBackward{ - .dim = dim, - .core_range = all_cores, - .op = MorehSoftmaxBackwardOp::LOGSOFTMAX, - .strategy = strategy, - .output_mem_config = output_mem_config, - .compute_kernel_config = kernel_config_val}, - input_tensors, - optional_input_tensors, - optional_output_tensors); - }, - {output_tensor, output_grad_tensor}, - output_tensors, - {}, - {input_grad_tensor}); - - return output_tensors.at(0); -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp deleted file mode 100644 index ab80e275a24..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp +++ /dev/null @@ -1,125 +0,0 @@ -/* - * SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - * - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "ttnn/operation.hpp" -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" - -namespace tt { -namespace operations { -namespace primary { - -using namespace tt_metal; - -enum class MorehSoftmaxBackwardOpParallelizationStrategy { - NONE, - SMALL_W, - SMALL_H, - LARGE_W, - LARGE_H, - LARGE_C -}; - -enum class MorehSoftmaxBackwardOp { - SOFTMAX, - SOFTMIN, - LOGSOFTMAX, -}; - -bool is_moreh_softmax_backward_w_small_available(const Tensor &tensor); -bool is_moreh_softmax_backward_h_small_available(const Tensor &tensor); - -operation::ProgramWithCallbacks moreh_softmax_backward_w_small( - const Tensor &output, - const Tensor &output_grad, - const Tensor &input_grad, - const CoreRange core_range, - const MorehSoftmaxBackwardOp op, - const ttnn::DeviceComputeKernelConfig compute_kernel_config); -operation::ProgramWithCallbacks moreh_softmax_backward_w_large( - const Tensor &output, - const Tensor &output_grad, - const Tensor &input_grad, - const CoreRange core_range, - const MorehSoftmaxBackwardOp op, - const ttnn::DeviceComputeKernelConfig compute_kernel_config); -operation::ProgramWithCallbacks moreh_softmax_backward_h_small( - const Tensor &output, - const Tensor &output_grad, - const Tensor &input_grad, - const CoreRange core_range, - const MorehSoftmaxBackwardOp op, - const ttnn::DeviceComputeKernelConfig compute_kernel_config); -operation::ProgramWithCallbacks moreh_softmax_backward_h_large( - const Tensor &output, - const Tensor &output_grad, - const Tensor &input_grad, - const CoreRange core_range, - const MorehSoftmaxBackwardOp op, - const ttnn::DeviceComputeKernelConfig compute_kernel_config); -operation::ProgramWithCallbacks moreh_softmax_backward_c_large( - const Tensor &output, - const Tensor &output_grad, - const Tensor &input_grad, - uint32_t dim, - const CoreRange core_range, - const MorehSoftmaxBackwardOp op, - const ttnn::DeviceComputeKernelConfig compute_kernel_config); - -struct MorehSoftmaxBackward { - const uint32_t dim; - const CoreRange core_range; // unused for now - const MorehSoftmaxBackwardOp op; - const MorehSoftmaxBackwardOpParallelizationStrategy strategy; - const MemoryConfig output_mem_config; - const ttnn::DeviceComputeKernelConfig compute_kernel_config; - - void validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - operation::ProgramWithCallbacks create_program( - const std::vector &input_tensors, std::vector &output_tensors) const; - MorehSoftmaxBackwardOpParallelizationStrategy get_parallelization_strategy( - const std::vector &input_tensors) const; - static constexpr auto attribute_names = std::make_tuple("dim", "op", "strategy", "output_mem_config", "compute_kernel_config"); - const auto attribute_values() const { - return std::make_tuple(std::cref(this->dim), std::cref(this->op), std::cref(this->strategy), std::cref(this->output_mem_config), std::cref(this->compute_kernel_config)); - } -}; - -// const ref prevents -Tensor moreh_softmax_backward( - const Tensor &output_tensor, - const Tensor &output_grad_tensor, - uint32_t dim, - std::optional input_grad_tensor = std::nullopt, - const MorehSoftmaxBackwardOpParallelizationStrategy strategy = MorehSoftmaxBackwardOpParallelizationStrategy::NONE, - const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - std::optional compute_kernel_config = std::nullopt); - -Tensor moreh_softmin_backward( - const Tensor &output_tensor, - const Tensor &output_grad_tensor, - uint32_t dim, - std::optional input_grad_tensor = std::nullopt, - const MorehSoftmaxBackwardOpParallelizationStrategy strategy = MorehSoftmaxBackwardOpParallelizationStrategy::NONE, - const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - std::optional compute_kernel_config = std::nullopt); - -Tensor moreh_logsoftmax_backward( - const Tensor &output_tensor, - const Tensor &output_grad_tensor, - uint32_t dim, - std::optional input_grad_tensor = std::nullopt, - const MorehSoftmaxBackwardOpParallelizationStrategy strategy = MorehSoftmaxBackwardOpParallelizationStrategy::NONE, - const MemoryConfig &output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - std::optional compute_kernel_config = std::nullopt); - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp deleted file mode 100644 index 632f5e8d3c4..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp +++ /dev/null @@ -1,146 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -operation::ProgramWithCallbacks moreh_softmax_backward_c_large(const Tensor &output, const Tensor &output_grad, const Tensor &input_grad, uint32_t dim, const CoreRange core_range, const MorehSoftmaxBackwardOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Large tensor algorithm selected"); - - // split work - auto shape = input_grad.get_legacy_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - uint32_t num_tiles = input_grad.volume() / shape[dim] / H / W * Ht * Wt; - - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_tiles); - - auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(input_grad.get_dtype()); - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, 2}, // y - {CB::c_in1, 2}, // dy - {CB::c_out0, 2}, // dx - {CB::c_intermed0, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // y * dy - {CB::c_intermed1, 2, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // sum(y * dy) - {CB::c_intermed2, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // dy - sum - }); - - // create read/wrtie kernel - bool y_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dy_is_dram = output_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dx_is_dram = input_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - - std::map compute_defines; - if (op == MorehSoftmaxBackwardOp::SOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxBackwardOp::LOGSOFTMAX) { - compute_defines["LOG"] = 1; - reader_defines["LOG"] = 1; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_c.cpp", all_cores, {y_is_dram, dy_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_backward_c.cpp", all_cores, {dx_is_dram}, writer_defines); - - auto outer_stride = Ht * Wt; - for(int i = dim ; i < shape.rank() - 2; i++ ) { - outer_stride *= shape[i]; - } - auto dim_size = shape[dim]; - auto inner_size = outer_stride / dim_size; - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_c_large.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, dim_size}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, dim_size}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - vector reader_args = { - output.buffer()->address(), - output_grad.buffer()->address(), - num_tiles_per_core, tile_offset, - outer_stride, inner_size, - dim_size}; - - vector writer_args = {input_grad.buffer()->address(), num_tiles_per_core, tile_offset, - outer_stride, inner_size, - dim_size}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp deleted file mode 100644 index 77dbaf76338..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp +++ /dev/null @@ -1,141 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -operation::ProgramWithCallbacks moreh_softmax_backward_h_large(const Tensor &output, const Tensor &output_grad, const Tensor &input_grad, const CoreRange core_range, const MorehSoftmaxBackwardOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Large tensor algorithm selected"); - - // split work - auto shape = input_grad.get_padded_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - auto num = input_grad.volume() / H / W; - - uint32_t num_cols_tiles = num * Wt; - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_cols_tiles); - - auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(input_grad.get_dtype()); - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, 2}, // output - {CB::c_in1, 2}, // output_grad - {CB::c_in2, 1}, // scaler - {CB::c_in3, 1}, // mask - {CB::c_out0, 2}, // input_grad - {CB::c_intermed0, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // output * output_grad - {CB::c_intermed1, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // reduce - {CB::c_intermed2, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // dy - sum - {CB::c_intermed3, 2, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // add(output * output_grad) - }); - - // create read/wrtie kernel - bool y_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dy_is_dram = output_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dx_is_dram = input_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - std::map compute_defines; - if (op == MorehSoftmaxBackwardOp::SOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxBackwardOp::LOGSOFTMAX) { - compute_defines["LOG"] = 1; - reader_defines["LOG"] = 1; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h_large.cpp", all_cores, {y_is_dram, dy_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_h.cpp", all_cores, {dx_is_dram}, writer_defines); - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h_large.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, Ht}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, Ht}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - float scaler = 1.0f; - uint32_t mask_h = input_grad.get_logical_shape()[-2] % TILE_HEIGHT; - if(mask_h == 0) mask_h = TILE_HEIGHT; - vector reader_args = { - output.buffer()->address(), - output_grad.buffer()->address(), - num_tiles_per_core, tile_offset, Ht, Wt, *reinterpret_cast(&scaler), mask_h}; - - vector writer_args = {input_grad.buffer()->address(), num_tiles_per_core, tile_offset, Ht, Wt}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp deleted file mode 100644 index cc64991b0e5..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp +++ /dev/null @@ -1,163 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -#define L1_512KB (512 * 1024) - -bool is_moreh_softmax_backward_h_small_available(const Tensor &tensor) { - auto h = tensor.get_padded_shape()[-2]; - int32_t Ht = (h + TILE_HEIGHT - 1) / TILE_HEIGHT; - - tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); - - auto tile_size = tt_metal::detail::TileSize(data_format); - - int32_t cb_usage = 0; // bytes - cb_usage += Ht * tile_size; // output - cb_usage += Ht * tile_size; // output_grad - cb_usage += 1 * tile_size; // scaler - cb_usage += 1 * tile_size; // mask - cb_usage += 2 * tile_size; // input_grad - cb_usage += Ht * tile_size; // output * output_grad - cb_usage += 1 * tile_size; // reduce - cb_usage += 1 * tile_size; // dy - sum - - return (tensor.device()->get_base_allocator_addr(HalMemType::L1) + cb_usage <= L1_512KB); -} - -operation::ProgramWithCallbacks moreh_softmax_backward_h_small(const Tensor &output, const Tensor &output_grad, const Tensor &input_grad, const CoreRange core_range, const MorehSoftmaxBackwardOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Small tensor algorithm selected"); - // split work - auto shape = input_grad.get_padded_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - auto num = input_grad.volume() / H / W; - - uint32_t num_cols_tiles = num * Wt; - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_cols_tiles); - - auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(input_grad.get_dtype()); - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, Ht}, // output - {CB::c_in1, Ht}, // output_grad - {CB::c_in2, 1}, // scaler - {CB::c_in3, 1}, // mask - {CB::c_out0, 2}, // input_grad - {CB::c_intermed0, Ht, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // output * output_grad - {CB::c_intermed1, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // reduce - {CB::c_intermed2, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // dy - sum - }); - - // create read/wrtie kernel - bool y_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dy_is_dram = output_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dx_is_dram = input_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_h.cpp", all_cores, {y_is_dram, dy_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_h.cpp", all_cores, {dx_is_dram}, writer_defines); - - std::map compute_defines; - if (op == MorehSoftmaxBackwardOp::SOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxBackwardOp::LOGSOFTMAX) { - compute_defines["LOG"] = 1; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_h.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, Ht}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, Ht}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - float scaler = 1.0f; - uint32_t mask_h = input_grad.get_logical_shape()[-2] % TILE_HEIGHT; - if(mask_h == 0) mask_h = TILE_HEIGHT; - vector reader_args = { - output.buffer()->address(), - output_grad.buffer()->address(), - num_tiles_per_core, tile_offset, Ht, Wt, *reinterpret_cast(&scaler), mask_h}; - - vector writer_args = {input_grad.buffer()->address(), num_tiles_per_core, tile_offset, Ht, Wt}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp deleted file mode 100644 index 5904de7044b..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp +++ /dev/null @@ -1,141 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -operation::ProgramWithCallbacks moreh_softmax_backward_w_large(const Tensor &output, const Tensor &output_grad, const Tensor &input_grad, const CoreRange core_range, const MorehSoftmaxBackwardOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Large tensor algorithm selected"); - // split work - auto shape = input_grad.get_padded_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - auto num = input_grad.volume() / H / W; - - uint32_t num_kernel_rows = num * Ht; - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_kernel_rows); - - auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(input_grad.get_dtype()); - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, 2}, // output - {CB::c_in1, 2}, // output_grad - {CB::c_in2, 1}, // scaler - {CB::c_in3, 1}, // mask - {CB::c_out0, 2}, // input_grad - {CB::c_intermed0, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // output * output_grad - {CB::c_intermed1, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // reduce - {CB::c_intermed2, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // dy - sum - {CB::c_intermed3, 2, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // add(output * output_grad) - }); - - // create read/wrtie kernel - bool y_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dy_is_dram = output_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dx_is_dram = input_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - std::map compute_defines; - if (op == MorehSoftmaxBackwardOp::SOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxBackwardOp::LOGSOFTMAX) { - compute_defines["LOG"] = 1; - reader_defines["LOG"] = 1; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w_large.cpp", all_cores, {y_is_dram, dy_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_w.cpp", all_cores, {dx_is_dram}, writer_defines); - - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w_large.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, Wt}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, Wt}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - float scaler = 1.0f; - uint32_t mask_w = input_grad.get_logical_shape()[-1] % TILE_WIDTH; - if(mask_w == 0) mask_w = TILE_WIDTH; - vector reader_args = { - output.buffer()->address(), - output_grad.buffer()->address(), - num_tiles_per_core, tile_offset, Wt, *reinterpret_cast(&scaler), mask_w}; - - vector writer_args = {input_grad.buffer()->address(), num_tiles_per_core, tile_offset, Wt}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core * Wt; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp deleted file mode 100644 index d0b07328121..00000000000 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp +++ /dev/null @@ -1,164 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_helper_functions.hpp" -#include "tt_metal/common/work_split.hpp" -#include "ttnn/run_operation.hpp" - -#include "tt_metal/host_api.hpp" -#include "tt_metal/common/constants.hpp" -#include "tt_metal/detail/util.hpp" - -using namespace tt::constants; -using namespace std; -using namespace tt::tt_metal; - -namespace tt { -namespace operations { -namespace primary { - -#define L1_512KB (512 * 1024) - -bool is_moreh_softmax_backward_w_small_available(const Tensor &tensor) { - auto w = tensor.get_legacy_shape()[-1]; - int32_t Wt = (w + TILE_WIDTH - 1) / TILE_WIDTH; - - tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); - - auto tile_size = tt_metal::detail::TileSize(data_format); - - int32_t cb_usage = 0; // bytes - cb_usage += Wt * tile_size; // output - cb_usage += Wt * tile_size; // output_grad - cb_usage += 1 * tile_size; // scaler - cb_usage += 1 * tile_size; // mask - cb_usage += 2 * tile_size; // input_grad - cb_usage += Wt * tile_size; // output * output_grad - cb_usage += 1 * tile_size; // reduce - cb_usage += 1 * tile_size; // dy - sum - - return (tensor.device()->get_base_allocator_addr(HalMemType::L1) + cb_usage <= L1_512KB); -} - -operation::ProgramWithCallbacks moreh_softmax_backward_w_small(const Tensor &output, const Tensor &output_grad, const Tensor &input_grad, const CoreRange core_range, const MorehSoftmaxBackwardOp op, const ttnn::DeviceComputeKernelConfig compute_kernel_config) { - log_info(LogTest, "Small tensor algorithm selected"); - - // split work - auto shape = input_grad.get_padded_shape(); - auto H = shape[-2]; - auto W = shape[-1]; - auto Ht = H / TILE_HEIGHT; - auto Wt = W / TILE_WIDTH; - - auto num = input_grad.volume() / H / W; - - uint32_t num_kernel_rows = num * Ht; - uint32_t core_w = core_range.end_coord.x - core_range.start_coord.x + 1; - uint32_t core_h = core_range.end_coord.y - core_range.start_coord.y + 1; - - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = - split_work_to_cores(core_range, num_kernel_rows); - - auto arch = input_grad.device()->arch(); - auto [math_fidelity, math_approx_mode, fp32_dest_acc_en, packer_l1_acc, dst_full_sync_en] = get_compute_kernel_config_args(arch, compute_kernel_config); - - Program program = Program(); - - // create circular buffers - tt::DataFormat data_format = tt_metal::datatype_to_dataformat_converter(input_grad.get_dtype()); - - CreateCircularBuffer( - program, - all_cores, - data_format, - { - {CB::c_in0, Wt}, // output - {CB::c_in1, Wt}, // output_grad - {CB::c_in2, 1}, // scaler - {CB::c_in3, 1}, // mask - {CB::c_out0, 2}, // input_grad - {CB::c_intermed0, Wt, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // output * output_grad - {CB::c_intermed1, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // reduce - {CB::c_intermed2, 1, fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format}, // dy - sum - }); - - // create read/wrtie kernel - bool y_is_dram = output.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dy_is_dram = output_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool dx_is_dram = input_grad.buffer()->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - - std::map reader_defines; - std::map writer_defines; - - auto reader_kernel_id = CreateReadKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/reader_moreh_softmax_backward_w.cpp", all_cores, {y_is_dram, dy_is_dram}, reader_defines); - auto writer_kernel_id = CreateWriteKernel( - program, "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/writer_moreh_softmax_w.cpp", all_cores, {dx_is_dram}, writer_defines); - - std::map compute_defines; - if (op == MorehSoftmaxBackwardOp::SOFTMAX) compute_defines["SOFTMAX"] = "1"; - else compute_defines["SOFTMIN"] = "1"; - - if (op == MorehSoftmaxBackwardOp::LOGSOFTMAX) { - compute_defines["LOG"] = 1; - } - - if (fp32_dest_acc_en) { - compute_defines["FP32_DEST_ACC_EN"] = "1"; - } - - // create compute kernel - CreateComputeKernel( - program, - "ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/kernels/moreh_softmax_backward_w.cpp", - { - {core_group_1, num_tiles_per_core_group_1, {num_tiles_per_core_group_1, Wt}}, - {core_group_2, num_tiles_per_core_group_2, {num_tiles_per_core_group_2, Wt}}, - }, - compute_defines, - math_fidelity, - fp32_dest_acc_en, - math_approx_mode); - - // Set Runtime Args - auto core_x_offset = core_range.start_coord.x; - auto core_y_offset = core_range.start_coord.y; - - for (uint32_t i = 0, tile_offset = 0; i < num_cores; i++) { - CoreCoord core = {i / core_h + core_x_offset, i % core_h + core_y_offset}; - uint32_t num_tiles_per_core; - if (core_group_1.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_1; - } else if (core_group_2.core_coord_in_core_ranges(core)) { - num_tiles_per_core = num_tiles_per_core_group_2; - } else { - TT_THROW("Core not in specified core ranges"); - } - - float scaler = 1.0f; - uint32_t mask_w = input_grad.get_logical_shape()[-1] % TILE_WIDTH; - if(mask_w == 0) mask_w = TILE_WIDTH; - vector reader_args = { - output.buffer()->address(), - output_grad.buffer()->address(), - num_tiles_per_core, tile_offset, Wt, *reinterpret_cast(&scaler), mask_w}; - - vector writer_args = {input_grad.buffer()->address(), num_tiles_per_core, tile_offset, Wt}; - - SetRuntimeArgs(program, reader_kernel_id, core, reader_args); - SetRuntimeArgs(program, writer_kernel_id, core, writer_args); - - tile_offset += num_tiles_per_core * Wt; - } - - return { - .program = std::move(program), - .override_runtime_arguments_callback = - create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; -} - -} // namespace primary -} // namespace operations -} // namespace tt diff --git a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp index e003ec42341..62bc763e29d 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_lib/csrc/operations/primary/module.hpp @@ -14,8 +14,6 @@ #include "ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_matmul_backward/moreh_matmul_backward_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax_backward/moreh_softmax_backward_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_sum/moreh_sum_op.hpp" #include "ttnn/deprecated/tt_dnn/op_library/moreh_sum_backward/moreh_sum_backward_op.hpp" @@ -25,23 +23,6 @@ namespace tt { namespace operations { namespace primary { -void py_module_types(py::module& m_primary) { - py::enum_(m_primary, "MorehSoftmaxOpParallelizationStrategy") - .value("NONE", MorehSoftmaxOpParallelizationStrategy::NONE) - .value("SMALL_W", MorehSoftmaxOpParallelizationStrategy::SMALL_W) - .value("SMALL_H", MorehSoftmaxOpParallelizationStrategy::SMALL_H) - .value("LARGE_W", MorehSoftmaxOpParallelizationStrategy::LARGE_W) - .value("LARGE_H", MorehSoftmaxOpParallelizationStrategy::LARGE_H) - .value("LARGE_C", MorehSoftmaxOpParallelizationStrategy::LARGE_C); - - py::enum_(m_primary, "MorehSoftmaxBackwardOpParallelizationStrategy") - .value("NONE", MorehSoftmaxBackwardOpParallelizationStrategy::NONE) - .value("SMALL_W", MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_W) - .value("SMALL_H", MorehSoftmaxBackwardOpParallelizationStrategy::SMALL_H) - .value("LARGE_W", MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_W) - .value("LARGE_H", MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_H) - .value("LARGE_C", MorehSoftmaxBackwardOpParallelizationStrategy::LARGE_C); -} void py_module(py::module& m_primary) { // moreh_clip_grad_norm @@ -151,71 +132,6 @@ void py_module(py::module& m_primary) { py::arg("compute_kernel_config").noconvert() = std::nullopt, "Performs a moreh_layernorm_backward operation."); - m_primary.def( - "moreh_softmax", - &moreh_softmax, - py::arg("input_tensor").noconvert(), - py::arg("dim").noconvert(), - py::arg("output_tensor").noconvert() = std::nullopt, - py::arg("strategy").noconvert() = MorehSoftmaxOpParallelizationStrategy::NONE, - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("compute_kernel_config").noconvert() = std::nullopt, - "Performs a softmax operation. Returns an output tensor."); - m_primary.def( - "moreh_softmax_backward", - &moreh_softmax_backward, - py::arg("output_tensor").noconvert(), - py::arg("output_grad_tensor").noconvert(), - py::arg("dim").noconvert(), - py::arg("input_grad_tensor").noconvert() = std::nullopt, - py::arg("strategy").noconvert() = MorehSoftmaxBackwardOpParallelizationStrategy::NONE, - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("compute_kernel_config").noconvert() = std::nullopt, - "Performs a softmax backward operation. Returns an input grad tensor."); - m_primary.def( - "moreh_softmin", - &moreh_softmin, - py::arg("input_tensor").noconvert(), - py::arg("dim").noconvert(), - py::arg("output_tensor").noconvert() = std::nullopt, - py::arg("strategy").noconvert() = MorehSoftmaxOpParallelizationStrategy::NONE, - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("compute_kernel_config").noconvert() = std::nullopt, - "Performs a softmin operation. Returns an output tensor."); - m_primary.def( - "moreh_softmin_backward", - &moreh_softmin_backward, - py::arg("output_tensor").noconvert(), - py::arg("output_grad_tensor").noconvert(), - py::arg("dim").noconvert(), - py::arg("input_grad_tensor").noconvert() = std::nullopt, - py::arg("strategy").noconvert() = MorehSoftmaxBackwardOpParallelizationStrategy::NONE, - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("compute_kernel_config").noconvert() = std::nullopt, - "Performs a softmin backward operation. Returns an input grad tensor."); - - m_primary.def( - "moreh_logsoftmax", - &moreh_logsoftmax, - py::arg("input_tensor").noconvert(), - py::arg("dim").noconvert(), - py::arg("output_tensor").noconvert() = std::nullopt, - py::arg("strategy").noconvert() = MorehSoftmaxOpParallelizationStrategy::NONE, - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("compute_kernel_config").noconvert() = std::nullopt, - "Performs a logsoftmax operation. Returns an output tensor."); - - m_primary.def( - "moreh_logsoftmax_backward", - &moreh_logsoftmax_backward, - py::arg("output_tensor").noconvert(), - py::arg("output_grad_tensor").noconvert(), - py::arg("dim").noconvert(), - py::arg("input_grad_tensor").noconvert() = std::nullopt, - py::arg("strategy").noconvert() = MorehSoftmaxBackwardOpParallelizationStrategy::NONE, - py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, - py::arg("compute_kernel_config").noconvert() = std::nullopt, - "Performs a logsoftmax backward operation. Returns an input grad tensor."); m_primary.def( "moreh_sum", diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp index a2e23628d2f..201d6a7d1d9 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.cpp @@ -9,7 +9,7 @@ namespace ttnn::operations::moreh::moreh_softmax { #define L1_512KB (512 * 1024) bool is_moreh_softmax_w_small_available(const Tensor& tensor, const DeviceComputeKernelConfig& compute_kernel_config) { - auto w = tensor.get_legacy_shape()[-1]; + auto w = tensor.get_shape()[-1]; int32_t Wt = (w + tt::constants::TILE_WIDTH - 1) / tt::constants::TILE_WIDTH; auto arch = tensor.device()->arch(); @@ -39,7 +39,7 @@ bool is_moreh_softmax_w_small_available(const Tensor& tensor, const DeviceComput } bool is_moreh_softmax_h_small_available(const Tensor& tensor, const DeviceComputeKernelConfig& compute_kernel_config) { - auto h = tensor.get_legacy_shape()[-2]; + auto h = tensor.get_shape()[-2]; int32_t Ht = (h + tt::constants::TILE_HEIGHT - 1) / tt::constants::TILE_HEIGHT; auto arch = tensor.device()->arch(); @@ -83,15 +83,15 @@ MorehSoftmaxOperation::program_factory_t MorehSoftmaxOperation::select_program_f void MorehSoftmaxOperation::validate_inputs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& input_tensor = tensor_args.input_tensor; - TT_FATAL(input_tensor.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!"); - TT_FATAL(input_tensor.buffer() != nullptr, "Operands to softmax need to be allocated in buffers on device!"); - TT_FATAL((input_tensor.get_layout() == Layout::TILE), "Inputs to softmax must be tilized"); + const auto& input = tensor_args.input; + TT_FATAL(input.storage_type() == StorageType::DEVICE, "Operands to softmax need to be on device!"); + TT_FATAL(input.buffer() != nullptr, "Operands to softmax need to be allocated in buffers on device!"); + TT_FATAL((input.get_layout() == Layout::TILE), "Inputs to softmax must be tilized"); TT_FATAL( - input_tensor.get_dtype() == DataType::BFLOAT16 || input_tensor.get_dtype() == DataType::BFLOAT8_B, + input.get_dtype() == DataType::BFLOAT16 || input.get_dtype() == DataType::BFLOAT8_B, "Inputs must be of bfloat16 or bfloat8_b type"); - const auto rank = input_tensor.get_legacy_shape().rank(); + const auto rank = input.get_shape().rank(); const auto dim = operation_attributes.dim; TT_FATAL(dim >= 0 && dim < rank, "dim {} should be less than output tensor rank {}", dim, rank); } @@ -108,30 +108,26 @@ void MorehSoftmaxOperation::validate_on_program_cache_hit( MorehSoftmaxOperation::shape_return_value_t MorehSoftmaxOperation::compute_output_shapes( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return tensor_args.input_tensor.get_shape(); + return tensor_args.input.get_shape(); } MorehSoftmaxOperation::tensor_return_value_t MorehSoftmaxOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& output_tensor = tensor_args.output_tensor; - if (output_tensor.has_value()) - return output_tensor.value(); + const auto& output = tensor_args.output; + if (output.has_value()) + return output.value(); - const auto& input_tensor = tensor_args.input_tensor; - const auto& output_shape = input_tensor.get_legacy_shape(); + const auto& input = tensor_args.input; + const auto& output_shape = input.get_shape(); return create_device_tensor( - output_shape, - input_tensor.get_dtype(), - input_tensor.get_layout(), - input_tensor.device(), - operation_attributes.memory_config); + output_shape, input.get_dtype(), input.get_layout(), input.device(), operation_attributes.memory_config); } std::tuple MorehSoftmaxOperation::invoke( - const Tensor& input_tensor, + const Tensor& input, uint32_t dim, - const std::optional& output_tensor, + const std::optional& output, const MorehSoftmaxOp op, const MorehSoftmaxOpParallelizationStrategy strategy, const std::optional& memory_config, @@ -141,20 +137,19 @@ MorehSoftmaxOperation::invoke( dim, op, strategy, - memory_config.value_or(input_tensor.memory_config()), - init_device_compute_kernel_config( - input_tensor.device()->arch(), compute_kernel_config, MathFidelity::HiFi4)}, - tensor_args_t{input_tensor, output_tensor}}; + memory_config.value_or(input.memory_config()), + init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4)}, + tensor_args_t{input, output}}; } MorehSoftmaxOpParallelizationStrategy MorehSoftmaxOperation::get_parallelization_strategy( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - const auto& input = tensor_args.input_tensor; + const auto& input = tensor_args.input; const auto strategy = operation_attributes.strategy; const auto dim = operation_attributes.dim; const auto& compute_kernel_config = operation_attributes.compute_kernel_config; - auto rank = input.get_legacy_shape().rank(); + auto rank = input.get_shape().rank(); if (strategy == MorehSoftmaxOpParallelizationStrategy::NONE) { if (rank - 1 == dim) { if (is_moreh_softmax_w_small_available(input, compute_kernel_config)) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.hpp index ff5698ffa1b..42ee2f45b6e 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.hpp @@ -35,8 +35,8 @@ struct MorehSoftmaxOperation { }; struct tensor_args_t { - const Tensor& input_tensor; - const std::optional& output_tensor; + const Tensor& input; + const std::optional& output; }; using shape_return_value_t = Shape; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp index 520854fb899..25dde8d6154 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_c_large/softmax_c_large.cpp @@ -13,7 +13,7 @@ MorehSoftmaxOperation::MorehSoftmaxCLargeFactory::create( const tensor_args_t& tensor_args, tensor_return_value_t& output) { log_info(tt::LogTest, "Large tensor algorithm selected"); - const auto& input = tensor_args.input_tensor; + const auto& input = tensor_args.input; const auto dim = operation_attributes.dim; const auto op = operation_attributes.op; const auto& compute_kernel_config = operation_attributes.compute_kernel_config; @@ -22,7 +22,7 @@ MorehSoftmaxOperation::MorehSoftmaxCLargeFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input.get_legacy_shape(); + auto shape = input.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; auto Ht = H / tt::constants::TILE_HEIGHT; @@ -153,12 +153,11 @@ void MorehSoftmaxOperation::MorehSoftmaxCLargeFactory::override_runtime_argument auto& writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; auto& num_cores = cached_program.shared_variables.num_cores; auto& num_cores_y = cached_program.shared_variables.num_cores_y; - for (uint32_t i = 0; i < num_cores; i++) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = tensor_args.input_tensor.buffer()->address(); + runtime_args[0] = tensor_args.input.buffer()->address(); } { auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp index 5e1a8339063..6160032be1c 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_large/softmax_h_large.cpp @@ -13,7 +13,7 @@ MorehSoftmaxOperation::MorehSoftmaxHLargeFactory::create( const tensor_args_t& tensor_args, tensor_return_value_t& output) { log_info(tt::LogTest, "Large tensor algorithm selected"); - const auto& input = tensor_args.input_tensor; + const auto& input = tensor_args.input; const auto op = operation_attributes.op; const auto& compute_kernel_config = operation_attributes.compute_kernel_config; @@ -157,12 +157,11 @@ void MorehSoftmaxOperation::MorehSoftmaxHLargeFactory::override_runtime_argument auto& writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; auto& num_cores = cached_program.shared_variables.num_cores; auto& num_cores_y = cached_program.shared_variables.num_cores_y; - for (uint32_t i = 0; i < num_cores; i++) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = tensor_args.input_tensor.buffer()->address(); + runtime_args[0] = tensor_args.input.buffer()->address(); } { auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp index 1820a9f6df6..bbd62257bf5 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_h_small/softmax_h_small.cpp @@ -13,7 +13,7 @@ MorehSoftmaxOperation::MorehSoftmaxHSmallFactory::create( const tensor_args_t& tensor_args, tensor_return_value_t& output) { log_info(tt::LogTest, "Large tensor algorithm selected"); - const auto& input = tensor_args.input_tensor; + const auto& input = tensor_args.input; const auto op = operation_attributes.op; const auto& compute_kernel_config = operation_attributes.compute_kernel_config; @@ -21,7 +21,7 @@ MorehSoftmaxOperation::MorehSoftmaxHSmallFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input.get_legacy_shape(); + auto shape = input.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; @@ -158,12 +158,11 @@ void MorehSoftmaxOperation::MorehSoftmaxHSmallFactory::override_runtime_argument auto& writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; auto& num_cores = cached_program.shared_variables.num_cores; auto& num_cores_y = cached_program.shared_variables.num_cores_y; - for (uint32_t i = 0; i < num_cores; i++) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = tensor_args.input_tensor.buffer()->address(); + runtime_args[0] = tensor_args.input.buffer()->address(); } { auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp index d2703e5a4c8..f74ba41b8fd 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_large/softmax_w_large.cpp @@ -13,7 +13,7 @@ MorehSoftmaxOperation::MorehSoftmaxWLargeFactory::create( const tensor_args_t& tensor_args, tensor_return_value_t& output) { log_info(tt::LogTest, "Large tensor algorithm selected"); - const auto& input = tensor_args.input_tensor; + const auto& input = tensor_args.input; const auto op = operation_attributes.op; const auto& compute_kernel_config = operation_attributes.compute_kernel_config; @@ -21,7 +21,7 @@ MorehSoftmaxOperation::MorehSoftmaxWLargeFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input.get_legacy_shape(); + auto shape = input.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; auto Ht = H / tt::constants::TILE_HEIGHT; @@ -157,12 +157,11 @@ void MorehSoftmaxOperation::MorehSoftmaxWLargeFactory::override_runtime_argument auto& writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; auto& num_cores = cached_program.shared_variables.num_cores; auto& num_cores_y = cached_program.shared_variables.num_cores_y; - for (uint32_t i = 0; i < num_cores; i++) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = tensor_args.input_tensor.buffer()->address(); + runtime_args[0] = tensor_args.input.buffer()->address(); } { auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp index 627f9a2f827..586816d6805 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/softmax_w_small/softmax_w_small.cpp @@ -13,7 +13,7 @@ MorehSoftmaxOperation::MorehSoftmaxWSmallFactory::create( const tensor_args_t& tensor_args, tensor_return_value_t& output) { log_info(tt::LogTest, "Large tensor algorithm selected"); - const auto& input = tensor_args.input_tensor; + const auto& input = tensor_args.input; const auto op = operation_attributes.op; const auto& compute_kernel_config = operation_attributes.compute_kernel_config; @@ -21,7 +21,7 @@ MorehSoftmaxOperation::MorehSoftmaxWSmallFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input.get_legacy_shape(); + auto shape = input.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; auto Ht = H / tt::constants::TILE_HEIGHT; @@ -156,12 +156,11 @@ void MorehSoftmaxOperation::MorehSoftmaxWSmallFactory::override_runtime_argument auto& writer_kernel_id = cached_program.shared_variables.unary_writer_kernel_id; auto& num_cores = cached_program.shared_variables.num_cores; auto& num_cores_y = cached_program.shared_variables.num_cores_y; - for (uint32_t i = 0; i < num_cores; i++) { CoreCoord core = {i / num_cores_y, i % num_cores_y}; { auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = tensor_args.input_tensor.buffer()->address(); + runtime_args[0] = tensor_args.input.buffer()->address(); } { auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/moreh_softmax_backward_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/moreh_softmax_backward_device_operation.cpp index b1239fad816..77bc76fa8ef 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/moreh_softmax_backward_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/moreh_softmax_backward_device_operation.cpp @@ -9,7 +9,7 @@ namespace ttnn::operations::moreh::moreh_softmax_backward { #define L1_512KB (512 * 1024) bool is_moreh_softmax_backward_w_small_available(const Tensor& tensor) { - auto w = tensor.get_legacy_shape()[-1]; + auto w = tensor.get_shape()[-1]; int32_t Wt = (w + tt::constants::TILE_WIDTH - 1) / tt::constants::TILE_WIDTH; tt::DataFormat data_format = tt::tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); @@ -30,7 +30,7 @@ bool is_moreh_softmax_backward_w_small_available(const Tensor& tensor) { } bool is_moreh_softmax_backward_h_small_available(const Tensor& tensor) { - auto h = tensor.get_legacy_shape()[-2]; + auto h = tensor.get_shape()[-2]; int32_t Ht = (h + tt::constants::TILE_HEIGHT - 1) / tt::constants::TILE_HEIGHT; tt::DataFormat data_format = tt::tt_metal::datatype_to_dataformat_converter(tensor.get_dtype()); @@ -80,7 +80,7 @@ void MorehSoftmaxBackwardOperation::validate_inputs( output_grad_tensor.get_dtype() == DataType::BFLOAT16 || output_grad_tensor.get_dtype() == DataType::BFLOAT8_B, "Output_tensor_grad dtype should be bfloat16 or bfloat8_b"); - const auto rank = output_tensor.get_legacy_shape().rank(); + const auto rank = output_tensor.get_shape().rank(); const auto dim = operation_attributes.dim; TT_FATAL(dim >= 0 && dim < rank, "dim {} should be less than output tensor rank {}", dim, rank); } @@ -107,7 +107,7 @@ MorehSoftmaxBackwardOperation::tensor_return_value_t MorehSoftmaxBackwardOperati return input_grad_tensor.value(); const auto& output_tensor = tensor_args.output_tensor; - const auto& input_grad_shape = output_tensor.get_legacy_shape(); + const auto& input_grad_shape = output_tensor.get_shape(); return create_device_tensor( input_grad_shape, output_tensor.get_dtype(), @@ -144,7 +144,7 @@ MorehSoftmaxBackwardOpParallelizationStrategy MorehSoftmaxBackwardOperation::get const auto dim = operation_attributes.dim; const auto& compute_kernel_config = operation_attributes.compute_kernel_config; - auto rank = output.get_legacy_shape().rank(); + auto rank = output.get_shape().rank(); if (strategy == MorehSoftmaxBackwardOpParallelizationStrategy::NONE) { if (rank - 1 == dim) { if (is_moreh_softmax_backward_w_small_available(output)) { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp index 68c89a6cc4f..c5c69ee8eed 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_c_large/softmax_backward_c_large.cpp @@ -23,7 +23,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardCLargeFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input_grad.get_legacy_shape(); + auto shape = input_grad.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; auto Ht = H / tt::constants::TILE_HEIGHT; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp index fc042e650bc..98cc38c9534 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_large/softmax_backward_h_large.cpp @@ -22,7 +22,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardHLargeFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input_grad.get_legacy_shape(); + auto shape = input_grad.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; auto Ht = H / tt::constants::TILE_HEIGHT; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_small/softmax_backward_h_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_small/softmax_backward_h_small.cpp index a6b228f795a..34ffa241b4f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_small/softmax_backward_h_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_h_small/softmax_backward_h_small.cpp @@ -22,7 +22,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardHSmallFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input_grad.get_legacy_shape(); + auto shape = input_grad.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; auto Ht = H / tt::constants::TILE_HEIGHT; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_large/softmax_backward_w_large.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_large/softmax_backward_w_large.cpp index fb3db5e4216..53e68b2453f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_large/softmax_backward_w_large.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_large/softmax_backward_w_large.cpp @@ -22,7 +22,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardWLargeFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input_grad.get_legacy_shape(); + auto shape = input_grad.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; auto Ht = H / tt::constants::TILE_HEIGHT; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_small/softmax_backward_w_small.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_small/softmax_backward_w_small.cpp index f5af6071e38..cd17c998616 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_small/softmax_backward_w_small.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_softmax_backward/device/softmax_backward_w_small/softmax_backward_w_small.cpp @@ -22,7 +22,7 @@ MorehSoftmaxBackwardOperation::MorehSoftmaxBackwardWSmallFactory::create( auto grid_coord = device->compute_with_storage_grid_size(); const CoreRange core_range({0, 0}, {grid_coord.x - 1, grid_coord.y - 1}); // split work - auto shape = input_grad.get_legacy_shape(); + auto shape = input_grad.get_shape().value; auto H = shape[-2]; auto W = shape[-1]; auto Ht = H / tt::constants::TILE_HEIGHT; diff --git a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.cpp b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.cpp index e4774606a41..7097643dff8 100644 --- a/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.cpp +++ b/ttnn/cpp/ttnn/operations/normalization/softmax/softmax.cpp @@ -4,12 +4,14 @@ #include "softmax.hpp" -#include "ttnn/deprecated/tt_dnn/op_library/moreh_softmax/moreh_softmax_op.hpp" +#include "ttnn/cpp/ttnn/operations/moreh/moreh_softmax/device/moreh_softmax_device_operation.hpp" #include "device/softmax_op.hpp" #include "ttnn/operations/core/core.hpp" namespace ttnn::operations::normalization { +using namespace moreh::moreh_softmax; + ttnn::Tensor ExecuteSoftmax::invoke( const ttnn::Tensor& input_tensor, const int dim_arg, @@ -31,7 +33,7 @@ ttnn::Tensor ExecuteSoftmax::invoke( return ttnn::reshape(output_tensor, input_shape); } else { auto dim_4D = dim + 4 - rank; - auto output_tensor = tt::operations::primary::moreh_softmax(input_tensor_4D, dim_4D); + auto output_tensor = ttnn::prim::moreh_softmax(input_tensor_4D, dim_4D, std::nullopt, MorehSoftmaxOp::SOFTMAX, MorehSoftmaxOpParallelizationStrategy::NONE, memory_config.value_or(input_tensor.memory_config()), compute_kernel_config); return ttnn::reshape(output_tensor, input_shape); } } From 2d9853b73eda29706a310034cfbbcb6aabdd3e7a Mon Sep 17 00:00:00 2001 From: Aditya Saigal <129097327+tt-asaigal@users.noreply.github.com> Date: Wed, 9 Oct 2024 14:17:15 -0400 Subject: [PATCH 45/58] #0: Move input and output queue data structures in vc_packet_router to stack (#13641) #0: Move input and output queue data structures in vc_packet_rotuer to stack - Were previously created as unintialized global variables, which put them in .bss - Moved to stack, since .bss region was overflowing --- tt_metal/impl/dispatch/kernels/vc_packet_router.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tt_metal/impl/dispatch/kernels/vc_packet_router.cpp b/tt_metal/impl/dispatch/kernels/vc_packet_router.cpp index 5460b44857a..6c77ea28a02 100644 --- a/tt_metal/impl/dispatch/kernels/vc_packet_router.cpp +++ b/tt_metal/impl/dispatch/kernels/vc_packet_router.cpp @@ -7,9 +7,6 @@ #include "tt_metal/impl/dispatch/kernels/packet_queue.hpp" #include "tt_metal/impl/dispatch/kernels/cq_helpers.hpp" -packet_input_queue_state_t input_queues[MAX_SWITCH_FAN_IN]; -packet_output_queue_state_t output_queues[MAX_SWITCH_FAN_OUT]; - constexpr uint32_t rx_queue_start_addr_words = get_compile_time_arg_val(1); constexpr uint32_t rx_queue_size_words = get_compile_time_arg_val(2); constexpr uint32_t rx_queue_size_bytes = rx_queue_size_words*PACKET_WORD_SIZE_BYTES; @@ -207,6 +204,8 @@ constexpr uint8_t input_packetize_dest_endpoint[MAX_SWITCH_FAN_IN] = }; void kernel_main() { + packet_input_queue_state_t input_queues[MAX_SWITCH_FAN_IN]; + packet_output_queue_state_t output_queues[MAX_SWITCH_FAN_OUT]; write_kernel_status(kernel_status, PQ_TEST_STATUS_INDEX, PACKET_QUEUE_TEST_STARTED); write_kernel_status(kernel_status, PQ_TEST_MISC_INDEX, 0xff000000); From 0a84342e641bcb2e5a3768405cc80c53f638d9fd Mon Sep 17 00:00:00 2001 From: Michael Chiou Date: Tue, 8 Oct 2024 17:46:21 -0400 Subject: [PATCH 46/58] #0: Weka is required for t3k unit test - mixtral test --- .github/workflows/t3000-unit-tests-impl.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/t3000-unit-tests-impl.yaml b/.github/workflows/t3000-unit-tests-impl.yaml index a84c55120d0..c76fe4d20cc 100644 --- a/.github/workflows/t3000-unit-tests-impl.yaml +++ b/.github/workflows/t3000-unit-tests-impl.yaml @@ -30,6 +30,11 @@ jobs: - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV + - name: Ensure weka mount is active + run: | + sudo systemctl restart mnt-MLPerf.mount + sudo /etc/rc.local + ls -al /mnt/MLPerf/bit_error_tests - uses: actions/download-artifact@v4 with: name: TTMetal_build_${{ matrix.test-group.arch }} From 75390d9e1e5ba261a4515db4ba3e1eec3ffdf05b Mon Sep 17 00:00:00 2001 From: Michael Chiou Date: Tue, 8 Oct 2024 17:55:51 -0400 Subject: [PATCH 47/58] #0: replace with github action for weka check --- .../actions/ensure-active-weka-mount/action.yml | 17 +++++++++++++++++ .github/workflows/perf-models-impl.yaml | 6 +----- .github/workflows/t3000-demo-tests-impl.yaml | 6 +----- .../workflows/t3000-frequent-tests-impl.yaml | 6 +----- .../workflows/t3000-model-perf-tests-impl.yaml | 6 +----- .github/workflows/t3000-unit-tests-impl.yaml | 6 +----- .github/workflows/tg-model-perf-tests-impl.yaml | 6 +----- .../workflows/tgg-model-perf-tests-impl.yaml | 6 +----- .github/workflows/ttnn-post-commit-wrapper.yaml | 2 +- 9 files changed, 25 insertions(+), 36 deletions(-) create mode 100644 .github/actions/ensure-active-weka-mount/action.yml diff --git a/.github/actions/ensure-active-weka-mount/action.yml b/.github/actions/ensure-active-weka-mount/action.yml new file mode 100644 index 00000000000..8a80a8ef436 --- /dev/null +++ b/.github/actions/ensure-active-weka-mount/action.yml @@ -0,0 +1,17 @@ +name: "Ensure Active Weka Mount" +description: "Make sure weka mount is active" + +inputs: + os: + description: 'Runner OS' + required: true + +runs: + using: "composite" + steps: + - name: Ensure active weka mount + shell: bash + run: | + sudo systemctl restart mnt-MLPerf.mount + sudo /etc/rc.local + ls -al /mnt/MLPerf/bit_error_tests diff --git a/.github/workflows/perf-models-impl.yaml b/.github/workflows/perf-models-impl.yaml index d44cc99e715..0fb59e1add7 100644 --- a/.github/workflows/perf-models-impl.yaml +++ b/.github/workflows/perf-models-impl.yaml @@ -28,11 +28,7 @@ jobs: - name: Enable Performance mode run: | sudo cpupower frequency-set -g performance - - name: Ensure weka mount is active - run: | - sudo systemctl restart mnt-MLPerf.mount - sudo /etc/rc.local - ls -al /mnt/MLPerf/bit_error_tests + - uses: ./.github/actions/ensure-active-weka-mount - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV diff --git a/.github/workflows/t3000-demo-tests-impl.yaml b/.github/workflows/t3000-demo-tests-impl.yaml index ce5f82ac000..defc6b3d2b1 100644 --- a/.github/workflows/t3000-demo-tests-impl.yaml +++ b/.github/workflows/t3000-demo-tests-impl.yaml @@ -28,11 +28,7 @@ jobs: - name: Enable performance mode run: | sudo cpupower frequency-set -g performance - - name: Ensure weka mount is active - run: | - sudo systemctl restart mnt-MLPerf.mount - sudo /etc/rc.local - ls -al /mnt/MLPerf/bit_error_tests + - uses: ./.github/actions/ensure-active-weka-mount - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV diff --git a/.github/workflows/t3000-frequent-tests-impl.yaml b/.github/workflows/t3000-frequent-tests-impl.yaml index a0edf468a68..2df18fbea23 100644 --- a/.github/workflows/t3000-frequent-tests-impl.yaml +++ b/.github/workflows/t3000-frequent-tests-impl.yaml @@ -27,11 +27,7 @@ jobs: runs-on: ["arch-wormhole_b0", "config-t3000", "in-service", "pipeline-functional"] steps: - uses: tenstorrent-metal/metal-workflows/.github/actions/checkout-with-submodule-lfs@v2.0.0 - - name: Ensure weka mount is active - run: | - sudo systemctl restart mnt-MLPerf.mount - sudo /etc/rc.local - ls -al /mnt/MLPerf/bit_error_tests + - uses: ./.github/actions/ensure-active-weka-mount - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV diff --git a/.github/workflows/t3000-model-perf-tests-impl.yaml b/.github/workflows/t3000-model-perf-tests-impl.yaml index 379cdf2f284..366ec614e74 100644 --- a/.github/workflows/t3000-model-perf-tests-impl.yaml +++ b/.github/workflows/t3000-model-perf-tests-impl.yaml @@ -30,11 +30,7 @@ jobs: - name: Enable performance mode run: | sudo cpupower frequency-set -g performance - - name: Ensure weka mount is active - run: | - sudo systemctl restart mnt-MLPerf.mount - sudo /etc/rc.local - ls -al /mnt/MLPerf/bit_error_tests + - uses: ./.github/actions/ensure-active-weka-mount - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV diff --git a/.github/workflows/t3000-unit-tests-impl.yaml b/.github/workflows/t3000-unit-tests-impl.yaml index c76fe4d20cc..6634dc9cfd5 100644 --- a/.github/workflows/t3000-unit-tests-impl.yaml +++ b/.github/workflows/t3000-unit-tests-impl.yaml @@ -30,11 +30,7 @@ jobs: - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV - - name: Ensure weka mount is active - run: | - sudo systemctl restart mnt-MLPerf.mount - sudo /etc/rc.local - ls -al /mnt/MLPerf/bit_error_tests + - uses: ./.github/actions/ensure-active-weka-mount - uses: actions/download-artifact@v4 with: name: TTMetal_build_${{ matrix.test-group.arch }} diff --git a/.github/workflows/tg-model-perf-tests-impl.yaml b/.github/workflows/tg-model-perf-tests-impl.yaml index dd10b6109a9..255a934a423 100644 --- a/.github/workflows/tg-model-perf-tests-impl.yaml +++ b/.github/workflows/tg-model-perf-tests-impl.yaml @@ -37,11 +37,7 @@ jobs: - name: Enable performance mode run: | sudo cpupower frequency-set -g performance - - name: Ensure weka mount is active - run: | - sudo systemctl restart mnt-MLPerf.mount - sudo /etc/rc.local - ls -al /mnt/MLPerf/bit_error_tests + - uses: ./.github/actions/ensure-active-weka-mount - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV diff --git a/.github/workflows/tgg-model-perf-tests-impl.yaml b/.github/workflows/tgg-model-perf-tests-impl.yaml index f3d44f2e2ba..df523469fbe 100644 --- a/.github/workflows/tgg-model-perf-tests-impl.yaml +++ b/.github/workflows/tgg-model-perf-tests-impl.yaml @@ -37,11 +37,7 @@ jobs: - name: Enable performance mode run: | sudo cpupower frequency-set -g performance - - name: Ensure weka mount is active - run: | - sudo systemctl restart mnt-MLPerf.mount - sudo /etc/rc.local - ls -al /mnt/MLPerf/bit_error_tests + - uses: ./.github/actions/ensure-active-weka-mount - name: Set up dynamic env vars for build run: | echo "TT_METAL_HOME=$(pwd)" >> $GITHUB_ENV diff --git a/.github/workflows/ttnn-post-commit-wrapper.yaml b/.github/workflows/ttnn-post-commit-wrapper.yaml index b94be9dd0e3..cd134bb9785 100644 --- a/.github/workflows/ttnn-post-commit-wrapper.yaml +++ b/.github/workflows/ttnn-post-commit-wrapper.yaml @@ -25,4 +25,4 @@ jobs: uses: ./.github/workflows/ttnn-post-commit.yaml with: arch: ${{ matrix.test-group.arch}} - runner-label: ${{ matrix.test-group.runner-label}} + runner-label: ${{ matrix.test-group.runner-label }} From 53b60f494e5e9540b7149aa5d6e924e4ac09f3df Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Wed, 9 Oct 2024 14:57:58 +0000 Subject: [PATCH 48/58] #13588: Move cached program commands to be local to program to avoid memory leak Fix HostMemDeviceCommand move constructor to properly move host vector instead of deepcopy --- tt_metal/impl/dispatch/command_queue.cpp | 1282 ++++++++--------- tt_metal/impl/dispatch/command_queue.hpp | 28 +- tt_metal/impl/dispatch/device_command.hpp | 17 +- .../dispatch/program_command_sequence.hpp | 38 + tt_metal/impl/program/program.hpp | 4 + 5 files changed, 705 insertions(+), 664 deletions(-) create mode 100644 tt_metal/impl/dispatch/program_command_sequence.hpp diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index f73d27b8e23..192c37b4d41 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -64,10 +64,6 @@ enum DispatchWriteOffsets { DISPATCH_WRITE_OFFSET_ETH_L1_CONFIG_BASE = 2, }; -// TODO: Delete entries when programs are deleted to save memory -thread_local std::unordered_map - EnqueueProgramCommand::cached_program_command_sequences = {}; - // EnqueueReadBufferCommandSection EnqueueReadBufferCommand::EnqueueReadBufferCommand( @@ -339,26 +335,26 @@ EnqueueProgramCommand::EnqueueProgramCommand( this->unicast_cores_launch_message_wptr = unicast_cores_launch_message_wptr; } -void EnqueueProgramCommand::assemble_preamble_commands(std::vector& kernel_config_addrs) { +void EnqueueProgramCommand::assemble_preamble_commands(ProgramCommandSequence& program_command_sequence, std::vector& kernel_config_addrs) { constexpr uint32_t uncached_cmd_sequence_sizeB = CQ_PREFETCH_CMD_BARE_MIN_SIZE; // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_SET_WRITE_OFFSET - this->cached_program_command_sequences[program.id].preamble_command_sequence = + program_command_sequence.preamble_command_sequence = HostMemDeviceCommand(uncached_cmd_sequence_sizeB); // Send write offsets if (hal.get_programmable_core_type_count() >= 2) { - this->cached_program_command_sequences[program.id].preamble_command_sequence.add_dispatch_set_write_offsets( + program_command_sequence.preamble_command_sequence.add_dispatch_set_write_offsets( 0, kernel_config_addrs[hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX)].addr, kernel_config_addrs[hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH)].addr); } else { - this->cached_program_command_sequences[program.id].preamble_command_sequence.add_dispatch_set_write_offsets( + program_command_sequence.preamble_command_sequence.add_dispatch_set_write_offsets( 0, kernel_config_addrs[hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX)].addr, 0); } } -void EnqueueProgramCommand::assemble_stall_commands(bool prefetch_stall) { +void EnqueueProgramCommand::assemble_stall_commands(ProgramCommandSequence& program_command_sequence, bool prefetch_stall) { if (prefetch_stall) { // Wait command so previous program finishes // Wait command with barrier for binaries to commit to DRAM @@ -367,7 +363,7 @@ void EnqueueProgramCommand::assemble_stall_commands(bool prefetch_stall) { CQ_PREFETCH_CMD_BARE_MIN_SIZE + // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT CQ_PREFETCH_CMD_BARE_MIN_SIZE; // CQ_PREFETCH_CMD_STALL - this->cached_program_command_sequences[program.id].stall_command_sequence = + program_command_sequence.stall_command_sequence = HostMemDeviceCommand(uncached_cmd_sequence_sizeB); // Wait for Noc Write Barrier @@ -375,16 +371,16 @@ void EnqueueProgramCommand::assemble_stall_commands(bool prefetch_stall) { // Wait Noc Write Barrier, wait for binaries to be written to worker cores // Stall to allow binaries to commit to DRAM first // TODO: this can be removed for all but the first program run - this->cached_program_command_sequences[program.id].stall_command_sequence.add_dispatch_wait_with_prefetch_stall( + program_command_sequence.stall_command_sequence.add_dispatch_wait_with_prefetch_stall( true, this->dispatch_message_addr, this->expected_num_workers_completed); } else { // Wait command so previous program finishes constexpr uint32_t cached_cmd_sequence_sizeB = CQ_PREFETCH_CMD_BARE_MIN_SIZE; // CQ_PREFETCH_CMD_RELAY_INLINE + CQ_DISPATCH_CMD_WAIT - this->cached_program_command_sequences[program.id].stall_command_sequence = + program_command_sequence.stall_command_sequence = HostMemDeviceCommand(cached_cmd_sequence_sizeB); - this->cached_program_command_sequences[program.id].stall_command_sequence.add_dispatch_wait( + program_command_sequence.stall_command_sequence.add_dispatch_wait( false, this->dispatch_message_addr, this->expected_num_workers_completed); } } @@ -528,7 +524,7 @@ void generate_runtime_args_cmds( } // Generate command sequence for unique (unicast) and common (multicast) runtime args -void EnqueueProgramCommand::assemble_runtime_args_commands() { +void EnqueueProgramCommand::assemble_runtime_args_commands(ProgramCommandSequence& program_command_sequence) { CoreType dispatch_core_type = dispatch_core_manager::instance().get_dispatch_core_type(this->device->id()); const uint32_t max_prefetch_command_size = dispatch_constants::get(dispatch_core_type).max_prefetch_command_size(); @@ -543,7 +539,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { std::vector>> common_rt_data_and_sizes; std::vector>> common_rt_args_data; - this->cached_program_command_sequences[program.id].runtime_args_command_sequences = {}; + program_command_sequence.runtime_args_command_sequences = {}; uint32_t command_count = 0; for (uint32_t programmable_core_type_index = 0; @@ -563,7 +559,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { } } - this->cached_program_command_sequences[program.id].runtime_args_command_sequences.reserve(command_count); + program_command_sequence.runtime_args_command_sequences.reserve(command_count); // Unique Runtime Args (Unicast) for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { if (hal.get_programmable_core_type(index) == HalProgrammableCoreType::IDLE_ETH) { @@ -608,7 +604,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { } uint32_t rta_offset = program.get_program_config(index).rta_offset; generate_runtime_args_cmds( - this->cached_program_command_sequences[program.id].runtime_args_command_sequences, + program_command_sequence.runtime_args_command_sequences, rta_offset, unique_sub_cmds, unique_rt_data_and_sizes, @@ -691,7 +687,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { std::visit( [&](auto&& sub_cmds) { generate_runtime_args_cmds( - this->cached_program_command_sequences[program.id].runtime_args_command_sequences, + program_command_sequence.runtime_args_command_sequences, crta_offset, sub_cmds, common_rt_data_and_sizes, @@ -718,691 +714,621 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { } uint32_t runtime_args_fetch_size_bytes = 0; - for (const auto& cmds : this->cached_program_command_sequences[program.id].runtime_args_command_sequences) { + for (const auto& cmds : program_command_sequence.runtime_args_command_sequences) { // BRISC, NCRISC, TRISC... runtime_args_fetch_size_bytes += cmds.size_bytes(); } - this->cached_program_command_sequences[program.id].runtime_args_fetch_size_bytes = runtime_args_fetch_size_bytes; -} - -void EnqueueProgramCommand::assemble_device_commands( - bool is_cached, std::vector& kernel_config_addrs) { - auto& cached_program_command_sequence = this->cached_program_command_sequences[this->program.id]; - if (not is_cached) { - // Calculate size of command and fill program indices of data to update - // TODO: Would be nice if we could pull this out of program - uint32_t cmd_sequence_sizeB = 0; - const uint32_t max_prefetch_command_size = - dispatch_constants::get(dispatch_core_type).max_prefetch_command_size(); - - // Multicast Semaphore Cmd - uint32_t num_multicast_semaphores = program.program_transfer_info.multicast_semaphores.size(); - std::vector> multicast_sem_sub_cmds(num_multicast_semaphores); - std::vector>> multicast_sem_data(num_multicast_semaphores); - std::vector>> multicast_sem_payload(num_multicast_semaphores); - std::vector> multicast_sem_dst_size; - multicast_sem_dst_size.reserve(num_multicast_semaphores); - if (num_multicast_semaphores > 0) { - uint32_t i = 0; - for (const auto& [dst, transfer_info_vec] : program.program_transfer_info.multicast_semaphores) { - // TODO: loop over things inside transfer_info[i] - uint32_t write_packed_len = transfer_info_vec[0].data.size(); - multicast_sem_dst_size.emplace_back(std::make_pair(dst, write_packed_len * sizeof(uint32_t))); - - for (const auto& transfer_info : transfer_info_vec) { - for (const auto& dst_noc_info : transfer_info.dst_noc_info) { - TT_ASSERT( - transfer_info.data.size() == write_packed_len, - "Not all data vectors in write packed semaphore cmd equal in len"); - multicast_sem_sub_cmds[i].emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = this->device->get_noc_multicast_encoding( - this->noc_index, std::get(dst_noc_info.first)), - .num_mcast_dests = dst_noc_info.second}); - multicast_sem_data[i].emplace_back( - transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); - } + program_command_sequence.runtime_args_fetch_size_bytes = runtime_args_fetch_size_bytes; +} + +void EnqueueProgramCommand::assemble_device_commands(ProgramCommandSequence& program_command_sequence, std::vector& kernel_config_addrs) { + // Calculate size of command and fill program indices of data to update + // TODO: Would be nice if we could pull this out of program + uint32_t cmd_sequence_sizeB = 0; + const uint32_t max_prefetch_command_size = + dispatch_constants::get(dispatch_core_type).max_prefetch_command_size(); + + // Multicast Semaphore Cmd + uint32_t num_multicast_semaphores = program.program_transfer_info.multicast_semaphores.size(); + std::vector> multicast_sem_sub_cmds(num_multicast_semaphores); + std::vector>> multicast_sem_data(num_multicast_semaphores); + std::vector>> multicast_sem_payload(num_multicast_semaphores); + std::vector> multicast_sem_dst_size; + multicast_sem_dst_size.reserve(num_multicast_semaphores); + if (num_multicast_semaphores > 0) { + uint32_t i = 0; + for (const auto& [dst, transfer_info_vec] : program.program_transfer_info.multicast_semaphores) { + // TODO: loop over things inside transfer_info[i] + uint32_t write_packed_len = transfer_info_vec[0].data.size(); + multicast_sem_dst_size.emplace_back(std::make_pair(dst, write_packed_len * sizeof(uint32_t))); + + for (const auto& transfer_info : transfer_info_vec) { + for (const auto& dst_noc_info : transfer_info.dst_noc_info) { + TT_ASSERT( + transfer_info.data.size() == write_packed_len, + "Not all data vectors in write packed semaphore cmd equal in len"); + multicast_sem_sub_cmds[i].emplace_back(CQDispatchWritePackedMulticastSubCmd{ + .noc_xy_addr = this->device->get_noc_multicast_encoding( + this->noc_index, std::get(dst_noc_info.first)), + .num_mcast_dests = dst_noc_info.second}); + multicast_sem_data[i].emplace_back( + transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); } - cmd_sequence_sizeB += insert_write_packed_payloads( - multicast_sem_sub_cmds[i].size(), - multicast_sem_dst_size.back().second, - max_prefetch_command_size, - this->packed_write_max_unicast_sub_cmds, - multicast_sem_payload[i]); - i++; } + cmd_sequence_sizeB += insert_write_packed_payloads( + multicast_sem_sub_cmds[i].size(), + multicast_sem_dst_size.back().second, + max_prefetch_command_size, + this->packed_write_max_unicast_sub_cmds, + multicast_sem_payload[i]); + i++; } + } - // Unicast Semaphore Cmd - uint32_t num_unicast_semaphores = program.program_transfer_info.unicast_semaphores.size(); - std::vector> unicast_sem_sub_cmds(num_unicast_semaphores); - std::vector>> unicast_sem_data(num_unicast_semaphores); - std::vector>> unicast_sem_payload(num_unicast_semaphores); - std::vector> unicast_sem_dst_size; - unicast_sem_dst_size.reserve(num_unicast_semaphores); - if (num_unicast_semaphores > 0) { - uint32_t i = 0; - for (const auto& [dst, transfer_info_vec] : program.program_transfer_info.unicast_semaphores) { - // TODO: loop over things inside transfer_info[i] - uint32_t write_packed_len = transfer_info_vec[0].data.size(); - unicast_sem_dst_size.emplace_back(std::make_pair(dst, write_packed_len * sizeof(uint32_t))); - - for (const auto& transfer_info : transfer_info_vec) { - for (const auto& dst_noc_info : transfer_info.dst_noc_info) { - TT_ASSERT( - transfer_info.data.size() == write_packed_len, - "Not all data vectors in write packed semaphore cmd equal in len"); - unicast_sem_sub_cmds[i].emplace_back(CQDispatchWritePackedUnicastSubCmd{ - .noc_xy_addr = this->device->get_noc_unicast_encoding( - this->noc_index, std::get(dst_noc_info.first))}); - unicast_sem_data[i].emplace_back( - transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); - } + // Unicast Semaphore Cmd + uint32_t num_unicast_semaphores = program.program_transfer_info.unicast_semaphores.size(); + std::vector> unicast_sem_sub_cmds(num_unicast_semaphores); + std::vector>> unicast_sem_data(num_unicast_semaphores); + std::vector>> unicast_sem_payload(num_unicast_semaphores); + std::vector> unicast_sem_dst_size; + unicast_sem_dst_size.reserve(num_unicast_semaphores); + if (num_unicast_semaphores > 0) { + uint32_t i = 0; + for (const auto& [dst, transfer_info_vec] : program.program_transfer_info.unicast_semaphores) { + // TODO: loop over things inside transfer_info[i] + uint32_t write_packed_len = transfer_info_vec[0].data.size(); + unicast_sem_dst_size.emplace_back(std::make_pair(dst, write_packed_len * sizeof(uint32_t))); + + for (const auto& transfer_info : transfer_info_vec) { + for (const auto& dst_noc_info : transfer_info.dst_noc_info) { + TT_ASSERT( + transfer_info.data.size() == write_packed_len, + "Not all data vectors in write packed semaphore cmd equal in len"); + unicast_sem_sub_cmds[i].emplace_back(CQDispatchWritePackedUnicastSubCmd{ + .noc_xy_addr = this->device->get_noc_unicast_encoding( + this->noc_index, std::get(dst_noc_info.first))}); + unicast_sem_data[i].emplace_back( + transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); } - cmd_sequence_sizeB += insert_write_packed_payloads( - unicast_sem_sub_cmds[i].size(), - unicast_sem_dst_size.back().second, - max_prefetch_command_size, - this->packed_write_max_unicast_sub_cmds, - unicast_sem_payload[i]); - i++; } + cmd_sequence_sizeB += insert_write_packed_payloads( + unicast_sem_sub_cmds[i].size(), + unicast_sem_dst_size.back().second, + max_prefetch_command_size, + this->packed_write_max_unicast_sub_cmds, + unicast_sem_payload[i]); + i++; } + } - const auto& circular_buffers_unique_coreranges = program.circular_buffers_unique_coreranges(); - const uint16_t num_multicast_cb_sub_cmds = circular_buffers_unique_coreranges.size(); - std::vector> mcast_cb_payload; - uint16_t cb_config_size_bytes = 0; - uint32_t aligned_cb_config_size_bytes = 0; - std::vector> cb_config_payloads( - num_multicast_cb_sub_cmds, - std::vector(UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG * NUM_CIRCULAR_BUFFERS, 0)); - std::vector multicast_cb_config_sub_cmds; - std::vector> multicast_cb_config_data; - if (num_multicast_cb_sub_cmds > 0) { - multicast_cb_config_sub_cmds.reserve(num_multicast_cb_sub_cmds); - multicast_cb_config_data.reserve(num_multicast_cb_sub_cmds); - cached_program_command_sequence.circular_buffers_on_core_ranges.resize(num_multicast_cb_sub_cmds); - uint32_t i = 0; - uint32_t max_overall_base_index = 0; - for (const CoreRange& core_range : circular_buffers_unique_coreranges) { - const CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start_coord); - const CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end_coord); - - const uint32_t num_receivers = core_range.size(); - auto& cb_config_payload = cb_config_payloads[i]; - uint32_t max_base_index = 0; - const auto& circular_buffers_on_corerange = program.circular_buffers_on_corerange(core_range); - cached_program_command_sequence.circular_buffers_on_core_ranges[i].reserve( - circular_buffers_on_corerange.size()); - for (const shared_ptr& cb : circular_buffers_on_corerange) { - cached_program_command_sequence.circular_buffers_on_core_ranges[i].emplace_back(cb); - const uint32_t cb_address = cb->address() >> 4; - const uint32_t cb_size = cb->size() >> 4; - for (const auto& buffer_index : cb->buffer_indices()) { - // 1 cmd for all 32 buffer indices, populate with real data for specified indices - - // cb config payload - const uint32_t base_index = UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG * (uint32_t)buffer_index; - cb_config_payload[base_index] = cb_address; - cb_config_payload[base_index + 1] = cb_size; - cb_config_payload[base_index + 2] = cb->num_pages(buffer_index); - cb_config_payload[base_index + 3] = cb->page_size(buffer_index) >> 4; - max_base_index = std::max(max_base_index, base_index); - } + const auto& circular_buffers_unique_coreranges = program.circular_buffers_unique_coreranges(); + const uint16_t num_multicast_cb_sub_cmds = circular_buffers_unique_coreranges.size(); + std::vector> mcast_cb_payload; + uint16_t cb_config_size_bytes = 0; + uint32_t aligned_cb_config_size_bytes = 0; + std::vector> cb_config_payloads( + num_multicast_cb_sub_cmds, + std::vector(UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG * NUM_CIRCULAR_BUFFERS, 0)); + std::vector multicast_cb_config_sub_cmds; + std::vector> multicast_cb_config_data; + if (num_multicast_cb_sub_cmds > 0) { + multicast_cb_config_sub_cmds.reserve(num_multicast_cb_sub_cmds); + multicast_cb_config_data.reserve(num_multicast_cb_sub_cmds); + program_command_sequence.circular_buffers_on_core_ranges.resize(num_multicast_cb_sub_cmds); + uint32_t i = 0; + uint32_t max_overall_base_index = 0; + for (const CoreRange& core_range : circular_buffers_unique_coreranges) { + const CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start_coord); + const CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end_coord); + + const uint32_t num_receivers = core_range.size(); + auto& cb_config_payload = cb_config_payloads[i]; + uint32_t max_base_index = 0; + const auto& circular_buffers_on_corerange = program.circular_buffers_on_corerange(core_range); + program_command_sequence.circular_buffers_on_core_ranges[i].reserve( + circular_buffers_on_corerange.size()); + for (const shared_ptr& cb : circular_buffers_on_corerange) { + program_command_sequence.circular_buffers_on_core_ranges[i].emplace_back(cb); + const uint32_t cb_address = cb->address() >> 4; + const uint32_t cb_size = cb->size() >> 4; + for (const auto& buffer_index : cb->buffer_indices()) { + // 1 cmd for all 32 buffer indices, populate with real data for specified indices + + // cb config payload + const uint32_t base_index = UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG * (uint32_t)buffer_index; + cb_config_payload[base_index] = cb_address; + cb_config_payload[base_index + 1] = cb_size; + cb_config_payload[base_index + 2] = cb->num_pages(buffer_index); + cb_config_payload[base_index + 3] = cb->page_size(buffer_index) >> 4; + max_base_index = std::max(max_base_index, base_index); } - multicast_cb_config_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = this->device->get_noc_multicast_encoding( - this->noc_index, CoreRange(physical_start, physical_end)), - .num_mcast_dests = (uint32_t)core_range.size()}); - multicast_cb_config_data.emplace_back( - cb_config_payload.data(), - (max_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t)); - max_overall_base_index = std::max(max_overall_base_index, max_base_index); - i++; } - uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); - cb_config_size_bytes = - (max_overall_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t); - aligned_cb_config_size_bytes = align(cb_config_size_bytes, l1_alignment); - cmd_sequence_sizeB += insert_write_packed_payloads( - num_multicast_cb_sub_cmds, - cb_config_size_bytes, - max_prefetch_command_size, - this->packed_write_max_unicast_sub_cmds, - mcast_cb_payload); + multicast_cb_config_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ + .noc_xy_addr = this->device->get_noc_multicast_encoding( + this->noc_index, CoreRange(physical_start, physical_end)), + .num_mcast_dests = (uint32_t)core_range.size()}); + multicast_cb_config_data.emplace_back( + cb_config_payload.data(), + (max_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t)); + max_overall_base_index = std::max(max_overall_base_index, max_base_index); + i++; } + uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); + cb_config_size_bytes = + (max_overall_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t); + aligned_cb_config_size_bytes = align(cb_config_size_bytes, l1_alignment); + cmd_sequence_sizeB += insert_write_packed_payloads( + num_multicast_cb_sub_cmds, + cb_config_size_bytes, + max_prefetch_command_size, + this->packed_write_max_unicast_sub_cmds, + mcast_cb_payload); + } + + // Program Binaries and Go Signals + // Get launch msg data while getting size of cmds + std::vector> kernel_bins_prefetch_subcmds; + std::vector> kernel_bins_dispatch_subcmds; + std::vector kernel_bins_write_packed_large_data_aligned_sizeB; + std::vector kernel_bins_unicast_cmds; + const uint32_t max_length_per_sub_cmd = dispatch_constants::get(this->dispatch_core_type).scratch_db_size() / 2; + const uint32_t max_paged_length_per_sub_cmd = + max_length_per_sub_cmd / HostMemDeviceCommand::PROGRAM_PAGE_SIZE * HostMemDeviceCommand::PROGRAM_PAGE_SIZE; + for (const auto& [cores, num_mcast_dests, kg_transfer_info] : program.program_transfer_info.kernel_bins) { + bool write_linear; + uint32_t noc_encoding; + std::visit( + [&](auto&& cores) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + noc_encoding = this->device->get_noc_multicast_encoding(this->noc_index, cores); + write_linear = false; + } else { + noc_encoding = this->device->get_noc_unicast_encoding(this->noc_index, cores); + write_linear = true; + } + }, + cores); + for (uint32_t kernel_idx = 0; kernel_idx < kg_transfer_info.dst_base_addrs.size(); kernel_idx++) { + if (write_linear) { + kernel_bins_unicast_cmds.emplace_back(2 * CQ_PREFETCH_CMD_BARE_MIN_SIZE); + cmd_sequence_sizeB += 2 * CQ_PREFETCH_CMD_BARE_MIN_SIZE; + kernel_bins_unicast_cmds.back().add_dispatch_write_linear( + false, // flush_prefetch + num_mcast_dests, // num_mcast_dests + noc_encoding, // noc_xy_addr + kg_transfer_info.dst_base_addrs[kernel_idx], + kg_transfer_info.lengths[kernel_idx]); + RecordDispatchData( + program, + DISPATCH_DATA_BINARY, + kg_transfer_info.lengths[kernel_idx], + kg_transfer_info.riscvs[kernel_idx]); + // Difference between prefetch total relayed pages and dispatch write linear + uint32_t relayed_bytes = + align(kg_transfer_info.lengths[kernel_idx], HostMemDeviceCommand::PROGRAM_PAGE_SIZE); + uint16_t length_adjust = uint16_t(relayed_bytes - kg_transfer_info.lengths[kernel_idx]); + + uint32_t base_address, page_offset; + if (kg_transfer_info.page_offsets[kernel_idx] > CQ_PREFETCH_RELAY_PAGED_START_PAGE_MASK) { + const uint32_t num_banks = this->device->num_banks(this->program.kernels_buffer->buffer_type()); + page_offset = kg_transfer_info.page_offsets[kernel_idx] % num_banks; + uint32_t num_full_pages_written_per_bank = + kg_transfer_info.page_offsets[kernel_idx] / num_banks; + base_address = this->program.kernels_buffer->address() + + num_full_pages_written_per_bank * this->program.kernels_buffer->page_size(); + } else { + base_address = this->program.kernels_buffer->address(); + page_offset = kg_transfer_info.page_offsets[kernel_idx]; + } - // Program Binaries and Go Signals - // Get launch msg data while getting size of cmds - std::vector> kernel_bins_prefetch_subcmds; - std::vector> kernel_bins_dispatch_subcmds; - std::vector kernel_bins_write_packed_large_data_aligned_sizeB; - std::vector kernel_bins_unicast_cmds; - const uint32_t max_length_per_sub_cmd = dispatch_constants::get(this->dispatch_core_type).scratch_db_size() / 2; - const uint32_t max_paged_length_per_sub_cmd = - max_length_per_sub_cmd / HostMemDeviceCommand::PROGRAM_PAGE_SIZE * HostMemDeviceCommand::PROGRAM_PAGE_SIZE; - for (const auto& [cores, num_mcast_dests, kg_transfer_info] : program.program_transfer_info.kernel_bins) { - bool write_linear; - uint32_t noc_encoding; - std::visit( - [&](auto&& cores) { - using T = std::decay_t; - if constexpr (std::is_same_v) { - noc_encoding = this->device->get_noc_multicast_encoding(this->noc_index, cores); - write_linear = false; - } else { - noc_encoding = this->device->get_noc_unicast_encoding(this->noc_index, cores); - write_linear = true; + kernel_bins_unicast_cmds.back().add_prefetch_relay_paged( + true, // is_dram + page_offset, + base_address, + this->program.kernels_buffer->page_size(), + relayed_bytes / this->program.kernels_buffer->page_size(), + length_adjust); + } else { + uint32_t base_address = this->program.kernels_buffer->address(); + uint32_t page_offset = kg_transfer_info.page_offsets[kernel_idx]; + uint32_t dst_addr = kg_transfer_info.dst_base_addrs[kernel_idx]; + uint32_t aligned_length = align(kg_transfer_info.lengths[kernel_idx], hal.get_alignment(HalMemType::DRAM)); + uint32_t padding = aligned_length - kg_transfer_info.lengths[kernel_idx]; + while (aligned_length != 0) { + if (kernel_bins_dispatch_subcmds.empty() || + kernel_bins_dispatch_subcmds.back().size() == + CQ_DISPATCH_CMD_PACKED_WRITE_LARGE_MAX_SUB_CMDS) { + kernel_bins_dispatch_subcmds.push_back({}); + kernel_bins_prefetch_subcmds.push_back({}); + kernel_bins_write_packed_large_data_aligned_sizeB.push_back(0); } - }, - cores); - for (uint32_t kernel_idx = 0; kernel_idx < kg_transfer_info.dst_base_addrs.size(); kernel_idx++) { - if (write_linear) { - kernel_bins_unicast_cmds.emplace_back(2 * CQ_PREFETCH_CMD_BARE_MIN_SIZE); - cmd_sequence_sizeB += 2 * CQ_PREFETCH_CMD_BARE_MIN_SIZE; - kernel_bins_unicast_cmds.back().add_dispatch_write_linear( - false, // flush_prefetch - num_mcast_dests, // num_mcast_dests - noc_encoding, // noc_xy_addr - kg_transfer_info.dst_base_addrs[kernel_idx], - kg_transfer_info.lengths[kernel_idx]); - RecordDispatchData( - program, - DISPATCH_DATA_BINARY, - kg_transfer_info.lengths[kernel_idx], - kg_transfer_info.riscvs[kernel_idx]); - // Difference between prefetch total relayed pages and dispatch write linear - uint32_t relayed_bytes = - align(kg_transfer_info.lengths[kernel_idx], HostMemDeviceCommand::PROGRAM_PAGE_SIZE); - uint16_t length_adjust = uint16_t(relayed_bytes - kg_transfer_info.lengths[kernel_idx]); - - uint32_t base_address, page_offset; - if (kg_transfer_info.page_offsets[kernel_idx] > CQ_PREFETCH_RELAY_PAGED_START_PAGE_MASK) { - const uint32_t num_banks = this->device->num_banks(this->program.kernels_buffer->buffer_type()); - page_offset = kg_transfer_info.page_offsets[kernel_idx] % num_banks; - uint32_t num_full_pages_written_per_bank = - kg_transfer_info.page_offsets[kernel_idx] / num_banks; - base_address = this->program.kernels_buffer->address() + - num_full_pages_written_per_bank * this->program.kernels_buffer->page_size(); + uint32_t write_length, read_length; + if (aligned_length <= max_length_per_sub_cmd) { + read_length = aligned_length; + write_length = read_length - padding; } else { - base_address = this->program.kernels_buffer->address(); - page_offset = kg_transfer_info.page_offsets[kernel_idx]; - } - - kernel_bins_unicast_cmds.back().add_prefetch_relay_paged( - true, // is_dram - page_offset, - base_address, - this->program.kernels_buffer->page_size(), - relayed_bytes / this->program.kernels_buffer->page_size(), - length_adjust); - } else { - uint32_t base_address = this->program.kernels_buffer->address(); - uint32_t page_offset = kg_transfer_info.page_offsets[kernel_idx]; - uint32_t dst_addr = kg_transfer_info.dst_base_addrs[kernel_idx]; - uint32_t aligned_length = align(kg_transfer_info.lengths[kernel_idx], hal.get_alignment(HalMemType::DRAM)); - uint32_t padding = aligned_length - kg_transfer_info.lengths[kernel_idx]; - while (aligned_length != 0) { - if (kernel_bins_dispatch_subcmds.empty() || - kernel_bins_dispatch_subcmds.back().size() == - CQ_DISPATCH_CMD_PACKED_WRITE_LARGE_MAX_SUB_CMDS) { - kernel_bins_dispatch_subcmds.push_back({}); - kernel_bins_prefetch_subcmds.push_back({}); - kernel_bins_write_packed_large_data_aligned_sizeB.push_back(0); - } - uint32_t write_length, read_length; - if (aligned_length <= max_length_per_sub_cmd) { - read_length = aligned_length; - write_length = read_length - padding; - } else { - read_length = max_paged_length_per_sub_cmd; - write_length = read_length; - } - kernel_bins_dispatch_subcmds.back().emplace_back(CQDispatchWritePackedLargeSubCmd{ - .noc_xy_addr = noc_encoding, - .addr = dst_addr, - .length = (uint16_t)write_length, - .num_mcast_dests = (uint8_t)num_mcast_dests, - .flags = CQ_DISPATCH_CMD_PACKED_WRITE_LARGE_FLAG_NONE}); - RecordDispatchData( - program, DISPATCH_DATA_BINARY, write_length, kg_transfer_info.riscvs[kernel_idx]); - dst_addr += write_length; - - kernel_bins_prefetch_subcmds.back().emplace_back(CQPrefetchRelayPagedPackedSubCmd{ - .start_page = (uint16_t)page_offset, - .log_page_size = (uint16_t)HostMemDeviceCommand::LOG2_PROGRAM_PAGE_SIZE, - .base_addr = base_address, - .length = read_length}); - page_offset += read_length / HostMemDeviceCommand::PROGRAM_PAGE_SIZE; - aligned_length -= read_length; - kernel_bins_write_packed_large_data_aligned_sizeB.back() += read_length; + read_length = max_paged_length_per_sub_cmd; + write_length = read_length; } + kernel_bins_dispatch_subcmds.back().emplace_back(CQDispatchWritePackedLargeSubCmd{ + .noc_xy_addr = noc_encoding, + .addr = dst_addr, + .length = (uint16_t)write_length, + .num_mcast_dests = (uint8_t)num_mcast_dests, + .flags = CQ_DISPATCH_CMD_PACKED_WRITE_LARGE_FLAG_NONE}); + RecordDispatchData( + program, DISPATCH_DATA_BINARY, write_length, kg_transfer_info.riscvs[kernel_idx]); + dst_addr += write_length; + + kernel_bins_prefetch_subcmds.back().emplace_back(CQPrefetchRelayPagedPackedSubCmd{ + .start_page = (uint16_t)page_offset, + .log_page_size = (uint16_t)HostMemDeviceCommand::LOG2_PROGRAM_PAGE_SIZE, + .base_addr = base_address, + .length = read_length}); + page_offset += read_length / HostMemDeviceCommand::PROGRAM_PAGE_SIZE; + aligned_length -= read_length; + kernel_bins_write_packed_large_data_aligned_sizeB.back() += read_length; } } - // Unlink the last subcmd of the current core range - if (!write_linear) { - kernel_bins_dispatch_subcmds.back().back().flags |= CQ_DISPATCH_CMD_PACKED_WRITE_LARGE_FLAG_UNLINK; - } } - uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST); - for (uint32_t i = 0; i < kernel_bins_dispatch_subcmds.size(); ++i) { - cmd_sequence_sizeB += align( - ((sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd))) + - kernel_bins_dispatch_subcmds[i].size() * sizeof(CQDispatchWritePackedLargeSubCmd), - pcie_alignment); - cmd_sequence_sizeB += align( - kernel_bins_prefetch_subcmds[i].size() * sizeof(CQPrefetchRelayPagedPackedSubCmd) + - sizeof(CQPrefetchCmd), - pcie_alignment); - } - - // Wait Cmd - if (program.program_transfer_info.num_active_cores > 0) { - cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; + // Unlink the last subcmd of the current core range + if (!write_linear) { + kernel_bins_dispatch_subcmds.back().back().flags |= CQ_DISPATCH_CMD_PACKED_WRITE_LARGE_FLAG_UNLINK; } + } + uint32_t pcie_alignment = hal.get_alignment(HalMemType::HOST); + for (uint32_t i = 0; i < kernel_bins_dispatch_subcmds.size(); ++i) { + cmd_sequence_sizeB += align( + ((sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd))) + + kernel_bins_dispatch_subcmds[i].size() * sizeof(CQDispatchWritePackedLargeSubCmd), + pcie_alignment); + cmd_sequence_sizeB += align( + kernel_bins_prefetch_subcmds[i].size() * sizeof(CQPrefetchRelayPagedPackedSubCmd) + + sizeof(CQPrefetchCmd), + pcie_alignment); + } + + // Wait Cmd + if (program.program_transfer_info.num_active_cores > 0) { + cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; + } - std::vector> multicast_go_signal_data; - std::vector> unicast_go_signal_data; - std::vector multicast_go_signal_sub_cmds; - std::vector unicast_go_signal_sub_cmds; - std::vector> multicast_go_signals_payload; - std::vector> unicast_go_signals_payload; - constexpr uint32_t go_signal_sizeB = sizeof(launch_msg_t); - uint32_t aligned_go_signal_sizeB = align(go_signal_sizeB, hal.get_alignment(HalMemType::L1)); - uint32_t go_signal_size_words = aligned_go_signal_sizeB / sizeof(uint32_t); - - // TODO: eventually the code below could be structured to loop over programmable_indices - // and check for mcast/unicast - uint32_t programmable_core_index = hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX); + std::vector> multicast_go_signal_data; + std::vector> unicast_go_signal_data; + std::vector multicast_go_signal_sub_cmds; + std::vector unicast_go_signal_sub_cmds; + std::vector> multicast_go_signals_payload; + std::vector> unicast_go_signals_payload; + constexpr uint32_t go_signal_sizeB = sizeof(launch_msg_t); + uint32_t aligned_go_signal_sizeB = align(go_signal_sizeB, hal.get_alignment(HalMemType::L1)); + uint32_t go_signal_size_words = aligned_go_signal_sizeB / sizeof(uint32_t); + + // TODO: eventually the code below could be structured to loop over programmable_indices + // and check for mcast/unicast + uint32_t programmable_core_index = hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX); + for (KernelGroup& kernel_group : program.get_kernel_groups(programmable_core_index)) { + kernel_group.launch_msg.kernel_config.mode = DISPATCH_MODE_DEV; + for (uint32_t i = 0; i < kernel_config_addrs.size(); i++) { + kernel_group.launch_msg.kernel_config.kernel_config_base[i] = kernel_config_addrs[i].addr; + } + kernel_group.launch_msg.kernel_config.host_assigned_id = program.get_runtime_id(); + const void* launch_message_data = (const void*)(&kernel_group.launch_msg); + for (const CoreRange& core_range : kernel_group.core_ranges.ranges()) { + CoreCoord physical_start = + device->physical_core_from_logical_core(core_range.start_coord, kernel_group.get_core_type()); + CoreCoord physical_end = + device->physical_core_from_logical_core(core_range.end_coord, kernel_group.get_core_type()); + + multicast_go_signal_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ + .noc_xy_addr = this->device->get_noc_multicast_encoding( + this->noc_index, CoreRange(physical_start, physical_end)), + .num_mcast_dests = (uint32_t)core_range.size()}); + multicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); + } + } + if (multicast_go_signal_sub_cmds.size() > 0) { + cmd_sequence_sizeB += insert_write_packed_payloads( + multicast_go_signal_sub_cmds.size(), + go_signal_sizeB, + max_prefetch_command_size, + this->packed_write_max_unicast_sub_cmds, + multicast_go_signals_payload); + } + + programmable_core_index = hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH); + // TODO: ugly, can be fixed by looping over indices w/ some work + if (programmable_core_index != -1) { for (KernelGroup& kernel_group : program.get_kernel_groups(programmable_core_index)) { kernel_group.launch_msg.kernel_config.mode = DISPATCH_MODE_DEV; for (uint32_t i = 0; i < kernel_config_addrs.size(); i++) { kernel_group.launch_msg.kernel_config.kernel_config_base[i] = kernel_config_addrs[i].addr; } kernel_group.launch_msg.kernel_config.host_assigned_id = program.get_runtime_id(); - const void* launch_message_data = (const void*)(&kernel_group.launch_msg); + const void* launch_message_data = (const launch_msg_t*)(&kernel_group.launch_msg); for (const CoreRange& core_range : kernel_group.core_ranges.ranges()) { - CoreCoord physical_start = - device->physical_core_from_logical_core(core_range.start_coord, kernel_group.get_core_type()); - CoreCoord physical_end = - device->physical_core_from_logical_core(core_range.end_coord, kernel_group.get_core_type()); - - multicast_go_signal_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = this->device->get_noc_multicast_encoding( - this->noc_index, CoreRange(physical_start, physical_end)), - .num_mcast_dests = (uint32_t)core_range.size()}); - multicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); - } - } - if (multicast_go_signal_sub_cmds.size() > 0) { - cmd_sequence_sizeB += insert_write_packed_payloads( - multicast_go_signal_sub_cmds.size(), - go_signal_sizeB, - max_prefetch_command_size, - this->packed_write_max_unicast_sub_cmds, - multicast_go_signals_payload); - } - - programmable_core_index = hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH); - // TODO: ugly, can be fixed by looping over indices w/ some work - if (programmable_core_index != -1) { - for (KernelGroup& kernel_group : program.get_kernel_groups(programmable_core_index)) { - kernel_group.launch_msg.kernel_config.mode = DISPATCH_MODE_DEV; - for (uint32_t i = 0; i < kernel_config_addrs.size(); i++) { - kernel_group.launch_msg.kernel_config.kernel_config_base[i] = kernel_config_addrs[i].addr; - } - kernel_group.launch_msg.kernel_config.host_assigned_id = program.get_runtime_id(); - const void* launch_message_data = (const launch_msg_t*)(&kernel_group.launch_msg); - for (const CoreRange& core_range : kernel_group.core_ranges.ranges()) { - for (auto x = core_range.start_coord.x; x <= core_range.end_coord.x; x++) { - for (auto y = core_range.start_coord.y; y <= core_range.end_coord.y; y++) { - CoreCoord physical_coord = device->physical_core_from_logical_core( - CoreCoord({x, y}), kernel_group.get_core_type()); - unicast_go_signal_sub_cmds.emplace_back(CQDispatchWritePackedUnicastSubCmd{ - .noc_xy_addr = - this->device->get_noc_unicast_encoding(this->noc_index, physical_coord)}); - unicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); - } + for (auto x = core_range.start_coord.x; x <= core_range.end_coord.x; x++) { + for (auto y = core_range.start_coord.y; y <= core_range.end_coord.y; y++) { + CoreCoord physical_coord = device->physical_core_from_logical_core( + CoreCoord({x, y}), kernel_group.get_core_type()); + unicast_go_signal_sub_cmds.emplace_back(CQDispatchWritePackedUnicastSubCmd{ + .noc_xy_addr = + this->device->get_noc_unicast_encoding(this->noc_index, physical_coord)}); + unicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); } } } } + } - if (unicast_go_signal_sub_cmds.size() > 0) { - cmd_sequence_sizeB += insert_write_packed_payloads( - unicast_go_signal_sub_cmds.size(), - go_signal_sizeB, - max_prefetch_command_size, - this->packed_write_max_unicast_sub_cmds, - unicast_go_signals_payload); - } - // If dispatch_s is enabled, have dispatch_d send a semaphore update to dispatch_s - // Either dispatch_d or dispatch_s will send the go signal - cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE + this->device->dispatch_s_enabled() * CQ_PREFETCH_CMD_BARE_MIN_SIZE; + if (unicast_go_signal_sub_cmds.size() > 0) { + cmd_sequence_sizeB += insert_write_packed_payloads( + unicast_go_signal_sub_cmds.size(), + go_signal_sizeB, + max_prefetch_command_size, + this->packed_write_max_unicast_sub_cmds, + unicast_go_signals_payload); + } + // If dispatch_s is enabled, have dispatch_d send a semaphore update to dispatch_s + // Either dispatch_d or dispatch_s will send the go signal + cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE + this->device->dispatch_s_enabled() * CQ_PREFETCH_CMD_BARE_MIN_SIZE; - cached_program_command_sequence.program_command_sequence = HostMemDeviceCommand(cmd_sequence_sizeB); + program_command_sequence.device_command_sequence = HostMemDeviceCommand(cmd_sequence_sizeB); - auto& program_command_sequence = cached_program_command_sequence.program_command_sequence; + auto& device_command_sequence = program_command_sequence.device_command_sequence; - uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); + uint32_t l1_alignment = hal.get_alignment(HalMemType::L1); - // Semaphores - // Multicast Semaphore Cmd - uint32_t index = hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX); - for (uint32_t i = 0; i < num_multicast_semaphores; ++i) { - uint32_t curr_sub_cmd_idx = 0; - for (const auto& [num_sub_cmds_in_cmd, multicast_sem_payload_sizeB] : multicast_sem_payload[i]) { - program_command_sequence.add_dispatch_write_packed( - num_sub_cmds_in_cmd, - multicast_sem_dst_size[i].first + program.get_program_config(index).sem_offset, - multicast_sem_dst_size[i].second, - multicast_sem_payload_sizeB, - multicast_sem_sub_cmds[i], - multicast_sem_data[i], - this->packed_write_max_unicast_sub_cmds, - curr_sub_cmd_idx, - false, - DISPATCH_WRITE_OFFSET_TENSIX_L1_CONFIG_BASE); - curr_sub_cmd_idx += num_sub_cmds_in_cmd; - for (auto& data_and_size : multicast_sem_data[i]) { - RecordDispatchData(program, DISPATCH_DATA_SEMAPHORE, data_and_size.second); - } + // Semaphores + // Multicast Semaphore Cmd + uint32_t index = hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX); + for (uint32_t i = 0; i < num_multicast_semaphores; ++i) { + uint32_t curr_sub_cmd_idx = 0; + for (const auto& [num_sub_cmds_in_cmd, multicast_sem_payload_sizeB] : multicast_sem_payload[i]) { + device_command_sequence.add_dispatch_write_packed( + num_sub_cmds_in_cmd, + multicast_sem_dst_size[i].first + program.get_program_config(index).sem_offset, + multicast_sem_dst_size[i].second, + multicast_sem_payload_sizeB, + multicast_sem_sub_cmds[i], + multicast_sem_data[i], + this->packed_write_max_unicast_sub_cmds, + curr_sub_cmd_idx, + false, + DISPATCH_WRITE_OFFSET_TENSIX_L1_CONFIG_BASE); + curr_sub_cmd_idx += num_sub_cmds_in_cmd; + for (auto& data_and_size : multicast_sem_data[i]) { + RecordDispatchData(program, DISPATCH_DATA_SEMAPHORE, data_and_size.second); } } + } - // Unicast Semaphore Cmd - index = hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH); - for (uint32_t i = 0; i < num_unicast_semaphores; ++i) { - uint32_t curr_sub_cmd_idx = 0; - for (const auto& [num_sub_cmds_in_cmd, unicast_sem_payload_sizeB] : unicast_sem_payload[i]) { - program_command_sequence.add_dispatch_write_packed( - num_sub_cmds_in_cmd, - unicast_sem_dst_size[i].first + program.get_program_config(index).sem_offset, - unicast_sem_dst_size[i].second, - unicast_sem_payload_sizeB, - unicast_sem_sub_cmds[i], - unicast_sem_data[i], - this->packed_write_max_unicast_sub_cmds, - curr_sub_cmd_idx, - false, - DISPATCH_WRITE_OFFSET_ETH_L1_CONFIG_BASE); - curr_sub_cmd_idx += num_sub_cmds_in_cmd; - for (auto& data_and_size : unicast_sem_data[i]) { - RecordDispatchData(program, DISPATCH_DATA_SEMAPHORE, data_and_size.second); - } + // Unicast Semaphore Cmd + index = hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH); + for (uint32_t i = 0; i < num_unicast_semaphores; ++i) { + uint32_t curr_sub_cmd_idx = 0; + for (const auto& [num_sub_cmds_in_cmd, unicast_sem_payload_sizeB] : unicast_sem_payload[i]) { + device_command_sequence.add_dispatch_write_packed( + num_sub_cmds_in_cmd, + unicast_sem_dst_size[i].first + program.get_program_config(index).sem_offset, + unicast_sem_dst_size[i].second, + unicast_sem_payload_sizeB, + unicast_sem_sub_cmds[i], + unicast_sem_data[i], + this->packed_write_max_unicast_sub_cmds, + curr_sub_cmd_idx, + false, + DISPATCH_WRITE_OFFSET_ETH_L1_CONFIG_BASE); + curr_sub_cmd_idx += num_sub_cmds_in_cmd; + for (auto& data_and_size : unicast_sem_data[i]) { + RecordDispatchData(program, DISPATCH_DATA_SEMAPHORE, data_and_size.second); } } + } - // CB Configs commands - index = hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX); - if (num_multicast_cb_sub_cmds > 0) { - uint32_t curr_sub_cmd_idx = 0; - cached_program_command_sequence.cb_configs_payloads.reserve(num_multicast_cb_sub_cmds); - const uint32_t cb_config_size_words = aligned_cb_config_size_bytes / sizeof(uint32_t); - for (const auto& [num_sub_cmds_in_cmd, mcast_cb_payload_sizeB] : mcast_cb_payload) { - uint32_t write_offset_bytes = program_command_sequence.write_offset_bytes(); - program_command_sequence.add_dispatch_write_packed( - num_sub_cmds_in_cmd, - program.get_program_config(index).cb_offset, - cb_config_size_bytes, - mcast_cb_payload_sizeB, - multicast_cb_config_sub_cmds, - multicast_cb_config_data, - this->packed_write_max_unicast_sub_cmds, - curr_sub_cmd_idx, - false, - DISPATCH_WRITE_OFFSET_TENSIX_L1_CONFIG_BASE); - for (auto& data_and_size : multicast_cb_config_data) { - RecordDispatchData(program, DISPATCH_DATA_CB_CONFIG, data_and_size.second); - } - curr_sub_cmd_idx += num_sub_cmds_in_cmd; - RecordDispatchData(program, DISPATCH_DATA_CB_CONFIG, mcast_cb_payload_sizeB); - uint32_t curr_sub_cmd_data_offset_words = - (write_offset_bytes + (sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd)) + - align(num_sub_cmds_in_cmd * sizeof(CQDispatchWritePackedMulticastSubCmd), l1_alignment)) / - sizeof(uint32_t); - for (uint32_t i = 0; i < num_sub_cmds_in_cmd; ++i) { - cached_program_command_sequence.cb_configs_payloads.push_back( - (uint32_t*)program_command_sequence.data() + curr_sub_cmd_data_offset_words); - curr_sub_cmd_data_offset_words += cb_config_size_words; - } + // CB Configs commands + index = hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX); + if (num_multicast_cb_sub_cmds > 0) { + uint32_t curr_sub_cmd_idx = 0; + program_command_sequence.cb_configs_payloads.reserve(num_multicast_cb_sub_cmds); + const uint32_t cb_config_size_words = aligned_cb_config_size_bytes / sizeof(uint32_t); + for (const auto& [num_sub_cmds_in_cmd, mcast_cb_payload_sizeB] : mcast_cb_payload) { + uint32_t write_offset_bytes = device_command_sequence.write_offset_bytes(); + device_command_sequence.add_dispatch_write_packed( + num_sub_cmds_in_cmd, + program.get_program_config(index).cb_offset, + cb_config_size_bytes, + mcast_cb_payload_sizeB, + multicast_cb_config_sub_cmds, + multicast_cb_config_data, + this->packed_write_max_unicast_sub_cmds, + curr_sub_cmd_idx, + false, + DISPATCH_WRITE_OFFSET_TENSIX_L1_CONFIG_BASE); + for (auto& data_and_size : multicast_cb_config_data) { + RecordDispatchData(program, DISPATCH_DATA_CB_CONFIG, data_and_size.second); + } + curr_sub_cmd_idx += num_sub_cmds_in_cmd; + RecordDispatchData(program, DISPATCH_DATA_CB_CONFIG, mcast_cb_payload_sizeB); + uint32_t curr_sub_cmd_data_offset_words = + (write_offset_bytes + (sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd)) + + align(num_sub_cmds_in_cmd * sizeof(CQDispatchWritePackedMulticastSubCmd), l1_alignment)) / + sizeof(uint32_t); + for (uint32_t i = 0; i < num_sub_cmds_in_cmd; ++i) { + program_command_sequence.cb_configs_payloads.push_back( + (uint32_t*)device_command_sequence.data() + curr_sub_cmd_data_offset_words); + curr_sub_cmd_data_offset_words += cb_config_size_words; } } + } - // All Previous Cmds Up to This Point Go Into the Kernel Config Buffer - cached_program_command_sequence.program_config_buffer_data_size_bytes = - program_command_sequence.write_offset_bytes(); - - // Program Binaries - for (const auto& kernel_bins_unicast_cmd : kernel_bins_unicast_cmds) { - program_command_sequence.add_data( - kernel_bins_unicast_cmd.data(), - kernel_bins_unicast_cmd.size_bytes(), - kernel_bins_unicast_cmd.size_bytes()); - } - uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); - for (uint32_t i = 0; i < kernel_bins_dispatch_subcmds.size(); ++i) { - program_command_sequence.add_dispatch_write_packed_large( - dram_alignment, kernel_bins_dispatch_subcmds[i].size(), kernel_bins_dispatch_subcmds[i]); - program_command_sequence.add_prefetch_relay_paged_packed( - kernel_bins_write_packed_large_data_aligned_sizeB[i], - kernel_bins_prefetch_subcmds[i], - kernel_bins_prefetch_subcmds[i].size()); - } + // All Previous Cmds Up to This Point Go Into the Kernel Config Buffer + program_command_sequence.program_config_buffer_data_size_bytes = + device_command_sequence.write_offset_bytes(); - // Go Signals - cached_program_command_sequence.go_signals.reserve( - multicast_go_signal_sub_cmds.size() + unicast_go_signal_sub_cmds.size()); - - // Get the address for the slot this launch_message will be written to - uint32_t multicast_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::LAUNCH) + this->multicast_cores_launch_message_wptr * sizeof(launch_msg_t); - - uint8_t go_signal_mcast_flag = 0x0; - if (multicast_go_signal_sub_cmds.size() > 0) { - go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_MCAST; - uint32_t curr_sub_cmd_idx = 0; - for (const auto& [num_sub_cmds_in_cmd, multicast_go_signal_payload_sizeB] : multicast_go_signals_payload) { - uint32_t write_offset_bytes = program_command_sequence.write_offset_bytes(); - program_command_sequence.add_dispatch_write_packed( - num_sub_cmds_in_cmd, - multicast_launch_msg_addr, - go_signal_sizeB, - multicast_go_signal_payload_sizeB, - multicast_go_signal_sub_cmds, - multicast_go_signal_data, - this->packed_write_max_unicast_sub_cmds, - curr_sub_cmd_idx); - curr_sub_cmd_idx += num_sub_cmds_in_cmd; - cached_program_command_sequence.launch_msg_write_packed_cmd_ptrs.push_back(&((CQDispatchCmd*) ((uint32_t*)program_command_sequence.data() + (write_offset_bytes + sizeof(CQPrefetchCmd)) / sizeof(uint32_t)))->write_packed); - uint32_t curr_sub_cmd_data_offset_words = - (write_offset_bytes + (sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd)) + - align(num_sub_cmds_in_cmd * sizeof(CQDispatchWritePackedMulticastSubCmd), l1_alignment)) / - sizeof(uint32_t); - for (uint32_t i = 0; i < num_sub_cmds_in_cmd; ++i) { - cached_program_command_sequence.go_signals.push_back( - (launch_msg_t*)((uint32_t*)program_command_sequence.data() + curr_sub_cmd_data_offset_words)); - curr_sub_cmd_data_offset_words += go_signal_size_words; - } - } - } + // Program Binaries + for (const auto& kernel_bins_unicast_cmd : kernel_bins_unicast_cmds) { + device_command_sequence.add_data( + kernel_bins_unicast_cmd.data(), + kernel_bins_unicast_cmd.size_bytes(), + kernel_bins_unicast_cmd.size_bytes()); + } + uint32_t dram_alignment = hal.get_alignment(HalMemType::DRAM); + for (uint32_t i = 0; i < kernel_bins_dispatch_subcmds.size(); ++i) { + device_command_sequence.add_dispatch_write_packed_large( + dram_alignment, kernel_bins_dispatch_subcmds[i].size(), kernel_bins_dispatch_subcmds[i]); + device_command_sequence.add_prefetch_relay_paged_packed( + kernel_bins_write_packed_large_data_aligned_sizeB[i], + kernel_bins_prefetch_subcmds[i], + kernel_bins_prefetch_subcmds[i].size()); + } - if (unicast_go_signal_sub_cmds.size() > 0) { - uint32_t unicast_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::LAUNCH) + this->unicast_cores_launch_message_wptr * sizeof(launch_msg_t); - go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_UNICAST; - uint32_t curr_sub_cmd_idx = 0; - for (const auto& [num_sub_cmds_in_cmd, unicast_go_signal_payload_sizeB] : unicast_go_signals_payload) { - uint32_t write_offset_bytes = program_command_sequence.write_offset_bytes(); - program_command_sequence.add_dispatch_write_packed( - num_sub_cmds_in_cmd, - unicast_launch_msg_addr, - go_signal_sizeB, - unicast_go_signal_payload_sizeB, - unicast_go_signal_sub_cmds, - unicast_go_signal_data, - this->packed_write_max_unicast_sub_cmds, - curr_sub_cmd_idx); - curr_sub_cmd_idx += num_sub_cmds_in_cmd; - cached_program_command_sequence.unicast_launch_msg_write_packed_cmd_ptrs.push_back(&((CQDispatchCmd*) ((uint32_t*)program_command_sequence.data() + (write_offset_bytes + sizeof(CQPrefetchCmd)) / sizeof(uint32_t)))->write_packed); - uint32_t curr_sub_cmd_data_offset_words = - (write_offset_bytes + (sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd)) + - align(num_sub_cmds_in_cmd * sizeof(CQDispatchWritePackedUnicastSubCmd), l1_alignment)) / - sizeof(uint32_t); - for (uint32_t i = 0; i < num_sub_cmds_in_cmd; ++i) { - cached_program_command_sequence.go_signals.push_back( - (launch_msg_t*)((uint32_t*)program_command_sequence.data() + curr_sub_cmd_data_offset_words)); - curr_sub_cmd_data_offset_words += go_signal_size_words; - } - } - } + // Go Signals + program_command_sequence.go_signals.reserve( + multicast_go_signal_sub_cmds.size() + unicast_go_signal_sub_cmds.size()); - // Wait Noc Write Barrier, wait for binaries/configs and launch_msg to be written to worker cores - if (program.program_transfer_info.num_active_cores > 0) { - program_command_sequence.add_dispatch_wait(true, this->dispatch_message_addr, 0, 0, false, false); - } - DispatcherSelect dispatcher_for_go_signal = DispatcherSelect::DISPATCH_MASTER; - if (this->device->dispatch_s_enabled()) { - // dispatch_d signals dispatch_s that its safe to send the go signal after a barrier - program_command_sequence.add_notify_dispatch_s_go_signal_cmd(); - dispatcher_for_go_signal = DispatcherSelect::DISPATCH_SLAVE; - } - go_msg_t run_program_go_signal; - run_program_go_signal.signal = RUN_MSG_GO; - run_program_go_signal.master_x = (uint8_t)this->dispatch_core.x; - run_program_go_signal.master_y = (uint8_t)this->dispatch_core.y; - uint32_t write_offset_bytes = program_command_sequence.write_offset_bytes(); - program_command_sequence.add_dispatch_go_signal_mcast(this->expected_num_workers_completed, go_signal_mcast_flag, *reinterpret_cast(&run_program_go_signal), this->dispatch_message_addr, dispatcher_for_go_signal); - cached_program_command_sequence.mcast_go_signal_cmd_ptr = &((CQDispatchCmd*) ((uint32_t*)program_command_sequence.data() + (write_offset_bytes + sizeof(CQPrefetchCmd)) / sizeof(uint32_t)))->mcast; - } else { - uint32_t i = 0; - ZoneScopedN("program_loaded_on_device"); - for (const auto& cbs_on_core_range : cached_program_command_sequence.circular_buffers_on_core_ranges) { - uint32_t* cb_config_payload = cached_program_command_sequence.cb_configs_payloads[i]; - for (const shared_ptr& cb : cbs_on_core_range) { - const uint32_t cb_address = cb->address() >> 4; - const uint32_t cb_size = cb->size() >> 4; - for (const auto& buffer_index : cb->buffer_indices()) { - // 1 cmd for all 32 buffer indices, populate with real data for specified indices + // Get the address for the slot this launch_message will be written to + uint32_t multicast_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::LAUNCH) + this->multicast_cores_launch_message_wptr * sizeof(launch_msg_t); - // cb config payload - uint32_t base_index = UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG * (uint32_t)buffer_index; - cb_config_payload[base_index] = cb_address; - cb_config_payload[base_index + 1] = cb_size; - cb_config_payload[base_index + 2] = cb->num_pages(buffer_index); - cb_config_payload[base_index + 3] = cb->page_size(buffer_index) >> 4; - } - } - i++; - } - for (auto& go_signal : cached_program_command_sequence.go_signals) { - for (uint32_t i = 0; i < kernel_config_addrs.size(); i++) { - go_signal->kernel_config.kernel_config_base[i] = kernel_config_addrs[i].addr; + uint8_t go_signal_mcast_flag = 0x0; + if (multicast_go_signal_sub_cmds.size() > 0) { + go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_MCAST; + uint32_t curr_sub_cmd_idx = 0; + for (const auto& [num_sub_cmds_in_cmd, multicast_go_signal_payload_sizeB] : multicast_go_signals_payload) { + uint32_t write_offset_bytes = device_command_sequence.write_offset_bytes(); + device_command_sequence.add_dispatch_write_packed( + num_sub_cmds_in_cmd, + multicast_launch_msg_addr, + go_signal_sizeB, + multicast_go_signal_payload_sizeB, + multicast_go_signal_sub_cmds, + multicast_go_signal_data, + this->packed_write_max_unicast_sub_cmds, + curr_sub_cmd_idx); + curr_sub_cmd_idx += num_sub_cmds_in_cmd; + program_command_sequence.launch_msg_write_packed_cmd_ptrs.push_back(&((CQDispatchCmd*) ((uint32_t*)device_command_sequence.data() + (write_offset_bytes + sizeof(CQPrefetchCmd)) / sizeof(uint32_t)))->write_packed); + uint32_t curr_sub_cmd_data_offset_words = + (write_offset_bytes + (sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd)) + + align(num_sub_cmds_in_cmd * sizeof(CQDispatchWritePackedMulticastSubCmd), l1_alignment)) / + sizeof(uint32_t); + for (uint32_t i = 0; i < num_sub_cmds_in_cmd; ++i) { + program_command_sequence.go_signals.push_back( + (launch_msg_t*)((uint32_t*)device_command_sequence.data() + curr_sub_cmd_data_offset_words)); + curr_sub_cmd_data_offset_words += go_signal_size_words; } - go_signal->kernel_config.host_assigned_id = program.get_runtime_id(); } - // Update launch message addresses to reflect new launch_msg slot in ring buffer - uint32_t multicast_cores_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::LAUNCH) + this->multicast_cores_launch_message_wptr * sizeof(launch_msg_t); - for (auto launch_msg_cmd_ptr : cached_program_command_sequence.launch_msg_write_packed_cmd_ptrs) { - launch_msg_cmd_ptr->addr = multicast_cores_launch_msg_addr; - } - if (cached_program_command_sequence.unicast_launch_msg_write_packed_cmd_ptrs.size()) { - uint32_t unicast_cores_launch_message_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::LAUNCH) + this->unicast_cores_launch_message_wptr * sizeof(launch_msg_t); - for (auto launch_msg_cmd_ptr : cached_program_command_sequence.unicast_launch_msg_write_packed_cmd_ptrs) { - launch_msg_cmd_ptr->addr = unicast_cores_launch_message_addr; + } + + if (unicast_go_signal_sub_cmds.size() > 0) { + uint32_t unicast_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::LAUNCH) + this->unicast_cores_launch_message_wptr * sizeof(launch_msg_t); + go_signal_mcast_flag |= (uint8_t)GoSignalMcastSettings::SEND_UNICAST; + uint32_t curr_sub_cmd_idx = 0; + for (const auto& [num_sub_cmds_in_cmd, unicast_go_signal_payload_sizeB] : unicast_go_signals_payload) { + uint32_t write_offset_bytes = device_command_sequence.write_offset_bytes(); + device_command_sequence.add_dispatch_write_packed( + num_sub_cmds_in_cmd, + unicast_launch_msg_addr, + go_signal_sizeB, + unicast_go_signal_payload_sizeB, + unicast_go_signal_sub_cmds, + unicast_go_signal_data, + this->packed_write_max_unicast_sub_cmds, + curr_sub_cmd_idx); + curr_sub_cmd_idx += num_sub_cmds_in_cmd; + program_command_sequence.unicast_launch_msg_write_packed_cmd_ptrs.push_back(&((CQDispatchCmd*) ((uint32_t*)device_command_sequence.data() + (write_offset_bytes + sizeof(CQPrefetchCmd)) / sizeof(uint32_t)))->write_packed); + uint32_t curr_sub_cmd_data_offset_words = + (write_offset_bytes + (sizeof(CQPrefetchCmd) + sizeof(CQDispatchCmd)) + + align(num_sub_cmds_in_cmd * sizeof(CQDispatchWritePackedUnicastSubCmd), l1_alignment)) / + sizeof(uint32_t); + for (uint32_t i = 0; i < num_sub_cmds_in_cmd; ++i) { + program_command_sequence.go_signals.push_back( + (launch_msg_t*)((uint32_t*)device_command_sequence.data() + curr_sub_cmd_data_offset_words)); + curr_sub_cmd_data_offset_words += go_signal_size_words; } } - // Update go signal to reflect potentially modified dispatch core and new wait count - go_msg_t run_program_go_signal; - run_program_go_signal.signal = RUN_MSG_GO; - run_program_go_signal.master_x = (uint8_t)this->dispatch_core.x; - run_program_go_signal.master_y = (uint8_t)this->dispatch_core.y; - cached_program_command_sequence.mcast_go_signal_cmd_ptr->go_signal = *reinterpret_cast(&run_program_go_signal); - cached_program_command_sequence.mcast_go_signal_cmd_ptr->wait_count = this->expected_num_workers_completed; } -} -void EnqueueProgramCommand::process() { - bool is_cached = true; - if (not program.is_finalized()) { - program.finalize(); - is_cached = false; + // Wait Noc Write Barrier, wait for binaries/configs and launch_msg to be written to worker cores + if (program.program_transfer_info.num_active_cores > 0) { + device_command_sequence.add_dispatch_wait(true, this->dispatch_message_addr, 0, 0, false, false); } - - const std::pair&> reservation = - this->manager.get_config_buffer_mgr().reserve(program.program_config_sizes_); - bool stall_first = reservation.first.need_sync; - // Note: since present implementation always stalls, we always free up to "now" - this->manager.get_config_buffer_mgr().free(reservation.first.sync_count); - uint32_t num_workers = 0; - if (program.runs_on_noc_multicast_only_cores()) { - num_workers += device->num_worker_cores(); + DispatcherSelect dispatcher_for_go_signal = DispatcherSelect::DISPATCH_MASTER; + if (this->device->dispatch_s_enabled()) { + // dispatch_d signals dispatch_s that its safe to send the go signal after a barrier + device_command_sequence.add_notify_dispatch_s_go_signal_cmd(); + dispatcher_for_go_signal = DispatcherSelect::DISPATCH_SLAVE; } - if (program.runs_on_noc_unicast_only_cores()) { - num_workers += device->num_eth_worker_cores(); + go_msg_t run_program_go_signal; + run_program_go_signal.signal = RUN_MSG_GO; + run_program_go_signal.master_x = (uint8_t)this->dispatch_core.x; + run_program_go_signal.master_y = (uint8_t)this->dispatch_core.y; + uint32_t write_offset_bytes = device_command_sequence.write_offset_bytes(); + device_command_sequence.add_dispatch_go_signal_mcast(this->expected_num_workers_completed, go_signal_mcast_flag, *reinterpret_cast(&run_program_go_signal), this->dispatch_message_addr, dispatcher_for_go_signal); + program_command_sequence.mcast_go_signal_cmd_ptr = &((CQDispatchCmd*) ((uint32_t*)device_command_sequence.data() + (write_offset_bytes + sizeof(CQPrefetchCmd)) / sizeof(uint32_t)))->mcast; +} + +void EnqueueProgramCommand::update_device_commands(ProgramCommandSequence& cached_program_command_sequence, std::vector& kernel_config_addrs) { + uint32_t i = 0; + ZoneScopedN("program_loaded_on_device"); + for (const auto& cbs_on_core_range : cached_program_command_sequence.circular_buffers_on_core_ranges) { + uint32_t* cb_config_payload = cached_program_command_sequence.cb_configs_payloads[i]; + for (const shared_ptr& cb : cbs_on_core_range) { + const uint32_t cb_address = cb->address() >> 4; + const uint32_t cb_size = cb->size() >> 4; + for (const auto& buffer_index : cb->buffer_indices()) { + // 1 cmd for all 32 buffer indices, populate with real data for specified indices + + // cb config payload + uint32_t base_index = UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG * (uint32_t)buffer_index; + cb_config_payload[base_index] = cb_address; + cb_config_payload[base_index + 1] = cb_size; + cb_config_payload[base_index + 2] = cb->num_pages(buffer_index); + cb_config_payload[base_index + 3] = cb->page_size(buffer_index) >> 4; + } + } + i++; } - this->manager.get_config_buffer_mgr().alloc( - this->expected_num_workers_completed + num_workers); - - std::vector& kernel_config_addrs = reservation.second; - - // Calculate all commands size and determine how many fetch q entries to use - // Preamble, some waits and stalls - // can be written directly to the issue queue - if (not is_cached) { - this->assemble_preamble_commands(kernel_config_addrs); - this->assemble_stall_commands(true); - // Runtime Args Command Sequence - this->assemble_runtime_args_commands(); - - // Record kernel groups in this program, only need to do it once. - for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { - CoreType core_type = hal.get_core_type(index); - RecordKernelGroups(program, core_type, program.get_kernel_groups(index)); + for (auto& go_signal : cached_program_command_sequence.go_signals) { + for (uint32_t i = 0; i < kernel_config_addrs.size(); i++) { + go_signal->kernel_config.kernel_config_base[i] = kernel_config_addrs[i].addr; } - } else { - static constexpr uint32_t wait_count_offset = (sizeof(CQPrefetchCmd) + offsetof(CQDispatchCmd, wait.count)); - static constexpr uint32_t tensix_l1_write_offset_offset = - (sizeof(CQPrefetchCmd) + offsetof(CQDispatchCmd, set_write_offset.offset1)); - static constexpr uint32_t eth_l1_write_offset_offset = - (sizeof(CQPrefetchCmd) + offsetof(CQDispatchCmd, set_write_offset.offset2)); - TT_ASSERT( - this->cached_program_command_sequences.find(program.id) != this->cached_program_command_sequences.end(), - "Program cache hit, but no stored command sequence"); - - this->cached_program_command_sequences[program.id].stall_command_sequence.update_cmd_sequence( - wait_count_offset, &this->expected_num_workers_completed, sizeof(uint32_t)); - - this->cached_program_command_sequences[program.id].preamble_command_sequence.update_cmd_sequence( - tensix_l1_write_offset_offset, - &kernel_config_addrs[hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX)], - sizeof(uint32_t)); - if (hal.get_programmable_core_type_count() >= 2) { - this->cached_program_command_sequences[program.id].preamble_command_sequence.update_cmd_sequence( - eth_l1_write_offset_offset, - &kernel_config_addrs[hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH)], - sizeof(uint32_t)); + go_signal->kernel_config.host_assigned_id = program.get_runtime_id(); + } + // Update launch message addresses to reflect new launch_msg slot in ring buffer + uint32_t multicast_cores_launch_msg_addr = hal.get_dev_addr(HalProgrammableCoreType::TENSIX, HalL1MemAddrType::LAUNCH) + this->multicast_cores_launch_message_wptr * sizeof(launch_msg_t); + for (auto launch_msg_cmd_ptr : cached_program_command_sequence.launch_msg_write_packed_cmd_ptrs) { + launch_msg_cmd_ptr->addr = multicast_cores_launch_msg_addr; + } + if (cached_program_command_sequence.unicast_launch_msg_write_packed_cmd_ptrs.size()) { + uint32_t unicast_cores_launch_message_addr = hal.get_dev_addr(HalProgrammableCoreType::ACTIVE_ETH, HalL1MemAddrType::LAUNCH) + this->unicast_cores_launch_message_wptr * sizeof(launch_msg_t); + for (auto launch_msg_cmd_ptr : cached_program_command_sequence.unicast_launch_msg_write_packed_cmd_ptrs) { + launch_msg_cmd_ptr->addr = unicast_cores_launch_message_addr; } } - RecordProgramRun(program); - - // Main Command Sequence - this->assemble_device_commands(is_cached, kernel_config_addrs); - - const auto& cached_program_command_sequence = this->cached_program_command_sequences[program.id]; + // Update go signal to reflect potentially modified dispatch core and new wait count + go_msg_t run_program_go_signal; + run_program_go_signal.signal = RUN_MSG_GO; + run_program_go_signal.master_x = (uint8_t)this->dispatch_core.x; + run_program_go_signal.master_y = (uint8_t)this->dispatch_core.y; + cached_program_command_sequence.mcast_go_signal_cmd_ptr->go_signal = *reinterpret_cast(&run_program_go_signal); + cached_program_command_sequence.mcast_go_signal_cmd_ptr->wait_count = this->expected_num_workers_completed; +} - uint32_t preamble_fetch_size_bytes = cached_program_command_sequence.preamble_command_sequence.size_bytes(); +void EnqueueProgramCommand::write_program_command_sequence(const ProgramCommandSequence& program_command_sequence, bool stall_first) { + uint32_t preamble_fetch_size_bytes = program_command_sequence.preamble_command_sequence.size_bytes(); - uint32_t stall_fetch_size_bytes = cached_program_command_sequence.stall_command_sequence.size_bytes(); + uint32_t stall_fetch_size_bytes = program_command_sequence.stall_command_sequence.size_bytes(); - uint32_t runtime_args_fetch_size_bytes = cached_program_command_sequence.runtime_args_fetch_size_bytes; + uint32_t runtime_args_fetch_size_bytes = program_command_sequence.runtime_args_fetch_size_bytes; - uint32_t program_fetch_size_bytes = cached_program_command_sequence.program_command_sequence.size_bytes(); + uint32_t program_fetch_size_bytes = program_command_sequence.device_command_sequence.size_bytes(); uint32_t program_config_buffer_data_size_bytes = - cached_program_command_sequence.program_config_buffer_data_size_bytes; + program_command_sequence.program_config_buffer_data_size_bytes; uint32_t program_rem_fetch_size_bytes = program_fetch_size_bytes - program_config_buffer_data_size_bytes; - uint8_t* program_command_sequence_data = (uint8_t*)cached_program_command_sequence.program_command_sequence.data(); + uint8_t* program_command_sequence_data = (uint8_t*)program_command_sequence.device_command_sequence.data(); uint32_t total_fetch_size_bytes = stall_fetch_size_bytes + preamble_fetch_size_bytes + runtime_args_fetch_size_bytes + program_fetch_size_bytes; @@ -1413,17 +1339,17 @@ void EnqueueProgramCommand::process() { uint32_t write_ptr = this->manager.get_issue_queue_write_ptr(this->command_queue_id); this->manager.cq_write( - cached_program_command_sequence.preamble_command_sequence.data(), preamble_fetch_size_bytes, write_ptr); + program_command_sequence.preamble_command_sequence.data(), preamble_fetch_size_bytes, write_ptr); write_ptr += preamble_fetch_size_bytes; if (stall_first) { // Must stall before writing runtime args this->manager.cq_write( - cached_program_command_sequence.stall_command_sequence.data(), stall_fetch_size_bytes, write_ptr); + program_command_sequence.stall_command_sequence.data(), stall_fetch_size_bytes, write_ptr); write_ptr += stall_fetch_size_bytes; } - for (const auto& cmds : cached_program_command_sequence.runtime_args_command_sequences) { + for (const auto& cmds : program_command_sequence.runtime_args_command_sequences) { this->manager.cq_write(cmds.data(), cmds.size_bytes(), write_ptr); write_ptr += cmds.size_bytes(); } @@ -1437,7 +1363,7 @@ void EnqueueProgramCommand::process() { // Didn't stall before kernel config data, stall before remaining commands this->manager.cq_write( - cached_program_command_sequence.stall_command_sequence.data(), stall_fetch_size_bytes, write_ptr); + program_command_sequence.stall_command_sequence.data(), stall_fetch_size_bytes, write_ptr); write_ptr += stall_fetch_size_bytes; this->manager.cq_write(program_command_sequence_data, program_rem_fetch_size_bytes, write_ptr); @@ -1456,7 +1382,7 @@ void EnqueueProgramCommand::process() { this->manager.issue_queue_reserve(preamble_fetch_size_bytes, this->command_queue_id); uint32_t write_ptr = this->manager.get_issue_queue_write_ptr(this->command_queue_id); this->manager.cq_write( - cached_program_command_sequence.preamble_command_sequence.data(), preamble_fetch_size_bytes, write_ptr); + program_command_sequence.preamble_command_sequence.data(), preamble_fetch_size_bytes, write_ptr); this->manager.issue_queue_push_back(preamble_fetch_size_bytes, this->command_queue_id); // One fetch queue entry for just the wait and stall, very inefficient this->manager.fetch_queue_reserve_back(this->command_queue_id); @@ -1467,7 +1393,7 @@ void EnqueueProgramCommand::process() { this->manager.issue_queue_reserve(stall_fetch_size_bytes, this->command_queue_id); write_ptr = this->manager.get_issue_queue_write_ptr(this->command_queue_id); this->manager.cq_write( - cached_program_command_sequence.stall_command_sequence.data(), stall_fetch_size_bytes, write_ptr); + program_command_sequence.stall_command_sequence.data(), stall_fetch_size_bytes, write_ptr); this->manager.issue_queue_push_back(stall_fetch_size_bytes, this->command_queue_id); // One fetch queue entry for just the wait and stall, very inefficient this->manager.fetch_queue_reserve_back(this->command_queue_id); @@ -1475,7 +1401,7 @@ void EnqueueProgramCommand::process() { } // TODO: We can pack multiple RT args into one fetch q entry - for (const auto& cmds : cached_program_command_sequence.runtime_args_command_sequences) { + for (const auto& cmds : program_command_sequence.runtime_args_command_sequences) { uint32_t fetch_size_bytes = cmds.size_bytes(); this->manager.issue_queue_reserve(fetch_size_bytes, this->command_queue_id); write_ptr = this->manager.get_issue_queue_write_ptr(this->command_queue_id); @@ -1503,7 +1429,7 @@ void EnqueueProgramCommand::process() { this->manager.issue_queue_reserve(stall_fetch_size_bytes, this->command_queue_id); write_ptr = this->manager.get_issue_queue_write_ptr(this->command_queue_id); this->manager.cq_write( - cached_program_command_sequence.stall_command_sequence.data(), stall_fetch_size_bytes, write_ptr); + program_command_sequence.stall_command_sequence.data(), stall_fetch_size_bytes, write_ptr); this->manager.issue_queue_push_back(stall_fetch_size_bytes, this->command_queue_id); // One fetch queue entry for just the wait and stall, very inefficient this->manager.fetch_queue_reserve_back(this->command_queue_id); @@ -1526,10 +1452,84 @@ void EnqueueProgramCommand::process() { this->manager.fetch_queue_write(program_fetch_size_bytes, this->command_queue_id); } } +} + +void EnqueueProgramCommand::process() { + + bool is_finalized = program.is_finalized(); + if (not is_finalized) { + program.finalize(); + } + + const std::pair&> reservation = + this->manager.get_config_buffer_mgr().reserve(program.program_config_sizes_); + bool stall_first = reservation.first.need_sync; + // Note: since present implementation always stalls, we always free up to "now" + this->manager.get_config_buffer_mgr().free(reservation.first.sync_count); + uint32_t num_workers = 0; + if (program.runs_on_noc_multicast_only_cores()) { + num_workers += device->num_worker_cores(); + } + if (program.runs_on_noc_unicast_only_cores()) { + num_workers += device->num_eth_worker_cores(); + } + this->manager.get_config_buffer_mgr().alloc( + this->expected_num_workers_completed + num_workers); + + std::vector& kernel_config_addrs = reservation.second; - // Front load generating and caching stall_commands without stall during program loading stage - if (not is_cached) { - this->assemble_stall_commands(false); + RecordProgramRun(program); + + // Cache is only usable if caching is enabled and program is finalized + // If cache has a program entry but the program is not finalized, then the cache is stale + // Currently this is mapped by device, but will be mapped by multiple values in the future + uint64_t command_hash = this->device->id(); + auto cached_cmd_iter = this->program.cached_program_command_sequences_.find(command_hash); + bool is_cached = is_finalized && cached_cmd_iter != this->program.cached_program_command_sequences_.end(); + + // Calculate all commands size and determine how many fetch q entries to use + // Preamble, some waits and stalls + // can be written directly to the issue queue + if (!is_cached) { + ProgramCommandSequence program_command_sequence; + this->assemble_preamble_commands(program_command_sequence, kernel_config_addrs); + this->assemble_stall_commands(program_command_sequence, true); + // Runtime Args Command Sequence + this->assemble_runtime_args_commands(program_command_sequence); + + // Record kernel groups in this program, only need to do it once. + for (uint32_t index = 0; index < hal.get_programmable_core_type_count(); index++) { + CoreType core_type = hal.get_core_type(index); + RecordKernelGroups(program, core_type, program.get_kernel_groups(index)); + } + this->assemble_device_commands(program_command_sequence, kernel_config_addrs); + this->write_program_command_sequence(program_command_sequence, stall_first); + this->assemble_stall_commands(program_command_sequence, false); + this->program.cached_program_command_sequences_.insert({command_hash, std::move(program_command_sequence)}); + } else { + static constexpr uint32_t wait_count_offset = (sizeof(CQPrefetchCmd) + offsetof(CQDispatchCmd, wait.count)); + static constexpr uint32_t tensix_l1_write_offset_offset = + (sizeof(CQPrefetchCmd) + offsetof(CQDispatchCmd, set_write_offset.offset1)); + static constexpr uint32_t eth_l1_write_offset_offset = + (sizeof(CQPrefetchCmd) + offsetof(CQDispatchCmd, set_write_offset.offset2)); + + auto& cached_program_command_sequence = cached_cmd_iter->second; + + cached_program_command_sequence.stall_command_sequence.update_cmd_sequence( + wait_count_offset, &this->expected_num_workers_completed, sizeof(uint32_t)); + + cached_program_command_sequence.preamble_command_sequence.update_cmd_sequence( + tensix_l1_write_offset_offset, + &kernel_config_addrs[hal.get_programmable_core_type_index(HalProgrammableCoreType::TENSIX)], + sizeof(uint32_t)); + if (hal.get_programmable_core_type_count() >= 2) { + cached_program_command_sequence.preamble_command_sequence.update_cmd_sequence( + eth_l1_write_offset_offset, + &kernel_config_addrs[hal.get_programmable_core_type_index(HalProgrammableCoreType::ACTIVE_ETH)], + sizeof(uint32_t)); + } + this->update_device_commands(cached_program_command_sequence, kernel_config_addrs); + this->write_program_command_sequence(cached_program_command_sequence, stall_first); } } diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index cbaf4c9b089..821a806ff9c 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -11,9 +11,11 @@ #include #include #include +#include #include "common/env_lib.hpp" #include "tt_metal/common/base.hpp" +#include "tt_metal/impl/dispatch/program_command_sequence.hpp" #include "tt_metal/impl/dispatch/command_queue_interface.hpp" #include "tt_metal/impl/dispatch/device_command.hpp" #include "tt_metal/impl/dispatch/lock_free_queue.hpp" @@ -298,21 +300,6 @@ class EnqueueProgramCommand : public Command { uint32_t unicast_cores_launch_message_wptr = 0; public: - struct CachedProgramCommandSequence { - HostMemDeviceCommand preamble_command_sequence; - HostMemDeviceCommand stall_command_sequence; - std::vector runtime_args_command_sequences; - uint32_t runtime_args_fetch_size_bytes; - HostMemDeviceCommand program_command_sequence; - std::vector cb_configs_payloads; - std::vector>> circular_buffers_on_core_ranges; - std::vector go_signals; - uint32_t program_config_buffer_data_size_bytes; - std::vector launch_msg_write_packed_cmd_ptrs; - std::vector unicast_launch_msg_write_packed_cmd_ptrs; - CQDispatchGoSignalMcastCmd* mcast_go_signal_cmd_ptr; - }; - thread_local static std::unordered_map cached_program_command_sequences; EnqueueProgramCommand( uint32_t command_queue_id, @@ -325,10 +312,13 @@ class EnqueueProgramCommand : public Command { uint32_t multicast_cores_launch_message_wptr, uint32_t unicast_cores_launch_message_wptr); - void assemble_preamble_commands(std::vector& kernel_config_addrs); - void assemble_stall_commands(bool prefetch_stall); - void assemble_device_commands(bool is_cached, std::vector& kernel_config_addrs); - void assemble_runtime_args_commands(); + void assemble_preamble_commands(ProgramCommandSequence& program_command_sequence, std::vector& kernel_config_addrs); + void assemble_stall_commands(ProgramCommandSequence& program_command_sequence, bool prefetch_stall); + void assemble_runtime_args_commands(ProgramCommandSequence& program_command_sequence); + void assemble_device_commands(ProgramCommandSequence& program_command_sequence, std::vector& kernel_config_addrs); + void update_device_commands(ProgramCommandSequence& cached_program_command_sequence, std::vector& kernel_config_addrs); + + void write_program_command_sequence(const ProgramCommandSequence& program_command_sequence, bool stall_first); void process(); diff --git a/tt_metal/impl/dispatch/device_command.hpp b/tt_metal/impl/dispatch/device_command.hpp index 1ca49462801..141fa91be92 100644 --- a/tt_metal/impl/dispatch/device_command.hpp +++ b/tt_metal/impl/dispatch/device_command.hpp @@ -51,8 +51,13 @@ class DeviceCommand { DeviceCommand &operator=(DeviceCommand &&other) { this->cmd_sequence_sizeB = other.cmd_sequence_sizeB; this->cmd_write_offsetB = other.cmd_write_offsetB; - this->cmd_region_vector = other.cmd_region_vector; - this->deepcopy(other); + this->cmd_region_vector = std::move(other.cmd_region_vector); + if constexpr (hugepage_write) { + this->deepcopy(other); + } else { + this->cmd_region = this->cmd_region_vector.data(); + } + return *this; } DeviceCommand(const DeviceCommand &other) : @@ -64,8 +69,12 @@ class DeviceCommand { DeviceCommand(DeviceCommand &&other) : cmd_sequence_sizeB(other.cmd_sequence_sizeB), cmd_write_offsetB(other.cmd_write_offsetB), - cmd_region_vector(other.cmd_region_vector) { - this->deepcopy(other); + cmd_region_vector(std::move(other.cmd_region_vector)) { + if constexpr (hugepage_write) { + this->deepcopy(other); + } else { + this->cmd_region = this->cmd_region_vector.data(); + } } // Constants diff --git a/tt_metal/impl/dispatch/program_command_sequence.hpp b/tt_metal/impl/dispatch/program_command_sequence.hpp new file mode 100644 index 00000000000..c39a38c33d3 --- /dev/null +++ b/tt_metal/impl/dispatch/program_command_sequence.hpp @@ -0,0 +1,38 @@ +// SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include +#include + +#include "tt_metal/impl/dispatch/device_command.hpp" + +struct CQDispatchWritePackedCmd; +struct launch_msg_t; + +namespace tt::tt_metal { + +inline namespace v0 { + +class CircularBuffer; + +} // namespace v0 + +struct ProgramCommandSequence { + HostMemDeviceCommand preamble_command_sequence; + HostMemDeviceCommand stall_command_sequence; + std::vector runtime_args_command_sequences; + uint32_t runtime_args_fetch_size_bytes; + HostMemDeviceCommand device_command_sequence; + std::vector cb_configs_payloads; + std::vector>> circular_buffers_on_core_ranges; + std::vector go_signals; + uint32_t program_config_buffer_data_size_bytes; + std::vector launch_msg_write_packed_cmd_ptrs; + std::vector unicast_launch_msg_write_packed_cmd_ptrs; + CQDispatchGoSignalMcastCmd* mcast_go_signal_cmd_ptr; +}; + +} // namespace tt::tt_metal diff --git a/tt_metal/impl/program/program.hpp b/tt_metal/impl/program/program.hpp index b3eff1841ce..74b148a8544 100644 --- a/tt_metal/impl/program/program.hpp +++ b/tt_metal/impl/program/program.hpp @@ -11,6 +11,7 @@ #include "tt_metal/impl/kernels/kernel_types.hpp" #include "tt_metal/impl/buffers/circular_buffer_types.hpp" #include "tt_metal/impl/buffers/semaphore.hpp" +#include "tt_metal/impl/dispatch/program_command_sequence.hpp" #include "tt_metal/impl/program/program_device_map.hpp" #include "dev_msgs.h" @@ -224,6 +225,9 @@ class Program { std::vector program_configs_; std::vector program_config_sizes_; + + std::unordered_map cached_program_command_sequences_; + friend CBHandle CreateCircularBuffer(Program &program, const std::variant &core_spec, const CircularBufferConfig &config); friend std::shared_ptr detail::GetCircularBuffer(const Program &program, CBHandle id); friend void detail::ValidateCircularBufferRegion(const Program &program, const Device *device); From 68b08ae7d79b81bbfac258f65d3d29ea6c5ed5a7 Mon Sep 17 00:00:00 2001 From: Salar Hosseini <159165450+skhorasganiTT@users.noreply.github.com> Date: Wed, 9 Oct 2024 15:43:19 -0400 Subject: [PATCH 49/58] [skip ci] Update perf and latest features for llm models (Oct 7) (#13648) --- README.md | 16 ++++++++-------- models/MODEL_UPDATES.md | 7 +++++++ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 416d466ba23..62a6cd660aa 100644 --- a/README.md +++ b/README.md @@ -24,17 +24,17 @@ | Model | Batch | Hardware | ttft (s) | t/s/u | Target t/s/u | Release | |----------------------------------------------------------------------|-------|----------------------------------------------------------|------------|-------|--------------|---------------------------------------------------------------------------| | [Falcon7B-decode](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | | -| [Falcon7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.07 | 16.7 | 26 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Falcon7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.07 | 16.7 | 26 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | | [Mistral-7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) | | [Mamba-2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.04 | 12.3 | 41 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) | | [LLaMA-3.1-8B](./models/demos/wormhole/llama31_8b) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.20 | 21.4 | 23 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Falcon7B (data parallel)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.10 | 14.1 | 26 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [LLaMA-2-70B - (tensor parallel)](./models/demos/t3000/llama2_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [LLaMA-3.1-70B (tensor parallel)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Falcon40B (tensor parallel)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Mixtral7Bx8 (tensor parallel)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.23 | 14.2 | 33 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Falcon7B (data parallel)](./models/demos/tg/falcon7b) |1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.24 | 4.3 | 26 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -> **Last Update:** September 23, 2024 +| [Falcon7B (data parallel)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.10 | 14.4 | 26 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | +| [LLaMA-2-70B - (tensor parallel)](./models/demos/t3000/llama2_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | +| [LLaMA-3.1-70B (tensor parallel)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | +| [Falcon40B (tensor parallel)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | [v0.53.0-rc2](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc2) | +| [Mixtral7Bx8 (tensor parallel)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.23 | 14.2 | 33 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | +| [Falcon7B (data parallel)](./models/demos/tg/falcon7b) |1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.21 | 4.4 | 26 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | +> **Last Update:** October 7, 2024 > **Notes:** > - The reported LLM performance is for an input sequence length (number of rows filled in the KV cache) of 128 for all models except Mamba (which can accept any sequence length). diff --git a/models/MODEL_UPDATES.md b/models/MODEL_UPDATES.md index 1c3e02e2651..feaa61cc031 100644 --- a/models/MODEL_UPDATES.md +++ b/models/MODEL_UPDATES.md @@ -4,6 +4,13 @@ > > Please refer to the front-page [README](../README.md) for the latest verified release for each model. +## October 7, 2024 + +### [Llama 3.1 - 8B](demos/wormhole/llama31_8b) +- Added support for continuous batching +- Added paged caching support for PagedAttention +- Added a demo which runs with TT-NN tracing (23 t/s/u decode on main) + ## September 23, 2024 ### [Llama 3/3.1 - 70B](demos/t3000/llama3_70b) From df995f1846c40f45dfdc5b4ecf0bce07b1f7a222 Mon Sep 17 00:00:00 2001 From: Yu Gao <145494740+yugaoTT@users.noreply.github.com> Date: Wed, 9 Oct 2024 18:22:27 -0400 Subject: [PATCH 50/58] #0: remove always attr from cmd buffers (#13659) --- tt_metal/hw/inc/dataflow_api.h | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index e38abdbd062..7c771096094 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -41,26 +41,26 @@ extern uint32_t tt_l1_ptr *sem_l1_base[]; #if defined(KERNEL_BUILD) #if defined(COMPILE_FOR_BRISC) -constexpr uint32_t read_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? BRISC_RD_CMD_BUF : DYNAMIC_NOC_BRISC_RD_CMD_BUF; -constexpr uint32_t write_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? BRISC_WR_CMD_BUF : DYNAMIC_NOC_BRISC_WR_CMD_BUF; -constexpr uint32_t write_reg_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? BRISC_WR_REG_CMD_BUF : DYNAMIC_NOC_BRISC_WR_REG_CMD_BUF; -constexpr uint32_t write_at_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? BRISC_AT_CMD_BUF : DYNAMIC_NOC_BRISC_AT_CMD_BUF; +constexpr uint32_t read_cmd_buf = NOC_MODE == DM_DEDICATED_NOC ? BRISC_RD_CMD_BUF : DYNAMIC_NOC_BRISC_RD_CMD_BUF; +constexpr uint32_t write_cmd_buf = NOC_MODE == DM_DEDICATED_NOC ? BRISC_WR_CMD_BUF : DYNAMIC_NOC_BRISC_WR_CMD_BUF; +constexpr uint32_t write_reg_cmd_buf = NOC_MODE == DM_DEDICATED_NOC ? BRISC_WR_REG_CMD_BUF : DYNAMIC_NOC_BRISC_WR_REG_CMD_BUF; +constexpr uint32_t write_at_cmd_buf = NOC_MODE == DM_DEDICATED_NOC ? BRISC_AT_CMD_BUF : DYNAMIC_NOC_BRISC_AT_CMD_BUF; #elif defined(COMPILE_FOR_NCRISC) -constexpr uint32_t read_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_RD_CMD_BUF : DYNAMIC_NOC_NCRISC_RD_CMD_BUF; -constexpr uint32_t write_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_WR_CMD_BUF : DYNAMIC_NOC_NCRISC_WR_CMD_BUF; -constexpr uint32_t write_reg_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_WR_REG_CMD_BUF : DYNAMIC_NOC_NCRISC_WR_REG_CMD_BUF; -constexpr uint32_t write_at_cmd_buf __attribute__((used)) = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_AT_CMD_BUF : DYNAMIC_NOC_NCRISC_AT_CMD_BUF; +constexpr uint32_t read_cmd_buf = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_RD_CMD_BUF : DYNAMIC_NOC_NCRISC_RD_CMD_BUF; +constexpr uint32_t write_cmd_buf = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_WR_CMD_BUF : DYNAMIC_NOC_NCRISC_WR_CMD_BUF; +constexpr uint32_t write_reg_cmd_buf = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_WR_REG_CMD_BUF : DYNAMIC_NOC_NCRISC_WR_REG_CMD_BUF; +constexpr uint32_t write_at_cmd_buf = NOC_MODE == DM_DEDICATED_NOC ? NCRISC_AT_CMD_BUF : DYNAMIC_NOC_NCRISC_AT_CMD_BUF; #else // use the default cmf buffers for compute/eth -constexpr uint32_t read_cmd_buf __attribute__((used)) = NCRISC_RD_CMD_BUF; -constexpr uint32_t write_cmd_buf __attribute__((used)) = NCRISC_WR_CMD_BUF; -constexpr uint32_t write_reg_cmd_buf __attribute__((used)) = NCRISC_WR_REG_CMD_BUF; -constexpr uint32_t write_at_cmd_buf __attribute__((used)) = NCRISC_AT_CMD_BUF; +constexpr uint32_t read_cmd_buf = NCRISC_RD_CMD_BUF; +constexpr uint32_t write_cmd_buf = NCRISC_WR_CMD_BUF; +constexpr uint32_t write_reg_cmd_buf = NCRISC_WR_REG_CMD_BUF; +constexpr uint32_t write_at_cmd_buf = NCRISC_AT_CMD_BUF; #endif #else // FW build -constexpr uint32_t read_cmd_buf __attribute__((used)) = NCRISC_RD_CMD_BUF; -constexpr uint32_t write_cmd_buf __attribute__((used)) = NCRISC_WR_CMD_BUF; -constexpr uint32_t write_reg_cmd_buf __attribute__((used)) = NCRISC_WR_REG_CMD_BUF; -constexpr uint32_t write_at_cmd_buf __attribute__((used)) = NCRISC_AT_CMD_BUF; +constexpr uint32_t read_cmd_buf = NCRISC_RD_CMD_BUF; +constexpr uint32_t write_cmd_buf = NCRISC_WR_CMD_BUF; +constexpr uint32_t write_reg_cmd_buf = NCRISC_WR_REG_CMD_BUF; +constexpr uint32_t write_at_cmd_buf = NCRISC_AT_CMD_BUF; #endif /** @file */ From fa69b0b81b9f9f452bfc51abca3180942d24446e Mon Sep 17 00:00:00 2001 From: Stanislav Minakov Date: Wed, 9 Oct 2024 16:56:29 -0700 Subject: [PATCH 51/58] #13127: Allow `compute_output_shapes` to use SimpleShape instead of LegacyShape, port some ops to SimpleShape (#13645) * #13127: Prototype of moving some operation to ttnn::SimpleShape * #13127: Port more ops * #13127: Infra support for SimpleShape * #13127: Revert convolution changes * #13127: Refactor to remove code duplication * #13127: Fix rebase issues * #13127: Extract extract_legacy_shape function, make get_physical_shape pure function --- docs/source/ttnn/ttnn/dependencies/tt_lib.rst | 8 +++--- .../unit_tests/gtests/test_ccl_on_galaxy.cpp | 9 ++++++- .../moreh_clip_grad_norm_op.cpp | 6 ++--- .../moreh_clip_grad_norm_op.hpp | 6 ++--- .../op_library/moreh_dot/moreh_dot_op.cpp | 10 +++----- .../op_library/moreh_dot/moreh_dot_op.hpp | 2 +- .../moreh_dot_backward_op.cpp | 2 +- .../moreh_dot_backward_op.hpp | 2 +- .../moreh_layernorm/moreh_layernorm_op.cpp | 25 +++++-------------- .../moreh_layernorm/moreh_layernorm_op.hpp | 2 +- .../moreh_layernorm_backward_op.cpp | 6 ++--- .../moreh_layernorm_backward_op.hpp | 4 +-- .../moreh_matmul/moreh_matmul_op.cpp | 21 ++++++---------- .../moreh_matmul/moreh_matmul_op.hpp | 2 +- ttnn/cpp/ttnn/operation.hpp | 7 +++--- .../ccl/all_gather/device/all_gather_op.cpp | 6 ++--- .../ccl/all_gather/device/all_gather_op.hpp | 2 +- .../device/reduce_scatter_op.cpp | 6 ++--- .../device/reduce_scatter_op.hpp | 2 +- .../bcast/device/bcast_device_operation.cpp | 8 +++--- .../bcast/device/bcast_device_operation.hpp | 2 +- .../concat/device/concat_device_operation.cpp | 6 ++--- .../concat/device/concat_device_operation.hpp | 2 +- .../device/all_gather_matmul_op.cpp | 8 +++--- .../device/all_gather_matmul_op.hpp | 2 +- .../operations/matmul/device/matmul_op.cpp | 15 ++++------- .../operations/matmul/device/matmul_op.hpp | 2 +- .../device/moreh_dot_device_operation.cpp | 10 +++----- .../device/moreh_dot_device_operation.hpp | 2 +- .../moreh_dot_backward_device_operation.hpp | 2 +- .../moreh_group_norm_device_operation.cpp | 16 +++++------- .../moreh_group_norm_device_operation.hpp | 2 +- .../reduction/prod/device/prod_nc_op.cpp | 2 +- .../reduction/prod/device/prod_nc_op.hpp | 2 +- ttnn/cpp/ttnn/run_operation.cpp | 23 +++++++++++++++-- ttnn/cpp/ttnn/tensor/tensor.cpp | 5 ++++ ttnn/cpp/ttnn/tensor/tensor.hpp | 8 ++++++ ttnn/cpp/ttnn/tensor/types.cpp | 25 +++++++++++++++++++ ttnn/cpp/ttnn/tensor/types.hpp | 2 ++ 39 files changed, 153 insertions(+), 119 deletions(-) diff --git a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst index 73ae0517fe9..7a87a746005 100644 --- a/docs/source/ttnn/ttnn/dependencies/tt_lib.rst +++ b/docs/source/ttnn/ttnn/dependencies/tt_lib.rst @@ -34,7 +34,7 @@ New Device Operation struct { void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; @@ -48,7 +48,7 @@ New Device Operation with a member int some_member void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; @@ -61,7 +61,7 @@ New Device Operation with Optional Input Tensors struct { void validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, @@ -80,7 +80,7 @@ and create_output_tensors with the additional parameter for the output_tensors. struct { void validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector> create_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithOptionalOutputTensors create_program(const std::vector& input_tensors, std::vector> &output_tensors) const; diff --git a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp index 027537ae3a5..df3476bd545 100644 --- a/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp +++ b/tests/ttnn/unit_tests/gtests/test_ccl_on_galaxy.cpp @@ -27,7 +27,14 @@ std::vector run_operation( const operation::OptionalTensors& optional_output_tensors = {}) { static_assert(operation::detail::is_device_operation(), "ttnn::run_operation can only dispatch Device Operations!"); // Create output tensor vector by examining the number of output shapes created by the device operation - std::vector outputs(operation::DeviceOperation(devop).compute_output_shapes(input_tensors).size()); + auto output_shapes = operation::DeviceOperation(devop).compute_output_shapes(input_tensors); + size_t output_shapes_size = 0; + if (std::holds_alternative>(output_shapes)) { + output_shapes_size = std::get>(output_shapes).size(); + } else { + output_shapes_size = std::get>(output_shapes).size(); + } + std::vector outputs(output_shapes_size); // Populate the workers of the output tensors, based on the input tensors. This is needed for the async engine. for (int i = 0; i < outputs.size(); i++) { outputs[i] = Tensor(operation::get_workers_for_op_output(std::move(input_tensors), std::move(optional_input_tensors))); diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp index b054fbedff6..20276ea7676 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.cpp @@ -48,7 +48,7 @@ void MorehClipGradNormStep1::validate( check_tensor(tmp_pow_sum, "moreh_clip_grad_norm_step1", "tmp_pow_sum"); }; -std::vector MorehClipGradNormStep1::compute_output_shapes(const std::vector &) const { return {}; } +std::vector MorehClipGradNormStep1::compute_output_shapes(const std::vector &) const { return {}; } std::vector MorehClipGradNormStep1::create_output_tensors(const std::vector &) const { return {}; } @@ -105,7 +105,7 @@ void MorehClipGradNormStep2::validate(const std::vector &input_tensors) check_tensor(total_norm, "moreh_clip_grad_norm_step2", "total_norm"); } -std::vector MorehClipGradNormStep2::compute_output_shapes(const std::vector &) const { return {}; } +std::vector MorehClipGradNormStep2::compute_output_shapes(const std::vector &) const { return {}; } std::vector MorehClipGradNormStep2::create_output_tensors(const std::vector &) const { return {}; } @@ -146,7 +146,7 @@ void MorehClipGradNormStep3::validate( check_tensor(clip_coef_clamped, "moreh_clip_grad_norm_step3", "clip_coef_clamped"); } -std::vector MorehClipGradNormStep3::compute_output_shapes(const std::vector &) const { return {}; } +std::vector MorehClipGradNormStep3::compute_output_shapes(const std::vector &) const { return {}; } std::vector MorehClipGradNormStep3::create_output_tensors(const std::vector &) const { return {}; } diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp index c946befe11d..3e84fee79c3 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_clip_grad_norm/moreh_clip_grad_norm_op.hpp @@ -32,7 +32,7 @@ struct MorehClipGradNormStep1 { void validate( const std::vector &input_tensors, const std::vector> &optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &) const; + std::vector compute_output_shapes(const std::vector &) const; std::vector create_output_tensors(const std::vector &) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, @@ -49,7 +49,7 @@ struct MorehClipGradNormStep2 { float norm_type; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &) const; + std::vector compute_output_shapes(const std::vector &) const; std::vector create_output_tensors(const std::vector &) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &) const; @@ -64,7 +64,7 @@ struct MorehClipGradNormStep3 { void validate( const std::vector &input_tensors, const std::vector> &optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &) const; + std::vector compute_output_shapes(const std::vector &) const; std::vector create_output_tensors(const std::vector &) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp index dabf8dd64ba..2b38e2f6bf4 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.cpp @@ -46,13 +46,11 @@ void MorehDot::validate(const std::vector& input_tensors) const { "Operands to matmul need to be allocated in buffers on device!"); } -std::vector MorehDot::compute_output_shapes(const std::vector& input_tensors) const { +std::vector MorehDot::compute_output_shapes(const std::vector& input_tensors) const { const auto& input_tensor = input_tensors.at(0); - auto output_shape = input_tensor.get_legacy_shape(); - auto padding = output_shape.padding(); - output_shape[3] = TILE_WIDTH; - padding[3] = Padding::PadDimension{0, 31}; - return {tt::tt_metal::LegacyShape(output_shape, padding)}; + auto output_shape = input_tensor.get_logical_shape(); + output_shape[3] = 1; + return {output_shape}; } std::vector MorehDot::create_output_tensors(const std::vector& input_tensors) const { diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp index 0be70fedaa7..288b8b07729 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot/moreh_dot_op.hpp @@ -27,7 +27,7 @@ struct MorehDot { const DataType output_dtype; // TODO: Uplift output_dtype as an option for general dot/bmm void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.cpp index dfa0fec1c50..7aa118e7a93 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.cpp @@ -63,7 +63,7 @@ void MorehDotBackward::validate( } } -std::vector MorehDotBackward::compute_output_shapes(const std::vector& inputs) const { +std::vector MorehDotBackward::compute_output_shapes(const std::vector& inputs) const { // Inplace return {}; } diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp index 6e073dd5723..c00a1b7b96a 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_dot_backward/moreh_dot_backward_op.hpp @@ -29,7 +29,7 @@ operation::ProgramWithCallbacks moreh_dot_backward_single_core( struct MorehDotBackward { void validate( const std::vector &inputs, const std::vector> &optional_inputs) const; - std::vector compute_output_shapes(const std::vector &inputs) const; + std::vector compute_output_shapes(const std::vector &inputs) const; std::vector create_output_tensors(const std::vector &inputs) const; operation::ProgramWithCallbacks create_program( const std::vector &inputs, diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp index bfac7fe8658..7f09ff62e74 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.cpp @@ -405,39 +405,26 @@ void MorehLayerNorm::validate_with_output_tensors( } } -std::vector MorehLayerNorm::compute_output_shapes(const std::vector& input_tensors) const { +std::vector MorehLayerNorm::compute_output_shapes(const std::vector& input_tensors) const { auto input = input_tensors.at(0); // compute mean_rstd_shape - tt::tt_metal::LegacyShape input_shape = input.get_legacy_shape(); - auto input_shape_without_padding = input_shape.without_padding(); + auto input_shape = input.get_logical_shape(); auto input_rank = input_shape.rank(); auto output_rank = input_rank - normalized_dims; - std::vector output_size_vec; - auto dimensions_pads = std::vector(); + std::vector output_shape_vec; // special case handling if (output_rank == 1) { - output_size_vec.push_back(32); - dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = 31}); + output_shape_vec.push_back(1); } for (uint32_t dim = 0 ; dim < output_rank; dim++) { - auto input_shape_without_padding_size = input_shape_without_padding[dim]; - if (is_hw_dim(dim, output_rank)) { - output_size_vec.push_back(round_up_to_mul32(input_shape_without_padding_size)); - - auto padding_back = output_size_vec[dim] - input_shape_without_padding_size; - dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = padding_back}); - } else { - output_size_vec.push_back(input_shape_without_padding_size); - dimensions_pads.push_back(Padding::PadDimension{.front = 0, .back = 0}); - } + output_shape_vec.push_back(input_shape[dim]); } - const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); - auto mean_rstd_output_shape = tt::tt_metal::LegacyShape(output_size_vec, padding); + ttnn::SimpleShape mean_rstd_output_shape(std::move(output_shape_vec)); return {input_shape, mean_rstd_output_shape, mean_rstd_output_shape}; } diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp index 18edae72fd4..de0111eb97e 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm/moreh_layernorm_op.hpp @@ -34,7 +34,7 @@ struct MorehLayerNorm { const std::vector &input_tensors, const std::vector> &optional_input_tensors, const std::vector> &output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors( const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp index 2f09730b4a6..f0ff2ff97f1 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.cpp @@ -62,10 +62,10 @@ void MorehLayerNormBackwardInputGrad::validate_with_output_tensors( } } -std::vector MorehLayerNormBackwardInputGrad::compute_output_shapes( +std::vector MorehLayerNormBackwardInputGrad::compute_output_shapes( const std::vector& input_tensors) const { auto input = input_tensors.at(0); - auto input_shape = input.get_legacy_shape(); + auto input_shape = input.get_logical_shape(); // The shapes of the input and output are always the same. return {input_shape}; @@ -131,7 +131,7 @@ void MorehLayerNormBackwardGammaBetaGrad::validate_with_output_tensors( } } -std::vector MorehLayerNormBackwardGammaBetaGrad::compute_output_shapes( +std::vector MorehLayerNormBackwardGammaBetaGrad::compute_output_shapes( const std::vector& input_tensors) const { TT_THROW("The compute_output_shapes function in MorehLayerNormBackwardGammaBetaGrad is not implemented."); return {}; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp index 6e46832f6e7..f66c8f1061c 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_layernorm_backward/moreh_layernorm_backward_op.hpp @@ -30,7 +30,7 @@ struct MorehLayerNormBackwardInputGrad { const std::vector &input_tensors, const std::vector> &optional_input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, @@ -46,7 +46,7 @@ struct MorehLayerNormBackwardGammaBetaGrad { void validate_with_output_tensors( const std::vector &input_tensors, const std::vector>& output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp index 7de14906c48..3a491640609 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.cpp @@ -32,18 +32,16 @@ inline bool is_dot_forward(const Tensor& input, const Tensor& other, bool transp return is_1d_tensor(input) && is_1d_tensor(other) && is_same_shape(input, other); } -tt::tt_metal::LegacyShape compute_output_shape( +ttnn::SimpleShape compute_output_shape( const tt::tt_metal::LegacyShape& input_shape, const tt::tt_metal::LegacyShape& other_shape, bool transpose_input, bool transpose_other) { - const auto& input_shape_wo_padding = input_shape.without_padding(); - const auto& other_shape_wo_padding = other_shape.without_padding(); + const auto& logical_input_shape = input_shape.logical_shape(); + const auto& logical_other_shape = other_shape.logical_shape(); - auto h = (transpose_input) ? (input_shape[-1]) : (input_shape[-2]); - auto w = (transpose_other) ? (other_shape[-2]) : (other_shape[-1]); - auto h_wo_padding = (transpose_input) ? (input_shape_wo_padding[-1]) : (input_shape_wo_padding[-2]); - auto w_wo_padding = (transpose_other) ? (other_shape_wo_padding[-2]) : (other_shape_wo_padding[-1]); + auto h = (transpose_input) ? (logical_input_shape[-1]) : (logical_input_shape[-2]); + auto w = (transpose_other) ? (logical_other_shape[-2]) : (logical_other_shape[-1]); std::vector input_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); std::vector other_dim(tt::tt_metal::MAX_NUM_DIMENSIONS, 1); @@ -72,12 +70,7 @@ tt::tt_metal::LegacyShape compute_output_shape( output_dim[output_rank - 2] = h; output_dim[output_rank - 1] = w; - tt::tt_metal::LegacyShape output_shape{output_dim}; - auto padding = output_shape.padding(); - // padding for t logmatrix dims - padding[output_rank - 2] = Padding::PadDimension{0, h - h_wo_padding}; - padding[output_rank - 1] = Padding::PadDimension{0, w - w_wo_padding}; - return {tt::tt_metal::LegacyShape(output_shape, padding)}; + return {ttnn::SimpleShape(std::move(output_dim))}; } } // namespace @@ -159,7 +152,7 @@ operation::ProgramWithCallbacks MorehMatmul::create_program( } // Must be provided in the case where an optional output tensor was not provided -std::vector MorehMatmul::compute_output_shapes( +std::vector MorehMatmul::compute_output_shapes( const std::vector& input_tensors) const { return {compute_output_shape( input_tensors.at(0).get_legacy_shape(), diff --git a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp index 82ae710c93f..5297ee9add6 100644 --- a/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp +++ b/ttnn/cpp/ttnn/deprecated/tt_dnn/op_library/moreh_matmul/moreh_matmul_op.hpp @@ -39,7 +39,7 @@ struct MorehMatmul { const std::vector &input_tensors, const std::vector> &optional_input_tensors, const std::vector> &output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors( const std::vector &input_tensors, const std::vector> &output_tensors) const; operation::ProgramWithCallbacks create_program( diff --git a/ttnn/cpp/ttnn/operation.hpp b/ttnn/cpp/ttnn/operation.hpp index 5f5efc1b85e..e20faf18fdf 100644 --- a/ttnn/cpp/ttnn/operation.hpp +++ b/ttnn/cpp/ttnn/operation.hpp @@ -384,6 +384,7 @@ template struct DeviceOperation final { using storage_t = std::array; using OutputTensors = OutputTensorsT; + using ComputedShapes = std::variant, std::vector>; inline const std::string get_type_name() const { return this->get_type_name_impl_(this->type_erased_storage); } @@ -395,7 +396,7 @@ struct DeviceOperation final { this->type_erased_storage, input_tensors, optional_input_tensors, optional_output_tensors); } - inline const std::vector compute_output_shapes(const Tensors& input_tensors) const { + inline const ComputedShapes compute_output_shapes(const Tensors& input_tensors) const { return this->compute_output_shapes_impl_(this->type_erased_storage, input_tensors); } @@ -544,7 +545,7 @@ struct DeviceOperation final { } }}, compute_output_shapes_impl_{ - [](const storage_t& storage, const Tensors& input_tensors) -> const std::vector { + [](const storage_t& storage, const Tensors& input_tensors) -> const ComputedShapes { const auto& operation = *reinterpret_cast*>(&storage); return operation.compute_output_shapes(input_tensors); }}, @@ -753,7 +754,7 @@ struct DeviceOperation final { const Tensors&, const std::vector>&, const OptionalTensors&); - const std::vector (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&); + const ComputedShapes (*compute_output_shapes_impl_)(const storage_t& value, const Tensors&); const OutputTensors (*create_output_tensors_impl_)(const storage_t& value, const Tensors&, const OptionalTensors&); CacheableProgram (*create_program_impl_)( diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp index 026cf6d0ddb..3461d5f1da7 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.cpp @@ -165,10 +165,10 @@ void AllGather::validate(const std::vector &input_tensors) const { } } -std::vector AllGather::compute_output_shapes(const std::vector &input_tensors) const { - auto shape = input_tensors[0].get_legacy_shape(); +std::vector AllGather::compute_output_shapes(const std::vector &input_tensors) const { + auto shape = input_tensors[0].get_logical_shape(); shape[this->dim] *= this->ring_size; - return std::vector(input_tensors.size(), shape); + return std::vector(input_tensors.size(), shape); } std::vector AllGather::create_output_tensors(const std::vector &input_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp index 041561bcc87..607bb80af49 100644 --- a/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/all_gather/device/all_gather_op.hpp @@ -132,7 +132,7 @@ struct AllGather { const ccl::Topology topology; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program(const std::vector& input_tensors, std::vector &output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp index a573b0ff262..c87d5e35a93 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.cpp @@ -22,13 +22,13 @@ void ReduceScatter::validate(const std::vector& input_tensors) const { } } -std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { - auto shape = input_tensors[0].get_legacy_shape(); +std::vector ReduceScatter::compute_output_shapes(const std::vector& input_tensors) const { + auto shape = input_tensors[0].get_logical_shape(); TT_FATAL( shape[this->scatter_dim] % this->ring_size == 0, "The size of the scatter dimension must be a multiple of the ring size"); shape[this->scatter_dim] /= this->ring_size; - return std::vector(input_tensors.size(), shape); + return std::vector(input_tensors.size(), shape); } std::vector ReduceScatter::create_output_tensors(const std::vector& input_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp index 752a42020a4..996d3078ca0 100644 --- a/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp +++ b/ttnn/cpp/ttnn/operations/ccl/reduce_scatter/device/reduce_scatter_op.hpp @@ -25,7 +25,7 @@ struct ReduceScatter { const std::optional user_defined_num_buffers_per_channel; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.cpp index 0a310ac4b75..d2dd0bc14dd 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.cpp @@ -47,9 +47,9 @@ void EltwiseBinaryBroadcast::validate_with_output_tensors(const std::vector output_shape_required = this->compute_output_shapes(input_tensors); + const std::vector output_shape_required = this->compute_output_shapes(input_tensors); const auto& out_tensor = output_tensors.at(0).value(); - TT_FATAL(out_tensor.get_legacy_shape() == output_shape_required.at(0), "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_legacy_shape()); + TT_FATAL(out_tensor.get_logical_shape() == output_shape_required.at(0), "The input tensors need a shape of {}, however the output tensor is only {}", output_shape_required, out_tensor.get_legacy_shape()); } if (this->in_place) { TT_FATAL(input_tensor_a.memory_config().memory_layout == this->output_mem_config.memory_layout, "Error"); @@ -109,9 +109,9 @@ void EltwiseBinaryBroadcast::validate_with_output_tensors(const std::vector EltwiseBinaryBroadcast::compute_output_shapes(const std::vector &input_tensors) const { +std::vector EltwiseBinaryBroadcast::compute_output_shapes(const std::vector &input_tensors) const { const auto& input_tensor = input_tensors.at(0); - return {input_tensor.get_legacy_shape()}; + return {input_tensor.get_logical_shape()}; } diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.hpp index a7fcb22f395..a2a56b717f3 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/bcast_device_operation.hpp @@ -29,7 +29,7 @@ struct EltwiseBinaryBroadcast { const bool in_place; void validate_with_output_tensors(const std::vector &input_tensors, const std::vector> &output_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors, const std::vector> &output_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp index 472b107a4e6..1e3d958c661 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.cpp @@ -61,11 +61,11 @@ void ConcatDeviceOperation::validate(const std::vector &input_tensors) c } } -std::vector ConcatDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { - tt::tt_metal::LegacyShape shape_out = input_tensors[0].get_legacy_shape(); +std::vector ConcatDeviceOperation::compute_output_shapes(const std::vector &input_tensors) const { + ttnn::SimpleShape shape_out = input_tensors[0].get_logical_shape(); shape_out[this->dim] = 0; for (const Tensor &in_ref : input_tensors) { - tt::tt_metal::LegacyShape curr_shape = in_ref.get_legacy_shape(); + ttnn::SimpleShape curr_shape = in_ref.get_logical_shape(); shape_out[this->dim] += curr_shape[this->dim]; } return {shape_out}; diff --git a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp index 0e5a35500a1..86fe135637b 100644 --- a/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/data_movement/concat/device/concat_device_operation.hpp @@ -15,7 +15,7 @@ struct ConcatDeviceOperation { uint32_t dim; const MemoryConfig output_mem_config; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp index 7ec8010cf47..c6be2478fff 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.cpp @@ -54,15 +54,15 @@ void AllGatherMatmul::validate(const std::vector &input_tensors, const s } } -std::vector AllGatherMatmul::compute_output_shapes(const std::vector &input_tensors) const { +std::vector AllGatherMatmul::compute_output_shapes(const std::vector &input_tensors) const { // All Gather shape - tt::tt_metal::LegacyShape all_gather_output_shape = this->all_gather_struct.compute_output_shapes({input_tensors[0]})[0]; - tt::tt_metal::LegacyShape datacopy_output_shape = all_gather_output_shape; + ttnn::SimpleShape all_gather_output_shape = this->all_gather_struct.compute_output_shapes({input_tensors[0]})[0]; + ttnn::SimpleShape datacopy_output_shape = all_gather_output_shape; // Matmul shape - tt::tt_metal::LegacyShape matmul_output_shapes = this->matmul_struct.compute_output_shapes({input_tensors[1], input_tensors[2]})[0]; + ttnn::SimpleShape matmul_output_shapes = this->matmul_struct.compute_output_shapes({input_tensors[1], input_tensors[2]})[0]; return {all_gather_output_shape, matmul_output_shapes, datacopy_output_shape}; } diff --git a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp index 3d57614fefe..6dc88b1086d 100644 --- a/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/experimental/ccl/all_gather_matmul/device/all_gather_matmul_op.hpp @@ -42,7 +42,7 @@ struct AllGatherMatmul { /* General */ void validate(const std::vector &input_tensors, const std::vector>& optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector create_output_tensors(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector& input_tensors, diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp index 4d487336e79..2a6af0f88e8 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.cpp @@ -1298,29 +1298,24 @@ void Matmul::validate( chosen_program_config); } -std::vector Matmul::compute_output_shapes(const std::vector& input_tensors) const { - const tt::tt_metal::LegacyShape& input_shape_a = input_tensors.at(0).get_legacy_shape(); - const tt::tt_metal::LegacyShape& input_shape_b = input_tensors.at(1).get_legacy_shape(); +std::vector Matmul::compute_output_shapes(const std::vector& input_tensors) const { + ttnn::SimpleShape input_shape_a = input_tensors.at(0).get_logical_shape(); + ttnn::SimpleShape input_shape_b = input_tensors.at(1).get_logical_shape(); const uint32_t a_rank = input_shape_a.rank(); const uint32_t b_rank = input_shape_b.rank(); const uint32_t out_rank = std::max(a_rank, b_rank); const uint32_t rank_difference = out_rank - a_rank; - tt::tt_metal::LegacyShape output_shape = (b_rank > a_rank) ? input_shape_b : input_shape_a; - auto dimensions_pads = std::vector(); + ttnn::SimpleShape output_shape = (b_rank > a_rank) ? input_shape_b : input_shape_a; for (auto index = 0; index < rank_difference; index++) { TT_FATAL(input_shape_b[index] == 1, "When in1 rank greater than in0 rank front dimensions need to be 1"); output_shape[index] = input_shape_b[index]; - dimensions_pads.push_back(input_shape_b.padding()[index]); } for (auto index = 0; index < a_rank - 1; index++) { output_shape[rank_difference + index] = input_shape_a[index]; - dimensions_pads.push_back(input_shape_a.padding()[index]); } output_shape[-1] = input_shape_b[-1]; - dimensions_pads.push_back(input_shape_b.padding()[b_rank - 1]); - const auto padding = Padding(dimensions_pads, Padding::PadValue::Any); - return {tt::tt_metal::LegacyShape(output_shape, padding)}; + return {std::move(output_shape)}; } std::vector Matmul::create_output_tensors(const std::vector& input_tensors) const { diff --git a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp index b08678d871e..32eb87cd13a 100644 --- a/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp +++ b/ttnn/cpp/ttnn/operations/matmul/device/matmul_op.hpp @@ -169,7 +169,7 @@ struct Matmul { void validate( const std::vector &input_tensors, const std::vector> &optional_input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; + std::vector compute_output_shapes(const std::vector &input_tensors) const; std::vector compute_output_shapes_dram_sharded( const std::vector &input_tensors, uint32_t N_unpadded) const; std::vector create_output_tensors(const std::vector &input_tensors) const; diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp index 151f50ae67f..ac997b24f68 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.cpp @@ -50,14 +50,12 @@ void MorehDotOperation::validate_on_program_cache_hit( MorehDotOperation::shape_return_value_t MorehDotOperation::compute_output_shapes( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { if (tensor_args.output.has_value()) { - return tensor_args.output.value().get_shape(); + return tensor_args.output.value().get_logical_shape(); } const auto& input = tensor_args.input_a; - auto output_shape = input.get_shape().value; - auto padding = output_shape.padding(); - output_shape[3] = tt::constants::TILE_WIDTH; - padding[3] = Padding::PadDimension{0, 31}; - return ttnn::Shape{tt::tt_metal::LegacyShape(output_shape, padding)}; + auto output_shape = input.get_logical_shape(); + output_shape[3] = 1; + return ttnn::SimpleShape{std::move(output_shape)}; } MorehDotOperation::tensor_return_value_t MorehDotOperation::create_output_tensors( diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.hpp index 7c02317988e..727b282a362 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot/device/moreh_dot_device_operation.hpp @@ -23,7 +23,7 @@ struct MorehDotOperation { const std::optional& output; }; - using shape_return_value_t = Shape; + using shape_return_value_t = SimpleShape; using tensor_return_value_t = Tensor; struct SingleCore { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp index d7185780040..34ae223deb8 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_dot_op_backward/device/moreh_dot_backward_device_operation.hpp @@ -26,7 +26,7 @@ struct MorehDotBackwardOperation { const std::vector> output_tensors; }; - using shape_return_value_t = std::vector>; + using shape_return_value_t = std::vector>; using tensor_return_value_t = std::vector>; struct SingleCore { diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp index a4a6a19b365..945cfa6aafa 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.cpp @@ -84,20 +84,16 @@ MorehGroupNormOperation::shape_return_value_t MorehGroupNormOperation::compute_o const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { using namespace tt::constants; // mean, rstd (1, 1, N, num_groups) - const auto output_shape = tensor_args.input.get_shape(); - const auto N = output_shape.value[0]; + const auto output_shape = tensor_args.input.get_logical_shape(); + const auto N = output_shape[0]; const auto num_groups = operation_attributes.num_groups; - const std::vector mean_rstd_origin_shape{ + std::vector mean_rstd_origin_shape{ 1, 1, - TILE_HEIGHT * ((N + TILE_HEIGHT - 1) / TILE_HEIGHT), - TILE_WIDTH * ((num_groups + TILE_WIDTH - 1) / TILE_WIDTH)}; + N, + num_groups}; - auto mean_rstd_padding = output_shape.value.padding(); - mean_rstd_padding[2] = Padding::PadDimension{0, TILE_HEIGHT - (N % TILE_HEIGHT)}; - mean_rstd_padding[3] = Padding::PadDimension{0, TILE_WIDTH - (num_groups % TILE_WIDTH)}; - - Shape mean_rstd_shape = Shape(tt::tt_metal::LegacyShape(mean_rstd_origin_shape, mean_rstd_padding)); + SimpleShape mean_rstd_shape(std::move(mean_rstd_origin_shape)); return {output_shape, mean_rstd_shape, mean_rstd_shape}; } diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.hpp b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.hpp index 338cd28123a..480aac7cf01 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_group_norm/device/moreh_group_norm_device_operation.hpp @@ -27,7 +27,7 @@ struct MorehGroupNormOperation { const std::optional rstd; }; - using shape_return_value_t = std::vector>; + using shape_return_value_t = std::vector>; using tensor_return_value_t = std::vector>; struct MorehGroupNormFactory { diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp index ea9e217f356..cafcb99c1be 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.cpp @@ -44,7 +44,7 @@ std::vector Prod::create_output_tensors(const std::vector& input return {}; } -std::vector Prod::compute_output_shapes(const std::vector& inputs) const { +std::vector Prod::compute_output_shapes(const std::vector& inputs) const { // Inplace return {}; diff --git a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp index 7d92526127b..5552f120a4d 100644 --- a/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp +++ b/ttnn/cpp/ttnn/operations/reduction/prod/device/prod_nc_op.hpp @@ -23,7 +23,7 @@ using namespace tt_metal; struct Prod { int64_t dim; void validate(const std::vector &inputs) const; - std::vector compute_output_shapes(const std::vector &inputs) const; + std::vector compute_output_shapes(const std::vector &inputs) const; std::vector create_output_tensors(const std::vector &inputs) const; operation::ProgramWithCallbacks create_program( const std::vector &inputs, std::vector &outputs) const; diff --git a/ttnn/cpp/ttnn/run_operation.cpp b/ttnn/cpp/ttnn/run_operation.cpp index cfa2be30c4e..78366e52fef 100644 --- a/ttnn/cpp/ttnn/run_operation.cpp +++ b/ttnn/cpp/ttnn/run_operation.cpp @@ -302,6 +302,21 @@ template OptionalTensors run_without_autoformat( const OptionalTensors& optional_output_tensors, uint8_t cq_id); +std::vector extract_legacy_shapes( + const std::variant, std::vector>&& shapes, const std::function& layout_provider) { + if (std::holds_alternative>(shapes)) { + return std::get>(std::move(shapes)); + } + const auto& simple_shapes = std::get>(shapes); + std::vector legacy_shapes; + legacy_shapes.reserve(simple_shapes.size()); + for (size_t idx = 0; idx < simple_shapes.size(); idx++) { + auto layout = layout_provider(idx); + legacy_shapes.emplace_back(simple_shapes[idx].as_vector(), get_physical_shape(simple_shapes[idx], layout).as_vector()); + } + return legacy_shapes; +} + // To be deprecated/removed in favor of new implementation where ops specifically request how to format inputs/outputss Tensors run_with_autoformat( DeviceOperation&& operation, @@ -314,7 +329,9 @@ Tensors run_with_autoformat( using ttnn::operations::experimental::auto_format::AutoFormat; ZoneScoped; Device* device = detail::get_device(input_tensors, optional_input_tensors); - auto output_shapes = operation.compute_output_shapes(input_tensors); + auto output_shapes = extract_legacy_shapes(operation.compute_output_shapes(input_tensors), [](size_t) { + return Layout::TILE; + }); Tensors formatted_input_tensors; formatted_input_tensors.reserve(input_tensors.size()); @@ -372,7 +389,9 @@ Tensors run_with_autoformat( using ttnn::operations::experimental::auto_format::AutoFormat; ZoneScoped; Device* device = detail::get_device(input_tensors, optional_input_tensors); - auto output_shapes = operation.compute_output_shapes(input_tensors); + auto output_shapes = extract_legacy_shapes(operation.compute_output_shapes(input_tensors), [&](size_t idx) { + return output_layouts[idx]; + }); TT_ASSERT(input_tensors.size() == input_formatting.size()); TT_ASSERT(optional_input_tensors.size() == optional_input_formatting.size()); diff --git a/ttnn/cpp/ttnn/tensor/tensor.cpp b/ttnn/cpp/ttnn/tensor/tensor.cpp index 4f0fe5d95e2..642b2acec64 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.cpp +++ b/ttnn/cpp/ttnn/tensor/tensor.cpp @@ -696,6 +696,11 @@ Tensor create_device_tensor( } } +Tensor create_device_tensor( + const ttnn::SimpleShape& logical_shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { + return create_device_tensor(logical_shape, get_physical_shape(logical_shape, layout, tile), data_type, layout, device, memory_config, tile); +} + Tensor create_device_tensor( const ttnn::Shape& shape, DataType data_type, Layout layout, Device* device, const MemoryConfig& memory_config, const std::optional& tile) { return create_device_tensor(shape.logical_shape(), shape.padded_shape(), data_type, layout, device, memory_config, tile); diff --git a/ttnn/cpp/ttnn/tensor/tensor.hpp b/ttnn/cpp/ttnn/tensor/tensor.hpp index e23832be836..4f201073097 100644 --- a/ttnn/cpp/ttnn/tensor/tensor.hpp +++ b/ttnn/cpp/ttnn/tensor/tensor.hpp @@ -293,6 +293,14 @@ struct Tensor { } }; +Tensor create_device_tensor( + const ttnn::SimpleShape &logical_shape, + DataType dtype, + Layout layout, + Device *device, + const MemoryConfig &memory_config = {.memory_layout = tt::tt_metal::TensorMemoryLayout::INTERLEAVED}, + const std::optional& tile = std::nullopt); + Tensor create_device_tensor( const ttnn::SimpleShape &logical_shape, const ttnn::SimpleShape &padded_shape, diff --git a/ttnn/cpp/ttnn/tensor/types.cpp b/ttnn/cpp/ttnn/tensor/types.cpp index ebf3001f5b8..7814c9d2a18 100644 --- a/ttnn/cpp/ttnn/tensor/types.cpp +++ b/ttnn/cpp/ttnn/tensor/types.cpp @@ -5,6 +5,31 @@ #include #include "ttnn/tensor/types.hpp" +namespace ttnn { + +SimpleShape get_physical_shape(const SimpleShape& logical_shape, Layout layout, const std::optional& tile) { + SimpleShape physical_shape = logical_shape; + if (layout == Layout::TILE) { + auto tile_height = tt::constants::TILE_HEIGHT; + auto tile_width = tt::constants::TILE_WIDTH; + if (tile.has_value()) { + auto tile_shape = tile.value().get_tile_shape(); + tile_height = tile_shape[0]; + tile_width = tile_shape[1]; + } + auto rank = physical_shape.rank(); + if (rank >= 1) { + physical_shape[rank - 1] = (physical_shape[rank - 1] + tile_width - 1) / tile_width * tile_width; + if (rank >= 2) { + physical_shape[rank - 2] = (physical_shape[rank - 2] + tile_height - 1) / tile_height * tile_height; + } + } + } + return physical_shape; +} + +} + namespace tt { namespace tt_metal { diff --git a/ttnn/cpp/ttnn/tensor/types.hpp b/ttnn/cpp/ttnn/tensor/types.hpp index baffe41d56c..2004f4fb19d 100644 --- a/ttnn/cpp/ttnn/tensor/types.hpp +++ b/ttnn/cpp/ttnn/tensor/types.hpp @@ -69,6 +69,8 @@ class SimpleShape { std::vector value; }; +SimpleShape get_physical_shape(const SimpleShape& logical_shape, Layout layout, const std::optional& tile = std::nullopt); + } // namespace ttnn inline std::ostream &operator<<(std::ostream &os, const ttnn::SimpleShape &shape) { From 76c7789c4096ddfcc52b8b278e76e2c0e1e0bd18 Mon Sep 17 00:00:00 2001 From: Kalaivani Baskar <156762498+KalaivaniMCW@users.noreply.github.com> Date: Thu, 10 Oct 2024 08:39:08 +0530 Subject: [PATCH 52/58] #13408: Pytorch tracing sweeps - eltwise (#13437) #13408: Pytorch sweeps set 1 --- .github/workflows/ttnn-run-sweeps.yaml | 18 + .../eltwise/binary/add/add_all_pytorch2.py | 593 ++++++++++++++++++ .../eltwise/binary/eq/eq_scalar_pytorch2.py | 89 +++ .../floor_divide/floor_divide_pytorch2.py | 97 +++ .../eltwise/binary/gt/gt_scalar_pytorch2.py | 80 +++ .../eltwise/binary/le/le_tensor_pytorch2.py | 88 +++ .../binary/multiply/mul_tensor_pytorch2.py | 472 ++++++++++++++ .../remainder/remainder_scalar_pytorch2.py | 78 +++ .../sweeps/eltwise/unary/abs/abs_pytorch2.py | 76 +++ .../unary/bitwise/bitwise_not_pytorch2.py | 82 +++ .../eltwise/unary/ceil/ceil_pytorch2.py | 81 +++ .../sweeps/eltwise/unary/cos/cos_pytorch2.py | 73 +++ .../sweeps/eltwise/unary/elu/elu_pytorch2.py | 77 +++ .../sweeps/eltwise/unary/exp/exp_pytorch2.py | 87 +++ .../eltwise/unary/floor/floor_pytorch2.py | 82 +++ .../eltwise/unary/gelu/gelu_pytorch2.py | 129 ++++ .../unary/hardsigmoid/hardsigmoid_pytorch2.py | 89 +++ .../unary/leaky_relu/leaky_relu_pytorch2.py | 95 +++ .../sweeps/eltwise/unary/log/log_pytorch2.py | 73 +++ .../unit_testing/misc/test_unary_ops_ttnn.py | 25 +- .../{ => eltwise}/test_activation.py | 38 +- .../operations/eltwise/unary/unary_pybind.hpp | 2 +- ttnn/ttnn/operations/unary.py | 22 +- 23 files changed, 2515 insertions(+), 31 deletions(-) create mode 100644 tests/sweep_framework/sweeps/eltwise/binary/add/add_all_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/binary/eq/eq_scalar_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/binary/floor_divide/floor_divide_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/binary/gt/gt_scalar_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/binary/le/le_tensor_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/binary/multiply/mul_tensor_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_scalar_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/abs/abs_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/bitwise/bitwise_not_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/ceil/ceil_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/cos/cos_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/elu/elu_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/exp/exp_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/floor/floor_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/gelu/gelu_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/hardsigmoid/hardsigmoid_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/leaky_relu/leaky_relu_pytorch2.py create mode 100644 tests/sweep_framework/sweeps/eltwise/unary/log/log_pytorch2.py rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_activation.py (88%) diff --git a/.github/workflows/ttnn-run-sweeps.yaml b/.github/workflows/ttnn-run-sweeps.yaml index e15b7880d44..d234535b80f 100644 --- a/.github/workflows/ttnn-run-sweeps.yaml +++ b/.github/workflows/ttnn-run-sweeps.yaml @@ -13,10 +13,15 @@ on: - add - ccl.line_all_gather - ccl.all_gather_n300 + - eltwise.unary.abs.abs_pytorch2 - eltwise.unary.relu.relu - eltwise.unary.relu.relu_pytorch2 - eltwise.unary.gelu.gelu + - eltwise.unary.gelu.gelu_pytorch2 + - eltwise.unary.hardsigmoid.hardsigmoid_pytorch2 + - eltwise.unary.leaky_relu.leaky_relu_pytorch2 - eltwise.unary.cos.cos + - eltwise.unary.cos.cos_pytorch2 - eltwise.unary.sin.sin - eltwise.unary.sin.sin_pytorch2 - eltwise.unary.tril.tril_pytorch2 @@ -29,12 +34,16 @@ on: - eltwise.unary.rdiv.rdiv - eltwise.unary.frac.frac - eltwise.unary.ceil.ceil + - eltwise.unary.ceil.ceil_pytorch2 - eltwise.unary.trunc.trunc - eltwise.unary.floor.floor + - eltwise.unary.floor.floor_pytorch2 - eltwise.unary.clone.clone - eltwise.unary.elu.elu + - eltwise.unary.elu.elu_pytorch2 - eltwise.unary.erfc.erfc - eltwise.unary.exp.exp + - eltwise.unary.exp.exp_pytorch2 - eltwise.unary.exp2.exp2 - eltwise.unary.expm1.expm1 - eltwise.unary.tanh.tanh @@ -44,12 +53,14 @@ on: - eltwise.unary.deg2rad.deg2rad - eltwise.unary.relu6.relu6 - eltwise.unary.log.log + - eltwise.unary.log.log_pytorch2 - eltwise.unary.log1p.log1p - eltwise.unary.log2.log2 - eltwise.unary.log10.log10 - eltwise.unary.bitwise.bitwise_and - eltwise.unary.bitwise.bitwise_left_shift - eltwise.unary.bitwise.bitwise_not + - eltwise.unary.bitwise.bitwise_not_pytorch2 - eltwise.unary.bitwise.bitwise_or - eltwise.unary.bitwise.bitwise_right_shift - eltwise.unary.bitwise.bitwise_xor @@ -99,8 +110,10 @@ on: - eltwise.unary.isnan - eltwise.unary.isneginf - eltwise.unary.isposinf + - eltwise.binary.add.add_all_pytorch2 - eltwise.binary.subtract.subtract - eltwise.binary.multiply.multiply + - eltwise.binary.multiply.mul_tensor_pytorch2 - eltwise.binary.div.div - eltwise.binary.div_no_nan.div_no_nan - eltwise.binary.logical_or.logical_or_ @@ -115,9 +128,14 @@ on: - eltwise.binary.remainder.remainder - eltwise.binary.squared_difference.squared_difference - eltwise.binary.squared_difference_output.squared_difference_output + - eltwise.binary.remainder.remainder_scalar_pytorch2 - eltwise.binary.bcast.bcast_h_sharded - eltwise.binary.bcast.bcast + - eltwise.binary.eq.eq_scalar_pytorch2 + - eltwise.binary.gt.gt_scalar_pytorch2 + - eltwise.binary.le.le_tensor_pytorch2 - eltwise.binary.fmod.fmod + - eltwise.binary.floor_divide.floor_divide_pytorch2 - eltwise.binary.logaddexp.logaddexp - eltwise.binary.ldexp.ldexp - eltwise.binary.hypot.hypot diff --git a/tests/sweep_framework/sweeps/eltwise/binary/add/add_all_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/binary/add/add_all_pytorch2.py new file mode 100644 index 00000000000..408f0c117fa --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/binary/add/add_all_pytorch2.py @@ -0,0 +1,593 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + {"self": [0, 1], "other": [0, 1]}, + {"self": [0], "other": [0]}, + {"self": [1, 1, 1024], "other": [1, 1, 1024]}, + {"self": [1, 1, 16, 32], "other": [1, 1, 16, 32]}, + {"self": [1, 1, 3072], "other": [1, 1, 3072]}, + {"self": [1, 1, 4096], "other": [1, 1, 4096]}, + {"self": [1, 1, 512], "other": [1, 1, 512]}, + {"self": [1, 1, 7, 64], "other": [1, 1, 7, 64]}, + {"self": [1, 1, 768], "other": [1, 1, 768]}, + {"self": [1, 1, 768], "other": [1, 768]}, + {"self": [1, 10, 1024], "other": [1, 10, 1024]}, + {"self": [1, 10, 512], "other": [1, 10, 512]}, + {"self": [1, 10, 768], "other": [1, 10, 768]}, + {"self": [1, 1008, 7, 7], "other": [1, 1008, 7, 7]}, + {"self": [1, 1024, 10, 10], "other": [1, 1024, 10, 10]}, + {"self": [1, 1024, 14, 14], "other": [1, 1024, 14, 14]}, + {"self": [1, 1024, 16, 16], "other": [1, 1024, 16, 16]}, + {"self": [1, 1024, 160], "other": [1, 1024, 160]}, + {"self": [1, 1024, 256], "other": [256]}, + {"self": [1, 1024, 45, 80], "other": [1, 1024, 1, 1]}, + {"self": [1, 1024, 45, 80], "other": [1, 1024, 45, 80]}, + {"self": [1, 1024, 50, 68], "other": [1, 1024, 1, 1]}, + {"self": [1, 1024, 50, 68], "other": [1, 1024, 50, 68]}, + {"self": [1, 1024, 640], "other": [1, 1024, 640]}, + {"self": [1, 1024, 7, 7], "other": [1, 1024, 7, 7]}, + {"self": [1, 1024], "other": [1, 1024]}, + {"self": [1, 104, 28, 28], "other": [1, 104, 28, 28]}, + {"self": [1, 1056, 48, 48], "other": [1, 1056, 48, 48]}, + {"self": [1, 112, 14, 14], "other": [1, 112, 14, 14]}, + {"self": [1, 112, 15, 15], "other": [1, 112, 15, 15]}, + {"self": [1, 112, 20, 20], "other": [1, 112, 20, 20]}, + {"self": [1, 112, 24, 24], "other": [1, 112, 24, 24]}, + {"self": [1, 12, 1, 10], "other": [1, 1, 1, 10]}, + {"self": [1, 12, 1, 10], "other": [1, 12, 1, 10]}, + {"self": [1, 12, 1, 1], "other": [1, 1, 1, 1]}, + {"self": [1, 12, 1, 1], "other": [1, 12, 1, 1]}, + {"self": [1, 12, 1, 24], "other": [1, 1, 1, 24]}, + {"self": [1, 12, 1, 2], "other": [1, 1, 1, 2]}, + {"self": [1, 12, 1, 2], "other": [1, 12, 1, 2]}, + {"self": [1, 12, 1, 46], "other": [1, 1, 1, 46]}, + {"self": [1, 12, 10, 10], "other": [1, 1, 1, 10]}, + {"self": [1, 12, 10, 10], "other": [1, 12, 10, 10]}, + {"self": [1, 12, 12, 12], "other": [1, 1, 1, 12]}, + {"self": [1, 12, 128], "other": [1, 12, 128]}, + {"self": [1, 12, 14, 14], "other": [1, 1, 1, 14]}, + {"self": [1, 12, 197, 197], "other": [1, 12, 197, 197]}, + {"self": [1, 12, 201, 201], "other": [1, 1, 1, 201]}, + {"self": [1, 12, 24, 24], "other": [1, 1, 24, 24]}, + {"self": [1, 12, 25, 25], "other": [1, 1, 1, 25]}, + {"self": [1, 12, 3072], "other": [1, 12, 3072]}, + {"self": [1, 12, 45, 45], "other": [1, 1, 45, 45]}, + {"self": [1, 12, 7, 7], "other": [1, 1, 1, 7]}, + {"self": [1, 12, 768], "other": [1, 12, 768]}, + {"self": [1, 12, 9, 9], "other": [1, 1, 1, 9]}, + {"self": [1, 120, 17, 17], "other": [1, 120, 17, 17]}, + {"self": [1, 120, 28, 28], "other": [1, 120, 28, 28]}, + {"self": [1, 1200, 320], "other": [1, 1200, 320]}, + {"self": [1, 1232, 14, 14], "other": [1, 1232, 14, 14]}, + {"self": [1, 128, 100, 136], "other": [1, 128, 1, 1]}, + {"self": [1, 128, 128, 128], "other": [1, 128, 128, 128]}, + {"self": [1, 128, 1536], "other": [1, 128, 1536]}, + {"self": [1, 128, 180, 320], "other": [1, 128, 1, 1]}, + {"self": [1, 128, 200, 272], "other": [1, 128, 1, 1]}, + {"self": [1, 128, 28, 28], "other": [1, 128, 28, 28]}, + {"self": [1, 128, 56, 56], "other": [1, 128, 56, 56]}, + {"self": [1, 128, 75, 75], "other": [1, 128, 75, 75]}, + {"self": [1, 128, 90, 160], "other": [1, 128, 1, 1]}, + {"self": [1, 1280, 16, 16], "other": [1, 1280, 1, 1]}, + {"self": [1, 1280, 16, 16], "other": [1, 1280, 16, 16]}, + {"self": [1, 1280, 8, 8], "other": [1, 1280, 1, 1]}, + {"self": [1, 1280, 8, 8], "other": [1, 1280, 8, 8]}, + {"self": [1, 1344, 14, 14], "other": [1, 1344, 14, 14]}, + {"self": [1, 136, 19, 19], "other": [1, 136, 19, 19]}, + {"self": [1, 1370, 1280], "other": [1, 1370, 1280]}, + {"self": [1, 1392, 14, 14], "other": [1, 1392, 14, 14]}, + {"self": [1, 14, 128], "other": [1, 14, 128]}, + {"self": [1, 14, 14, 384], "other": [1, 14, 14, 384]}, + {"self": [1, 14, 14, 512], "other": [1, 14, 14, 512]}, + {"self": [1, 14, 3072], "other": [1, 14, 3072]}, + {"self": [1, 14, 768], "other": [1, 14, 768]}, + {"self": [1, 144, 28, 28], "other": [1, 144, 28, 28]}, + {"self": [1, 144, 7, 7], "other": [1, 144, 7, 7]}, + {"self": [1, 1445, 192], "other": [1, 1445, 192]}, + {"self": [1, 15, 1024], "other": [1, 15, 1024]}, + {"self": [1, 15, 512], "other": [1, 15, 512]}, + {"self": [1, 1500, 768], "other": [1, 1500, 768]}, + {"self": [1, 1500, 768], "other": [1500, 768]}, + {"self": [1, 1512, 7, 7], "other": [1, 1512, 7, 7]}, + {"self": [1, 16, 1, 10], "other": [1, 1, 1, 10]}, + {"self": [1, 16, 1, 10], "other": [1, 16, 1, 10]}, + {"self": [1, 16, 1, 1], "other": [1, 1, 1, 1]}, + {"self": [1, 16, 1, 1], "other": [1, 16, 1, 1]}, + {"self": [1, 16, 1, 2], "other": [1, 1, 1, 2]}, + {"self": [1, 16, 1, 2], "other": [1, 16, 1, 2]}, + {"self": [1, 16, 1, 60], "other": [1, 1, 1, 60]}, + {"self": [1, 16, 1, 6], "other": [1, 1, 1, 6]}, + {"self": [1, 16, 10, 10], "other": [1, 1, 1, 10]}, + {"self": [1, 16, 10, 10], "other": [1, 16, 10, 10]}, + {"self": [1, 16, 112, 112], "other": [1, 16, 112, 112]}, + {"self": [1, 16, 16, 384], "other": [1, 16, 16, 384]}, + {"self": [1, 16, 16, 512], "other": [1, 16, 16, 512]}, + {"self": [1, 16, 160, 160], "other": [1, 16, 160, 160]}, + {"self": [1, 16, 19, 19], "other": [1, 1, 19, 19]}, + {"self": [1, 16, 197, 197], "other": [1, 16, 197, 197]}, + {"self": [1, 16, 256, 256], "other": [1, 1, 1, 256]}, + {"self": [1, 16, 5, 5], "other": [1, 1, 1, 5]}, + {"self": [1, 16, 59, 59], "other": [1, 1, 59, 59]}, + {"self": [1, 16, 6, 49, 49], "other": [1, 16, 1, 49, 49]}, + {"self": [1, 16, 6, 64, 64], "other": [1, 16, 1, 64, 64]}, + {"self": [1, 16, 768], "other": [1, 16, 768]}, + {"self": [1, 16, 8, 49, 49], "other": [1, 16, 1, 49, 49]}, + {"self": [1, 16, 8, 64, 64], "other": [1, 16, 1, 64, 64]}, + {"self": [1, 16, 9, 9], "other": [1, 1, 1, 9]}, + {"self": [1, 160, 14, 14], "other": [1, 160, 14, 14]}, + {"self": [1, 160, 24, 24], "other": [1, 160, 24, 24]}, + {"self": [1, 160, 7, 7], "other": [1, 160, 7, 7]}, + {"self": [1, 16384, 256], "other": [256]}, + {"self": [1, 16384, 32], "other": [1, 16384, 32]}, + {"self": [1, 168, 28, 28], "other": [1, 168, 28, 28]}, + {"self": [1, 18, 56, 56], "other": [1, 18, 56, 56]}, + {"self": [1, 19, 1024], "other": [1, 19, 1024]}, + {"self": [1, 192, 28, 28], "other": [1, 192, 28, 28]}, + {"self": [1, 192, 32, 42], "other": [1, 192, 32, 42]}, + {"self": [1, 192, 7, 7], "other": [1, 192, 7, 7]}, + {"self": [1, 192, 8, 8], "other": [1, 192, 8, 8]}, + {"self": [1, 1920, 7, 7], "other": [1, 1920, 7, 7]}, + {"self": [1, 19200, 64], "other": [1, 19200, 64]}, + {"self": [1, 193, 768], "other": [1, 193, 768]}, + {"self": [1, 196, 768], "other": [1, 196, 768]}, + {"self": [1, 197, 1024], "other": [1, 197, 1024]}, + {"self": [1, 197, 768], "other": [1, 197, 768]}, + {"self": [1, 201, 768], "other": [1, 201, 768]}, + {"self": [1, 2016, 7, 7], "other": [1, 2016, 7, 7]}, + {"self": [1, 2048, 23, 40], "other": [1, 2048, 1, 1]}, + {"self": [1, 2048, 23, 40], "other": [1, 2048, 23, 40]}, + {"self": [1, 2048, 25, 34], "other": [1, 2048, 1, 1]}, + {"self": [1, 2048, 25, 34], "other": [1, 2048, 25, 34]}, + {"self": [1, 2048, 7, 7], "other": [1, 2048, 7, 7]}, + {"self": [1, 2048, 768], "other": [1, 2048, 768]}, + {"self": [1, 2048, 768], "other": [2048, 768]}, + {"self": [1, 208, 14, 14], "other": [1, 208, 14, 14]}, + {"self": [1, 208, 9, 9], "other": [1, 208, 9, 9]}, + {"self": [1, 216, 28, 28], "other": [1, 216, 28, 28]}, + {"self": [1, 224, 56, 56], "other": [1, 224, 56, 56]}, + {"self": [1, 232, 10, 10], "other": [1, 232, 10, 10]}, + {"self": [1, 232, 56, 56], "other": [1, 232, 56, 56]}, + {"self": [1, 24, 28, 28], "other": [1, 24, 28, 28]}, + {"self": [1, 24, 49, 49], "other": [1, 24, 49, 49]}, + {"self": [1, 24, 56, 56], "other": [1, 24, 56, 56]}, + {"self": [1, 24, 60, 60], "other": [1, 24, 60, 60]}, + {"self": [1, 24, 64, 64], "other": [1, 24, 64, 64]}, + {"self": [1, 24, 65, 65], "other": [1, 24, 65, 65]}, + {"self": [1, 24, 768], "other": [1, 24, 768]}, + {"self": [1, 24, 80, 80], "other": [1, 24, 80, 80]}, + {"self": [1, 240, 28, 28], "other": [1, 240, 28, 28]}, + {"self": [1, 25, 768], "other": [1, 25, 768]}, + {"self": [1, 2520, 7, 7], "other": [1, 2520, 7, 7]}, + {"self": [1, 256, 100, 136], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 100, 136], "other": [1, 256, 100, 136]}, + {"self": [1, 256, 1024], "other": [1, 256, 1024]}, + {"self": [1, 256, 128, 128], "other": [1, 256, 128, 128]}, + {"self": [1, 256, 1280], "other": [1, 256, 1280]}, + {"self": [1, 256, 14, 14], "other": [1, 256, 14, 14]}, + {"self": [1, 256, 180, 320], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 180, 320], "other": [1, 256, 180, 320]}, + {"self": [1, 256, 200, 272], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 200, 272], "other": [1, 256, 200, 272]}, + {"self": [1, 256, 256], "other": [1, 256, 256]}, + {"self": [1, 256, 256], "other": [256]}, + {"self": [1, 256, 28, 28], "other": [1, 256, 28, 28]}, + {"self": [1, 256, 38, 38], "other": [1, 256, 38, 38]}, + {"self": [1, 256, 384], "other": [1, 256, 384]}, + {"self": [1, 256, 45, 80], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 50, 68], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 50, 68], "other": [1, 256, 50, 68]}, + {"self": [1, 256, 512], "other": [1, 256, 512]}, + {"self": [1, 256, 56, 56], "other": [1, 256, 56, 56]}, + {"self": [1, 256, 64, 64], "other": [1, 256, 64, 64]}, + {"self": [1, 256, 75, 75], "other": [1, 256, 75, 75]}, + {"self": [1, 256, 90, 160], "other": [1, 256, 1, 1]}, + {"self": [1, 272, 12, 12], "other": [1, 272, 12, 12]}, + {"self": [1, 28, 28, 192], "other": [1, 28, 28, 192]}, + {"self": [1, 28, 28, 256], "other": [1, 28, 28, 256]}, + {"self": [1, 288, 14, 14], "other": [1, 288, 14, 14]}, + {"self": [1, 2904, 24, 24], "other": [1, 2904, 24, 24]}, + {"self": [1, 3, 16, 16, 2], "other": [1, 3, 16, 16, 2]}, + {"self": [1, 3, 300, 300], "other": [1, 3, 300, 300]}, + {"self": [1, 3, 32, 32, 2], "other": [1, 3, 32, 32, 2]}, + {"self": [1, 3, 320, 320], "other": [1, 3, 320, 320]}, + {"self": [1, 3, 64, 64, 2], "other": [1, 3, 64, 64, 2]}, + {"self": [1, 3, 800, 1066], "other": [1, 3, 800, 1066]}, + {"self": [1, 300, 512], "other": [1, 300, 512]}, + {"self": [1, 3024, 7, 7], "other": [1, 3024, 7, 7]}, + {"self": [1, 32, 1536], "other": [1, 32, 1536]}, + {"self": [1, 32, 24576], "other": [1, 32, 24576]}, + {"self": [1, 32, 28, 28], "other": [1, 32, 28, 28]}, + {"self": [1, 32, 32, 192], "other": [1, 32, 32, 192]}, + {"self": [1, 32, 32, 256], "other": [1, 32, 32, 256]}, + {"self": [1, 32, 49, 49], "other": [1, 32, 49, 49]}, + {"self": [1, 32, 56, 56], "other": [1, 32, 56, 56]}, + {"self": [1, 32, 64, 64], "other": [1, 32, 64, 64]}, + {"self": [1, 32, 75, 75], "other": [1, 32, 75, 75]}, + {"self": [1, 32, 95, 95], "other": [1, 32, 95, 95]}, + {"self": [1, 320, 14, 14], "other": [1, 320, 14, 14]}, + {"self": [1, 320, 64, 64], "other": [1, 320, 1, 1]}, + {"self": [1, 320, 64, 64], "other": [1, 320, 64, 64]}, + {"self": [1, 336, 14, 14], "other": [1, 336, 14, 14]}, + {"self": [1, 336, 56, 56], "other": [1, 336, 56, 56]}, + {"self": [1, 36, 28, 28], "other": [1, 36, 28, 28]}, + {"self": [1, 3712, 7, 7], "other": [1, 3712, 7, 7]}, + {"self": [1, 4, 12, 49, 49], "other": [1, 4, 1, 49, 49]}, + {"self": [1, 4, 12, 64, 64], "other": [1, 4, 1, 64, 64]}, + {"self": [1, 4, 16, 49, 49], "other": [1, 4, 1, 49, 49]}, + {"self": [1, 4, 16, 64, 64], "other": [1, 4, 1, 64, 64]}, + {"self": [1, 4, 768], "other": [1, 4, 768]}, + {"self": [1, 4, 768], "other": [4, 768]}, + {"self": [1, 40, 14, 14], "other": [1, 40, 14, 14]}, + {"self": [1, 40, 28, 28], "other": [1, 40, 28, 28]}, + {"self": [1, 40, 30, 30], "other": [1, 40, 30, 30]}, + {"self": [1, 40, 40, 40], "other": [1, 40, 40, 40]}, + {"self": [1, 400, 7, 7], "other": [1, 400, 7, 7]}, + {"self": [1, 408, 14, 14], "other": [1, 408, 14, 14]}, + {"self": [1, 4096, 256], "other": [256]}, + {"self": [1, 4096, 320], "other": [1, 4096, 320]}, + {"self": [1, 4096, 64], "other": [1, 4096, 64]}, + {"self": [1, 432, 14, 14], "other": [1, 432, 14, 14]}, + {"self": [1, 440, 7, 7], "other": [1, 440, 7, 7]}, + {"self": [1, 448, 28, 28], "other": [1, 448, 28, 28]}, + {"self": [1, 45, 3072], "other": [1, 45, 3072]}, + {"self": [1, 45, 768], "other": [1, 45, 768]}, + {"self": [1, 48, 14, 14], "other": [1, 48, 14, 14]}, + {"self": [1, 48, 33, 33], "other": [1, 48, 33, 33]}, + {"self": [1, 48, 38, 38], "other": [1, 48, 38, 38]}, + {"self": [1, 48, 56, 56], "other": [1, 48, 56, 56]}, + {"self": [1, 4800, 128], "other": [1, 4800, 128]}, + {"self": [1, 5, 1024], "other": [1, 5, 1024]}, + {"self": [1, 5, 16, 32], "other": [1, 5, 16, 32]}, + {"self": [1, 5, 4096], "other": [1, 5, 4096]}, + {"self": [1, 50, 1024], "other": [1, 50, 1024]}, + {"self": [1, 50, 768], "other": [1, 50, 768]}, + {"self": [1, 512, 100, 136], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 100, 136], "other": [1, 512, 100, 136]}, + {"self": [1, 512, 14, 14], "other": [1, 512, 14, 14]}, + {"self": [1, 512, 23, 40], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 25, 34], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 28, 28], "other": [1, 512, 28, 28]}, + {"self": [1, 512, 32, 32], "other": [1, 512, 32, 32]}, + {"self": [1, 512, 45, 80], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 50, 68], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 7, 7], "other": [1, 512, 7, 7]}, + {"self": [1, 512, 90, 160], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 90, 160], "other": [1, 512, 90, 160]}, + {"self": [1, 528, 96, 96], "other": [1, 528, 96, 96]}, + {"self": [1, 56, 48, 48], "other": [1, 56, 48, 48]}, + {"self": [1, 56, 56, 128], "other": [1, 56, 56, 128]}, + {"self": [1, 56, 56, 96], "other": [1, 56, 56, 96]}, + {"self": [1, 576, 14, 14], "other": [1, 576, 14, 14]}, + {"self": [1, 59, 1024], "other": [1, 59, 1024]}, + {"self": [1, 6, 1, 15], "other": [1, 1, 1, 15]}, + {"self": [1, 6, 1, 15], "other": [1, 6, 1, 15]}, + {"self": [1, 6, 1, 17], "other": [1, 1, 1, 17]}, + {"self": [1, 6, 1, 17], "other": [1, 6, 1, 17]}, + {"self": [1, 6, 1, 1], "other": [1, 1, 1, 1]}, + {"self": [1, 6, 1, 1], "other": [1, 6, 1, 1]}, + {"self": [1, 6, 1, 2], "other": [1, 1, 1, 2]}, + {"self": [1, 6, 1, 2], "other": [1, 6, 1, 2]}, + {"self": [1, 6, 15, 15], "other": [1, 1, 1, 15]}, + {"self": [1, 6, 15, 15], "other": [1, 6, 15, 15]}, + {"self": [1, 64, 120, 160], "other": [1, 64, 120, 160]}, + {"self": [1, 64, 1280], "other": [1, 64, 1280]}, + {"self": [1, 64, 14, 14], "other": [1, 64, 14, 14]}, + {"self": [1, 64, 180, 320], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 200, 272], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 240, 320], "other": [1, 64, 240, 320]}, + {"self": [1, 64, 256, 256], "other": [1, 64, 256, 256]}, + {"self": [1, 64, 28, 28], "other": [1, 64, 28, 28]}, + {"self": [1, 64, 3, 49, 49], "other": [1, 64, 1, 49, 49]}, + {"self": [1, 64, 3, 64, 64], "other": [1, 64, 1, 64, 64]}, + {"self": [1, 64, 30, 40], "other": [1, 64, 30, 40]}, + {"self": [1, 64, 360, 640], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 4, 49, 49], "other": [1, 64, 1, 49, 49]}, + {"self": [1, 64, 4, 64, 64], "other": [1, 64, 1, 64, 64]}, + {"self": [1, 64, 400, 544], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 480, 640], "other": [1, 64, 480, 640]}, + {"self": [1, 64, 56, 56], "other": [1, 64, 56, 56]}, + {"self": [1, 64, 60, 80], "other": [1, 64, 60, 80]}, + {"self": [1, 64, 6144], "other": [1, 64, 6144]}, + {"self": [1, 64, 64, 128], "other": [1, 64, 64, 128]}, + {"self": [1, 64, 64, 96], "other": [1, 64, 64, 96]}, + {"self": [1, 64, 9, 9], "other": [1, 1, 1, 9]}, + {"self": [1, 640, 32, 32], "other": [1, 640, 1, 1]}, + {"self": [1, 640, 32, 32], "other": [1, 640, 32, 32]}, + {"self": [1, 672, 28, 28], "other": [1, 672, 28, 28]}, + {"self": [1, 672, 7, 7], "other": [1, 672, 7, 7]}, + {"self": [1, 696, 28, 28], "other": [1, 696, 28, 28]}, + {"self": [1, 7, 3072], "other": [1, 7, 3072]}, + {"self": [1, 7, 4544], "other": [1, 7, 4544]}, + {"self": [1, 7, 7, 1024], "other": [1, 7, 7, 1024]}, + {"self": [1, 7, 7, 768], "other": [1, 7, 7, 768]}, + {"self": [1, 7, 768], "other": [1, 7, 768]}, + {"self": [1, 71, 7, 64], "other": [1, 71, 7, 64]}, + {"self": [1, 71, 7, 7], "other": [7, 7]}, + {"self": [1, 72, 14, 14], "other": [1, 72, 14, 14]}, + {"self": [1, 72, 56, 56], "other": [1, 72, 56, 56]}, + {"self": [1, 720, 14, 14], "other": [1, 720, 14, 14]}, + {"self": [1, 728, 19, 19], "other": [1, 728, 19, 19]}, + {"self": [1, 728, 38, 38], "other": [1, 728, 38, 38]}, + {"self": [1, 7392, 12, 12], "other": [1, 7392, 12, 12]}, + {"self": [1, 768, 384], "other": [384]}, + {"self": [1, 784, 7, 7], "other": [1, 784, 7, 7]}, + {"self": [1, 8, 1, 10], "other": [1, 1, 1, 10]}, + {"self": [1, 8, 1, 10], "other": [1, 8, 1, 10]}, + {"self": [1, 8, 1, 1], "other": [1, 1, 1, 1]}, + {"self": [1, 8, 1, 1], "other": [1, 8, 1, 1]}, + {"self": [1, 8, 1, 2], "other": [1, 1, 1, 2]}, + {"self": [1, 8, 1, 2], "other": [1, 8, 1, 2]}, + {"self": [1, 8, 10, 10], "other": [1, 1, 1, 10]}, + {"self": [1, 8, 10, 10], "other": [1, 8, 10, 10]}, + {"self": [1, 8, 256, 2048], "other": [1, 1, 1, 2048]}, + {"self": [1, 8, 768], "other": [1, 8, 768]}, + {"self": [1, 8, 8, 1024], "other": [1, 8, 8, 1024]}, + {"self": [1, 8, 8, 768], "other": [1, 8, 8, 768]}, + {"self": [1, 80, 10, 10], "other": [1, 80, 10, 10]}, + {"self": [1, 80, 14, 14], "other": [1, 80, 14, 14]}, + {"self": [1, 80, 15, 15], "other": [1, 80, 15, 15]}, + {"self": [1, 80, 20, 20], "other": [1, 80, 20, 20]}, + {"self": [1, 80, 56, 56], "other": [1, 80, 56, 56]}, + {"self": [1, 88, 17, 17], "other": [1, 88, 17, 17]}, + {"self": [1, 888, 7, 7], "other": [1, 888, 7, 7]}, + {"self": [1, 896, 14, 14], "other": [1, 896, 14, 14]}, + {"self": [1, 9, 1024], "other": [1, 9, 1024]}, + {"self": [1, 9, 128], "other": [1, 9, 128]}, + {"self": [1, 9, 16384], "other": [1, 9, 16384]}, + {"self": [1, 9, 2048], "other": [1, 9, 2048]}, + {"self": [1, 9, 3072], "other": [1, 9, 3072]}, + {"self": [1, 9, 4096], "other": [1, 9, 4096]}, + {"self": [1, 9, 768], "other": [1, 9, 768]}, + {"self": [1, 9, 8192], "other": [1, 9, 8192]}, + {"self": [1, 912, 7, 7], "other": [1, 912, 7, 7]}, + {"self": [1, 96, 14, 14], "other": [1, 96, 14, 14]}, + {"self": [1, 96, 19, 19], "other": [1, 96, 19, 19]}, + {"self": [1, 96, 56, 56], "other": [1, 96, 56, 56]}, + {"self": [1, 96, 7, 7], "other": [1, 96, 7, 7]}, + {"self": [1, 96, 80], "other": [1, 96, 80]}, + {"self": [10, 10], "other": [10, 10]}, + {"self": [100, 1, 256], "other": [100, 1, 256]}, + {"self": [12, 24, 24], "other": [12, 24, 24]}, + {"self": [13600, 1, 4], "other": [1, 9, 4]}, + {"self": [15, 15], "other": [15, 15]}, + {"self": [16, 6, 49, 49], "other": [1, 6, 49, 49]}, + {"self": [16, 6, 64, 64], "other": [1, 6, 64, 64]}, + {"self": [16, 8, 49, 49], "other": [1, 8, 49, 49]}, + {"self": [16, 8, 64, 64], "other": [1, 8, 64, 64]}, + {"self": [2, 7, 512], "other": [1, 7, 512]}, + {"self": [2, 7, 512], "other": [2, 7, 512]}, + {"self": [2, 8, 7, 7], "other": [2, 1, 7, 7]}, + {"self": [2048, 262], "other": [262]}, + {"self": [221, 1, 4], "other": [1, 9, 4]}, + {"self": [25, 4], "other": [25, 1]}, + {"self": [3234, 1], "other": [3234, 1]}, + {"self": [3234, 2], "other": [3234, 2]}, + {"self": [3234], "other": [3234]}, + {"self": [3400, 1, 4], "other": [1, 9, 4]}, + {"self": [4, 12, 49, 49], "other": [1, 12, 49, 49]}, + {"self": [4, 12, 64, 64], "other": [1, 12, 64, 64]}, + {"self": [4, 16, 49, 49], "other": [1, 16, 49, 49]}, + {"self": [4, 16, 64, 64], "other": [1, 16, 64, 64]}, + {"self": [59, 1024], "other": [59, 1024]}, + {"self": [63, 1, 4], "other": [1, 9, 4]}, + {"self": [64, 3, 49, 49], "other": [1, 3, 49, 49]}, + {"self": [64, 3, 64, 64], "other": [1, 3, 64, 64]}, + {"self": [64, 4, 49, 49], "other": [1, 4, 49, 49]}, + {"self": [64, 4, 64, 64], "other": [1, 4, 64, 64]}, + {"self": [850, 1, 4], "other": [1, 9, 4]}, + {"self": [8732, 1], "other": [8732, 1]}, + {"self": [8732, 2], "other": [8732, 2]}, + {"self": [8732], "other": [8732]}, + {"self": [], "other": []}, + {"self": [920, 1, 256], "other": [256]}, + {"self": [920, 1, 256], "other": [920, 1, 256]}, + {"self": [1, 1, 1, 42], "other": -6.0}, + {"self": [1, 1, 1, 42], "other": 0.5}, + {"self": [1, 1, 1, 42], "other": 1.0}, + {"self": [1, 1, 1, 42], "other": 1.0}, + {"self": [1, 1, 1, 42], "other": 2.0}, + {"self": [1, 1, 1024], "other": 1.0}, + {"self": [1, 1, 1], "other": 1e-06}, + {"self": [1, 1, 224, 224], "other": -0.030000000000000027}, + {"self": [1, 1, 224, 224], "other": -0.08799999999999997}, + {"self": [1, 1, 224, 224], "other": -0.18799999999999994}, + {"self": [1, 1, 3072], "other": 1.0}, + {"self": [1, 1, 32, 1], "other": -6.0}, + {"self": [1, 1, 32, 1], "other": 0.5}, + {"self": [1, 1, 32, 1], "other": 1.0}, + {"self": [1, 1, 32, 1], "other": 1.0}, + {"self": [1, 1, 32, 1], "other": 2.0}, + {"self": [1, 1, 4096], "other": 1.0}, + {"self": [1, 1, 40], "other": 1e-06}, + {"self": [1, 10, 1], "other": 1e-06}, + {"self": [1, 1024, 1, 1], "other": 0.0}, + {"self": [1, 1024, 1, 1], "other": 1e-05}, + {"self": [1, 10], "other": 0.0}, + {"self": [1, 10], "other": 1.0}, + {"self": [1, 12, 3072], "other": 1.0}, + {"self": [1, 128, 1, 1], "other": 0.0}, + {"self": [1, 128, 1, 1], "other": 1e-05}, + {"self": [1, 14, 3072], "other": 1.0}, + {"self": [1, 15, 1024], "other": 1.0}, + {"self": [1, 15, 1], "other": 1e-06}, + {"self": [1, 19], "other": 2.0}, + {"self": [1, 1], "other": 0.0}, + {"self": [1, 1], "other": 16.0}, + {"self": [1, 1], "other": 2.0}, + {"self": [1, 2048, 1, 1], "other": 0.0}, + {"self": [1, 2048, 1, 1], "other": 1e-05}, + {"self": [1, 23, 1], "other": 1e-06}, + {"self": [1, 256, 1, 1], "other": 0.0}, + {"self": [1, 256, 1, 1], "other": 1e-05}, + {"self": [1, 32, 6144], "other": 1.0}, + {"self": [1, 32, 6144], "other": 1.0}, + {"self": [1, 45, 3072], "other": 1.0}, + {"self": [1, 5, 4096], "other": 1.0}, + {"self": [1, 512, 1, 1], "other": 0.0}, + {"self": [1, 512, 1, 1], "other": 1e-05}, + {"self": [1, 59], "other": 2.0}, + {"self": [1, 64, 1, 1], "other": 0.0}, + {"self": [1, 64, 1, 1], "other": 1e-05}, + {"self": [1, 7, 3072], "other": 1.0}, + {"self": [1, 9, 128], "other": 1.0}, + {"self": [1, 9, 16384], "other": 1.0}, + {"self": [1, 9, 3072], "other": 1.0}, + {"self": [1, 9, 4096], "other": 1.0}, + {"self": [1, 9, 8192], "other": 1.0}, + {"self": [10, 10], "other": 0.0}, + {"self": [10, 10], "other": 8.0}, + {"self": [100], "other": 0.0}, + {"self": [1066], "other": 0.5}, + {"self": [10], "other": 0.5}, + {"self": [120], "other": 0.5}, + {"self": [128], "other": 0.5}, + {"self": [12], "other": 0.0}, + {"self": [136], "other": 0.0}, + {"self": [14], "other": 0.0}, + {"self": [15, 15], "other": 0.0}, + {"self": [15, 15], "other": 8.0}, + {"self": [160], "other": 0.5}, + {"self": [16], "other": 0.0}, + {"self": [17, 17], "other": 0.0}, + {"self": [17, 17], "other": 16.0}, + {"self": [19], "other": 0.5}, + {"self": [1], "other": 0.5}, + {"self": [2, 2], "other": 0.0}, + {"self": [2, 2], "other": 16.0}, + {"self": [20], "other": 0.5}, + {"self": [23], "other": 0.0}, + {"self": [24, 24], "other": 160.0}, + {"self": [240], "other": 0.5}, + {"self": [28], "other": 0.0}, + {"self": [2], "other": 0.5}, + {"self": [300], "other": 0.5}, + {"self": [30], "other": 0.5}, + {"self": [320], "other": 0.5}, + {"self": [32], "other": 0.0}, + {"self": [38], "other": 0.5}, + {"self": [3], "other": 0.5}, + {"self": [40], "other": 0.0}, + {"self": [40], "other": 0.5}, + {"self": [480], "other": 0.5}, + {"self": [50], "other": 0.0}, + {"self": [56], "other": 0.0}, + {"self": [5], "other": 0.5}, + {"self": [60], "other": 0.5}, + {"self": [640], "other": 0.5}, + {"self": [64], "other": 0.0}, + {"self": [68], "other": 0.0}, + {"self": [7], "other": 0.0}, + {"self": [800], "other": 0.5}, + {"self": [80], "other": 0.5}, + {"self": [], "other": 1}, + ], + # {"self": [s0 + 1, s0 + 1], "other": 16}, + # {"self": [s0 + 1, s0 + 1], "other": 0}, + # {"self": [1, 16, 1, "s0 + 1"], "other": [1, 1, 1, "s0 + 1"]}, + # {"self": [1, 16, 1, "s0 + 1"], "other": [1, 16, 1, "s0 + 1"]}, + # {"self": [1, 8, 1, "s0 + 1"], "other": [1, 1, 1, "s0 + 1"]}, + # {"self": [1, 8, 1, "s0 + 1"], "other": [1, 8, 1, "s0 + 1"]}, + # {"self": [1, 6, 1, "s0 + 1"], "other": [1, 1, 1, "s0 + 1"]}, + # {"self": [1, 6, 1, "s0 + 1"], "other": [1, 6, 1, "s0 + 1"]}, + # {"self": [1, 12, 1, "s0 + 1"], "other": [1, 1, 1, "s0 + 1"]}, + # {"self": [1, 12, 1, "s0 + 1"], "other": [1, 12, 1, "s0 + 1"]}, + # {"self": [1, 32, "s0", "s1"], "other": [1, 32, "s0", "s1"]}, + # {"self": [1, 12, 1, "s10 + 1"], "other": [1, 1, 1, "s10 + 1"]}, + # {"self": [1, 64, "s1", "s2"], "other": [1, 64, "s1", "s2"]}, + # {"self": [1, 128, "s1", "s2"], "other": [1, 128, "s1", "s2"]}, + # {"self": [1, 16, 1, "s10 + 1"], "other": [1, 1, 1, "s10 + 1"]}, + # {"self": [1, 256, "s1", "s2"], "other": [1, 256, "s1", "s2"]}, + # {"self": [1, "s0", 768], "other": [1, "s0", 768]} + "input_a_dtype": [ttnn.bfloat16], + "input_b_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_a_memory_config, + input_b_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape["self"]) + + if isinstance(input_shape["other"], list): + torch_input_tensor_b = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(input_shape["other"]) + else: + torch_input_tensor_b = torch.tensor(input_shape["other"], dtype=torch.float32) + # torch_input_tensor_b = input_shape["other"] + + golden_function = ttnn.get_golden_function(ttnn.add) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + # if isinstance(input_shape["other"], list): + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + layout=input_b_layout, + device=device, + memory_config=input_b_memory_config, + ) + # else: + # input_tensor_b = input_shape["other"] + + start_time = start_measuring_time() + result = ttnn.add(input_tensor_a, input_tensor_b) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, pcc=0.9999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/binary/eq/eq_scalar_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/binary/eq/eq_scalar_pytorch2.py new file mode 100644 index 00000000000..6a8a955225f --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/binary/eq/eq_scalar_pytorch2.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1, 256], + [1, 16], + [1, 7], + [1, 7], + [16, 49, 49], + [16, 64, 64], + [1], + [1], + [4, 49, 49], + [4, 64, 64], + [64, 49, 49], + [64, 64, 64], + ], + "scalar": [1, 0, 1, 50256, 0, 0, 1, 50256, 0, 0, 0, 0], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + scalar, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.eq) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.eq(input_tensor_a, scalar, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/binary/floor_divide/floor_divide_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/binary/floor_divide/floor_divide_pytorch2.py new file mode 100644 index 00000000000..efb9b66b64b --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/binary/floor_divide/floor_divide_pytorch2.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [128], + ], + "scalar": [ + 2, + ], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +def mesh_device_fixture(): + device = ttnn.open_device(device_id=0) + assert ttnn.device.is_wormhole_b0(device), "This op is available for Wormhole_B0 only" + yield (device, "Wormhole_B0") + ttnn.close_device(device) + del device + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + scalar, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + torch_input_tensor_b = torch.tensor(scalar, dtype=torch.float32) + + golden_function = ttnn.get_golden_function(ttnn.floor_div) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.floor_div(input_tensor_a, input_tensor_b, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/binary/gt/gt_scalar_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/binary/gt/gt_scalar_pytorch2.py new file mode 100644 index 00000000000..0910775d42a --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/binary/gt/gt_scalar_pytorch2.py @@ -0,0 +1,80 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [10, 10], + [15, 15], + [], + ], + "scalar": [0, 0], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + scalar, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.gt) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.gt(input_tensor_a, scalar, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/binary/le/le_tensor_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/binary/le/le_tensor_pytorch2.py new file mode 100644 index 00000000000..d119ced4892 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/binary/le/le_tensor_pytorch2.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape_a": [[1, 1, 1]], + "input_shape_b": [[1, 1, 1]], + "input_a_dtype": [ttnn.bfloat16], + "input_b_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape_a, + input_shape_b, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_a_memory_config, + input_b_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape_a) + torch_input_tensor_b = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(input_shape_b) + + golden_function = ttnn.get_golden_function(ttnn.le) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + layout=input_b_layout, + device=device, + memory_config=input_b_memory_config, + ) + start_time = start_measuring_time() + result = ttnn.le(input_tensor_a, input_tensor_b) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/binary/multiply/mul_tensor_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/binary/multiply/mul_tensor_pytorch2.py new file mode 100644 index 00000000000..37c1414c713 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/binary/multiply/mul_tensor_pytorch2.py @@ -0,0 +1,472 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + {"self": [0], "other": 0.5}, + {"self": [1, 1, 1, 10], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 12], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 14], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 15], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 17], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 1], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 201], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 2048], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 256], "other": -3.3895313892515355e38}, + {"self": [1, 1, 1, 25], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 2], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 42], "other": -0.75}, + {"self": [1, 1, 1, 42], "other": 1.25}, + {"self": [1, 1, 1, 42], "other": 1.9761904761904763}, + {"self": [1, 1, 1, 5], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 6], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1, 7], "other": -3.3895313892515355e38}, + {"self": [1, 1, 1, 8], "other": -3.3895313892515355e38}, + {"self": [1, 1, 1, 9], "other": -3.4028234663852886e38}, + {"self": [1, 1, 1024], "other": 0.03125}, + {"self": [1, 1, 1024], "other": 0.044715}, + {"self": [1, 1, 1024], "other": 0.125}, + {"self": [1, 1, 1024], "other": 0.5}, + {"self": [1, 1, 1024], "other": 0.7978845608028654}, + {"self": [1, 1, 224, 224], "other": 0.448}, + {"self": [1, 1, 224, 224], "other": 0.45}, + {"self": [1, 1, 224, 224], "other": 0.458}, + {"self": [1, 1, 256], "other": 1.0}, + {"self": [1, 1, 3072], "other": 0.044715}, + {"self": [1, 1, 3072], "other": 0.5}, + {"self": [1, 1, 3072], "other": 0.7978845608028654}, + {"self": [1, 1, 32, 1], "other": -0.75}, + {"self": [1, 1, 32, 1], "other": 1.25}, + {"self": [1, 1, 32, 1], "other": 1.5625}, + {"self": [1, 1, 4096], "other": 0.044715}, + {"self": [1, 1, 4096], "other": 0.5}, + {"self": [1, 1, 4096], "other": 0.7978845608028654}, + {"self": [1, 1, 480, 640], "other": 10.0}, + {"self": [1, 1, 512], "other": 0.04419417382415922}, + {"self": [1, 1, 768], "other": 0.03608439182435161}, + {"self": [1, 1, 768], "other": 0.125}, + {"self": [1, 12, 3072], "other": 0.044715}, + {"self": [1, 12, 3072], "other": 0.5}, + {"self": [1, 12, 3072], "other": 0.7978845608028654}, + {"self": [1, 12, 64, 64], "other": 16.0}, + {"self": [1, 14, 3072], "other": 0.044715}, + {"self": [1, 14, 3072], "other": 0.5}, + {"self": [1, 14, 3072], "other": 0.7978845608028654}, + {"self": [1, 15, 1024], "other": 0.044715}, + {"self": [1, 15, 1024], "other": 0.5}, + {"self": [1, 15, 1024], "other": 0.7978845608028654}, + {"self": [1, 16, 64, 64], "other": 16.0}, + {"self": [1, 160], "other": 1.0}, + {"self": [1, 19, 1024], "other": 0.125}, + {"self": [1, 19, 1024], "other": 32.0}, + {"self": [1, 1], "other": 0.0}, + {"self": [1, 1], "other": 16.0}, + {"self": [1, 1], "other": 50258.0}, + {"self": [1, 1], "other": 50259.0}, + {"self": [1, 1], "other": 50359.0}, + {"self": [1, 1], "other": 50363.0}, + {"self": [1, 23, 40], "other": 6.283185307179586}, + {"self": [1, 24, 49, 32], "other": 0.1767766952966369}, + {"self": [1, 24, 64, 64], "other": 16.0}, + {"self": [1, 24, 768], "other": 0.125}, + {"self": [1, 3, 16, 16, 2], "other": 2.0}, + {"self": [1, 3, 32, 32, 2], "other": 2.0}, + {"self": [1, 3, 64, 64, 2], "other": 2.0}, + {"self": [1, 3, 64, 64], "other": 16.0}, + {"self": [1, 32, 49, 32], "other": 0.1767766952966369}, + {"self": [1, 32, 6144], "other": 0.044715}, + {"self": [1, 32, 6144], "other": 0.5}, + {"self": [1, 32, 6144], "other": 0.79788456}, + {"self": [1, 32, 64, 64], "other": 16.0}, + {"self": [1, 4, 64, 64], "other": 16.0}, + {"self": [1, 45, 3072], "other": 0.044715}, + {"self": [1, 45, 3072], "other": 0.5}, + {"self": [1, 45, 3072], "other": 0.7978845608028654}, + {"self": [1, 5, 4096], "other": 0.044715}, + {"self": [1, 5, 4096], "other": 0.5}, + {"self": [1, 5, 4096], "other": 0.7978845608028654}, + {"self": [1, 50, 3072], "other": 1.702}, + {"self": [1, 50, 768], "other": 0.125}, + {"self": [1, 59, 1024], "other": 0.125}, + {"self": [1, 6, 64, 64], "other": 16.0}, + {"self": [1, 7, 3072], "other": 0.044715}, + {"self": [1, 7, 3072], "other": 0.5}, + {"self": [1, 7, 3072], "other": 0.7978845608028654}, + {"self": [1, 8, 64, 64], "other": 16.0}, + {"self": [1, 9, 128], "other": 0.044715}, + {"self": [1, 9, 128], "other": 0.5}, + {"self": [1, 9, 128], "other": 0.7978845608028654}, + {"self": [1, 9, 16384], "other": 0.044715}, + {"self": [1, 9, 16384], "other": 0.5}, + {"self": [1, 9, 16384], "other": 0.7978845608028654}, + {"self": [1, 9, 3072], "other": 0.044715}, + {"self": [1, 9, 3072], "other": 0.5}, + {"self": [1, 9, 3072], "other": 0.7978845608028654}, + {"self": [1, 9, 4096], "other": 0.044715}, + {"self": [1, 9, 4096], "other": 0.5}, + {"self": [1, 9, 4096], "other": 0.7978845608028654}, + {"self": [1, 9, 8192], "other": 0.044715}, + {"self": [1, 9, 8192], "other": 0.5}, + {"self": [1, 9, 8192], "other": 0.7978845608028654}, + {"self": [10, 10], "other": 16.0}, + {"self": [10, 10], "other": 8.0}, + {"self": [100], "other": 0.5}, + {"self": [1066], "other": 0.600375234521576}, + {"self": [120], "other": 0.5}, + {"self": [128], "other": 0.125}, + {"self": [128], "other": 0.25}, + {"self": [128], "other": 0.5}, + {"self": [128], "other": 1.0}, + {"self": [128], "other": 2.0}, + {"self": [12], "other": 32.0}, + {"self": [136], "other": 0.5}, + {"self": [14], "other": 0.5}, + {"self": [15, 15], "other": 16.0}, + {"self": [15, 15], "other": 8.0}, + {"self": [16, 6, 49, 32], "other": 0.1767766952966369}, + {"self": [16, 8, 49, 32], "other": 0.1767766952966369}, + {"self": [160], "other": -9.210340371976184}, + {"self": [160], "other": 0.5}, + {"self": [16], "other": 0.5}, + {"self": [16], "other": 32.0}, + {"self": [17, 17], "other": 16.0}, + {"self": [2, 2], "other": 16.0}, + {"self": [2, 7, 2048], "other": 1.702}, + {"self": [2, 7, 512], "other": 0.125}, + {"self": [23], "other": 31.304347826086957}, + {"self": [240], "other": 0.5}, + {"self": [28], "other": 0.25}, + {"self": [28], "other": 0.5}, + {"self": [300], "other": 1.6}, + {"self": [300], "other": 2.1333333333333333}, + {"self": [30], "other": 0.5}, + {"self": [320], "other": 0.5}, + {"self": [320], "other": 1.0}, + {"self": [320], "other": 1.5}, + {"self": [320], "other": 2.0}, + {"self": [3234, 2], "other": 0.5}, + {"self": [3234], "other": 0.5}, + {"self": [32], "other": 0.5}, + {"self": [4, 12, 49, 32], "other": 0.1767766952966369}, + {"self": [4, 16, 49, 32], "other": 0.1767766952966369}, + {"self": [40], "other": 0.5}, + {"self": [40], "other": 32.0}, + {"self": [480], "other": 0.5}, + {"self": [50], "other": 0.5}, + {"self": [56], "other": 0.125}, + {"self": [56], "other": 0.25}, + {"self": [56], "other": 0.5}, + {"self": [60], "other": 0.5}, + {"self": [64, 3, 49, 32], "other": 0.1767766952966369}, + {"self": [64, 4, 49, 32], "other": 0.1767766952966369}, + {"self": [640], "other": 0.5}, + {"self": [64], "other": 0.5}, + {"self": [68], "other": 0.5}, + {"self": [7], "other": 0.42857142857142855}, + {"self": [800], "other": 0.6}, + {"self": [80], "other": 0.5}, + {"self": [8732, 2], "other": 0.5}, + {"self": [8732], "other": 0.5}, + # vec other + {"self": [0, 1], "other": [0, 1]}, + {"self": [0], "other": []}, + {"self": [1, 1, 1, 17], "other": [1, 1, 1, 17]}, + {"self": [1, 1, 1, 1], "other": [1, 1, 1, 1]}, + {"self": [1, 1, 1, 2], "other": [1, 1, 1, 2]}, + {"self": [1, 1, 1, 42], "other": [1, 1, 1, 42]}, + {"self": [1, 1, 1024], "other": [1, 1, 1024]}, + {"self": [1, 1, 1024], "other": [1, 1, 1]}, + {"self": [1, 1, 16, 32], "other": [1, 1, 1, 32]}, + {"self": [1, 1, 3072], "other": [1, 1, 3072]}, + {"self": [1, 1, 32, 1], "other": [1, 1, 32, 1]}, + {"self": [1, 1, 4096], "other": [1, 1, 4096]}, + {"self": [1, 1, 512], "other": [1, 1, 1]}, + {"self": [1, 1, 7, 64], "other": [1, 1, 7, 64]}, + {"self": [1, 1, 768], "other": [1, 1, 1]}, + {"self": [1, 10, 1024], "other": [1, 10, 1]}, + {"self": [1, 10, 512], "other": [1, 10, 1]}, + {"self": [1, 10, 768], "other": [1, 10, 1]}, + {"self": [1, 1024, 1, 1], "other": [1, 1024, 1, 1]}, + {"self": [1, 1024, 2560], "other": [1, 1024, 2560]}, + {"self": [1, 1024, 45, 80], "other": [1, 1024, 1, 1]}, + {"self": [1, 1024, 50, 68], "other": [1, 1024, 1, 1]}, + {"self": [1, 1024, 7, 7], "other": [1, 1024, 1, 1]}, + {"self": [1, 104, 1, 1], "other": [1, 104, 28, 28]}, + {"self": [1, 1056, 1, 1], "other": [1, 1056, 48, 48]}, + {"self": [1, 10], "other": [1, 10]}, + {"self": [1, 12, 3072], "other": [1, 12, 3072]}, + {"self": [1, 120, 1, 1], "other": [1, 120, 14, 14]}, + {"self": [1, 120, 1, 1], "other": [1, 120, 28, 28]}, + {"self": [1, 120, 1, 1], "other": [1, 120, 40, 40]}, + {"self": [1, 120, 28, 28], "other": [1, 120, 1, 1]}, + {"self": [1, 120, 28, 28], "other": [1, 120, 28, 28]}, + {"self": [1, 1232, 1, 1], "other": [1, 1232, 14, 14]}, + {"self": [1, 128, 1, 1], "other": [1, 128, 1, 1]}, + {"self": [1, 128, 100, 136], "other": [1, 128, 1, 1]}, + {"self": [1, 128, 180, 320], "other": [1, 128, 1, 1]}, + {"self": [1, 128, 200, 272], "other": [1, 128, 1, 1]}, + {"self": [1, 128, 90, 160], "other": [1, 128, 1, 1]}, + {"self": [1, 1392, 1, 1], "other": [1, 1392, 14, 14]}, + {"self": [1, 14, 3072], "other": [1, 14, 3072]}, + {"self": [1, 144, 1, 1], "other": [1, 144, 14, 14]}, + {"self": [1, 144, 1, 1], "other": [1, 144, 28, 28]}, + {"self": [1, 15, 1024], "other": [1, 15, 1024]}, + {"self": [1, 15, 512], "other": [1, 15, 1]}, + {"self": [1, 1512, 1, 1], "other": [1, 1512, 7, 7]}, + {"self": [1, 16, 1, 1], "other": [1, 16, 56, 56]}, + {"self": [1, 184, 14, 14], "other": [1, 184, 14, 14]}, + {"self": [1, 192, 32, 42], "other": [1, 1, 1, 42]}, + {"self": [1, 192, 32, 42], "other": [1, 1, 32, 1]}, + {"self": [1, 1], "other": [1, 160]}, + {"self": [1, 200, 14, 14], "other": [1, 200, 14, 14]}, + {"self": [1, 2016, 1, 1], "other": [1, 2016, 7, 7]}, + {"self": [1, 2048, 1, 1], "other": [1, 2048, 1, 1]}, + {"self": [1, 2048, 23, 40], "other": [1, 2048, 1, 1]}, + {"self": [1, 2048, 25, 34], "other": [1, 2048, 1, 1]}, + {"self": [1, 208, 1, 1], "other": [1, 208, 14, 14]}, + {"self": [1, 216, 1, 1], "other": [1, 216, 28, 28]}, + {"self": [1, 224, 1, 1], "other": [1, 224, 56, 56]}, + {"self": [1, 232, 1, 1], "other": [1, 232, 56, 56]}, + {"self": [1, 24, 64, 64], "other": [24, 1, 1]}, + {"self": [1, 240, 1, 1], "other": [1, 240, 14, 14]}, + {"self": [1, 240, 28, 28], "other": [1, 240, 28, 28]}, + {"self": [1, 256, 1, 1], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 100, 136], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 128, 128], "other": [128, 1]}, + {"self": [1, 256, 128, 128], "other": [128]}, + {"self": [1, 256, 180, 320], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 200, 272], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 45, 80], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 50, 68], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 5120], "other": [1, 256, 5120]}, + {"self": [1, 256, 56, 56], "other": [1, 256, 1, 1]}, + {"self": [1, 256, 90, 160], "other": [1, 256, 1, 1]}, + {"self": [1, 288, 1, 1], "other": [1, 288, 7, 7]}, + {"self": [1, 2904, 1, 1], "other": [1, 2904, 24, 24]}, + {"self": [1, 3, 16, 16, 2], "other": [1, 3, 16, 16, 2]}, + {"self": [1, 3, 16, 16, 2], "other": []}, + {"self": [1, 3, 300, 300], "other": [300, 1]}, + {"self": [1, 3, 300, 300], "other": [300]}, + {"self": [1, 3, 32, 32, 2], "other": [1, 3, 32, 32, 2]}, + {"self": [1, 3, 32, 32, 2], "other": []}, + {"self": [1, 3, 320, 320], "other": [320, 1]}, + {"self": [1, 3, 320, 320], "other": [320]}, + {"self": [1, 3, 64, 64, 2], "other": [1, 3, 64, 64, 2]}, + {"self": [1, 3, 64, 64, 2], "other": []}, + {"self": [1, 3, 800, 1066], "other": [1066]}, + {"self": [1, 3, 800, 1066], "other": [800, 1]}, + {"self": [1, 3024, 1, 1], "other": [1, 3024, 7, 7]}, + {"self": [1, 32, 6144], "other": [1, 32, 6144]}, + {"self": [1, 32, 64, 64], "other": [32, 1, 1]}, + {"self": [1, 320, 1, 1], "other": [1, 320, 14, 14]}, + {"self": [1, 32], "other": [1, 32]}, + {"self": [1, 336, 1, 1], "other": [1, 336, 14, 14]}, + {"self": [1, 3712, 1, 1], "other": [1, 3712, 7, 7]}, + {"self": [1, 4096, 1280], "other": [1, 4096, 1280]}, + {"self": [1, 440, 1, 1], "other": [1, 440, 7, 7]}, + {"self": [1, 448, 1, 1], "other": [1, 448, 28, 28]}, + {"self": [1, 45, 3072], "other": [1, 45, 3072]}, + {"self": [1, 48, 1, 1], "other": [1, 48, 56, 56]}, + {"self": [1, 480, 1, 1], "other": [1, 480, 10, 10]}, + {"self": [1, 480, 1, 1], "other": [1, 480, 14, 14]}, + {"self": [1, 480, 1, 1], "other": [1, 480, 20, 20]}, + {"self": [1, 480, 14, 14], "other": [1, 480, 1, 1]}, + {"self": [1, 480, 14, 14], "other": [1, 480, 14, 14]}, + {"self": [1, 5, 16, 32], "other": [1, 5, 1, 32]}, + {"self": [1, 5, 4096], "other": [1, 5, 4096]}, + {"self": [1, 50, 3072], "other": [1, 50, 3072]}, + {"self": [1, 512, 1, 1], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 1, 1], "other": [1, 512, 38, 38]}, + {"self": [1, 512, 100, 136], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 23, 40], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 25, 34], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 28, 28], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 45, 80], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 50, 68], "other": [1, 512, 1, 1]}, + {"self": [1, 512, 90, 160], "other": [1, 512, 1, 1]}, + {"self": [1, 528, 1, 1], "other": [1, 528, 96, 96]}, + {"self": [1, 576, 1, 1], "other": [1, 576, 14, 14]}, + {"self": [1, 576, 1, 1], "other": [1, 576, 7, 7]}, + {"self": [1, 59], "other": [1, 59]}, + {"self": [1, 60], "other": [1, 60]}, + {"self": [1, 64, 1, 1], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 1, 1], "other": [1, 64, 56, 56]}, + {"self": [1, 64, 120, 160], "other": [1, 1, 120, 160]}, + {"self": [1, 64, 120, 160], "other": [120, 1]}, + {"self": [1, 64, 120, 160], "other": [160]}, + {"self": [1, 64, 180, 320], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 200, 272], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 240, 320], "other": [240, 1]}, + {"self": [1, 64, 240, 320], "other": [320]}, + {"self": [1, 64, 30, 40], "other": [1, 1, 30, 40]}, + {"self": [1, 64, 30, 40], "other": [30, 1]}, + {"self": [1, 64, 30, 40], "other": [40]}, + {"self": [1, 64, 360, 640], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 400, 544], "other": [1, 64, 1, 1]}, + {"self": [1, 64, 480, 640], "other": [480, 1]}, + {"self": [1, 64, 480, 640], "other": [640]}, + {"self": [1, 64, 5120], "other": [1, 64, 5120]}, + {"self": [1, 64, 60, 80], "other": [1, 1, 60, 80]}, + {"self": [1, 64, 60, 80], "other": [60, 1]}, + {"self": [1, 64, 60, 80], "other": [80]}, + {"self": [1, 672, 1, 1], "other": [1, 672, 10, 10]}, + {"self": [1, 672, 1, 1], "other": [1, 672, 14, 14]}, + {"self": [1, 672, 1, 1], "other": [1, 672, 20, 20]}, + {"self": [1, 672, 1, 1], "other": [1, 672, 7, 7]}, + {"self": [1, 672, 14, 14], "other": [1, 672, 1, 1]}, + {"self": [1, 672, 14, 14], "other": [1, 672, 14, 14]}, + {"self": [1, 672, 7, 7], "other": [1, 672, 1, 1]}, + {"self": [1, 696, 1, 1], "other": [1, 696, 28, 28]}, + {"self": [1, 7, 3072], "other": [1, 7, 3072]}, + {"self": [1, 71, 7, 64], "other": [1, 1, 7, 64]}, + {"self": [1, 72, 1, 1], "other": [1, 72, 28, 28]}, + {"self": [1, 72, 1, 1], "other": [1, 72, 40, 40]}, + {"self": [1, 72, 1, 1], "other": [1, 72, 56, 56]}, + {"self": [1, 72, 28, 28], "other": [1, 72, 1, 1]}, + {"self": [1, 72, 56, 56], "other": [1, 72, 56, 56]}, + {"self": [1, 7392, 1, 1], "other": [1, 7392, 12, 12]}, + {"self": [1, 768, 14, 14], "other": [1, 768, 1, 1]}, + {"self": [1, 784, 1, 1], "other": [1, 784, 7, 7]}, + {"self": [1, 888, 1, 1], "other": [1, 888, 7, 7]}, + {"self": [1, 896, 1, 1], "other": [1, 896, 14, 14]}, + {"self": [1, 9, 128], "other": [1, 9, 128]}, + {"self": [1, 9, 16384], "other": [1, 9, 16384]}, + {"self": [1, 9, 3072], "other": [1, 9, 3072]}, + {"self": [1, 9, 4096], "other": [1, 9, 4096]}, + {"self": [1, 9, 8192], "other": [1, 9, 8192]}, + {"self": [1, 96, 1, 1], "other": [1, 96, 14, 14]}, + {"self": [1, 960, 1, 1], "other": [1, 960, 7, 7]}, + {"self": [1, 960, 7, 7], "other": [1, 960, 1, 1]}, + {"self": [1, 960, 7, 7], "other": [1, 960, 7, 7]}, + {"self": [100], "other": []}, + {"self": [1024], "other": [1, 1, 1024]}, + {"self": [1024], "other": [1, 10, 1024]}, + {"self": [1024], "other": [1, 197, 1024]}, + {"self": [12], "other": []}, + {"self": [136], "other": []}, + {"self": [13], "other": []}, + {"self": [16, 1], "other": [1, 1, 32]}, + {"self": [16, 6, 64, 64], "other": [6, 1, 1]}, + {"self": [16, 8, 64, 64], "other": [8, 1, 1]}, + {"self": [17], "other": []}, + {"self": [1], "other": [1]}, + {"self": [2, 1], "other": []}, + {"self": [2, 7, 2048], "other": [2, 7, 2048]}, + {"self": [25], "other": []}, + {"self": [300], "other": []}, + {"self": [3234, 1], "other": [3234, 1]}, + {"self": [3234, 2], "other": [2]}, + {"self": [34], "other": []}, + {"self": [4, 12, 64, 64], "other": [12, 1, 1]}, + {"self": [4, 16, 64, 64], "other": [16, 1, 1]}, + {"self": [50], "other": []}, + {"self": [512], "other": [1, 1, 512]}, + {"self": [512], "other": [1, 10, 512]}, + {"self": [512], "other": [1, 15, 512]}, + {"self": [64, 3, 64, 64], "other": [3, 1, 1]}, + {"self": [64, 4, 64, 64], "other": [4, 1, 1]}, + {"self": [68], "other": []}, + {"self": [768], "other": [1, 1, 768]}, + {"self": [768], "other": [1, 10, 768]}, + {"self": [768], "other": [1, 197, 768]}, + {"self": [7], "other": []}, + {"self": [8732, 1], "other": [8732, 1]}, + {"self": [8732, 2], "other": [2]}, + {"self": [9], "other": []}, + {"self": [], "other": [0, 1]}, + {"self": [], "other": [1, 1, 768]}, + {"self": [], "other": [1, 24, 768]}, + {"self": [], "other": [3234, 1]}, + {"self": [], "other": [8732, 1]}, + ], + "input_a_dtype": [ttnn.bfloat16], + "input_b_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_b_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "input_b_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_b_dtype, + input_a_layout, + input_b_layout, + input_a_memory_config, + input_b_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape["self"]) + + if isinstance(input_shape["other"], list): + torch_input_tensor_b = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_b_dtype + )(input_shape["other"]) + else: + torch_input_tensor_b = torch.tensor(input_shape["other"], dtype=torch.float32) + # torch_input_tensor_b = input_shape["other"] + + golden_function = ttnn.get_golden_function(ttnn.mul) + torch_output_tensor = golden_function(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + # if isinstance(input_shape["other"], list): + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, + dtype=input_b_dtype, + layout=input_b_layout, + device=device, + memory_config=input_b_memory_config, + ) + # else: + # input_tensor_b = input_shape["other"] + + start_time = start_measuring_time() + result = ttnn.mul(input_tensor_a, input_tensor_b) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, pcc=0.99), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_scalar_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_scalar_pytorch2.py new file mode 100644 index 00000000000..5ac35667ba3 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/binary/remainder/remainder_scalar_pytorch2.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1], + ], + "scalar": [7], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + scalar, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.remainder) + torch_output_tensor = golden_function(torch_input_tensor_a, scalar) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.remainder(input_tensor_a, scalar, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/abs/abs_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/abs/abs_pytorch2.py new file mode 100644 index 00000000000..0e262cc9980 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/abs/abs_pytorch2.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [10, 10], + [15, 15], + ], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.abs) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.abs(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/bitwise/bitwise_not_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/bitwise/bitwise_not_pytorch2.py new file mode 100644 index 00000000000..91f5e28af58 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/bitwise/bitwise_not_pytorch2.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [[1, 23, 40]], + "input_a_dtype": [ttnn.int32], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +def mesh_device_fixture(): + device = ttnn.open_device(device_id=0) + assert ttnn.device.is_wormhole_b0(device), "This op is available for Wormhole_B0 only" + yield (device, "Wormhole_B0") + ttnn.close_device(device) + del device + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-2147483647, high=2147483648, dtype=torch.int64), input_a_dtype + )(input_shape) + + torch_input_tensor_a = torch.full(size=input_shape, fill_value=-2147483647).to(torch.int32) + + torch_output_tensor = torch.bitwise_not(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.bitwise_not(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/ceil/ceil_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/ceil/ceil_pytorch2.py new file mode 100644 index 00000000000..03fa811c4a1 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/ceil/ceil_pytorch2.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random, is_wormhole_b0 + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [[1066], [120], [128], [160], [240], [300], [30], [320], [40], [480], [60], [640], [800], [80]], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +def mesh_device_fixture(): + device = ttnn.open_device(device_id=0) + assert ttnn.device.is_wormhole_b0(device), "This op is available for Wormhole_B0 only" + yield (device, "Wormhole_B0") + ttnn.close_device(device) + del device + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.ceil) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.ceil(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/cos/cos_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/cos/cos_pytorch2.py new file mode 100644 index 00000000000..9cb49e3f023 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/cos/cos_pytorch2.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [[1, 160], [1, 23, 40, 64]], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=0, high=6.283185307179586, dtype=torch.float16), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.cos) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.cos(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/elu/elu_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/elu/elu_pytorch2.py new file mode 100644 index 00000000000..b69687ef84a --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/elu/elu_pytorch2.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 128, 28, 28], + ], + "alpha": [1.0], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + alpha, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.elu) + torch_output_tensor = golden_function(torch_input_tensor_a, alpha=alpha) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.elu(input_tensor_a, alpha=alpha, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/exp/exp_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/exp/exp_pytorch2.py new file mode 100644 index 00000000000..90f50db4656 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/exp/exp_pytorch2.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [0, 1], + [12, 1, 1], + [16, 1, 1], + [160], + [24, 1, 1], + [3, 1, 1], + [32, 1, 1], + [3234, 1], + [4, 1, 1], + [6, 1, 1], + [8, 1, 1], + [8732, 1], + [], + ], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-10, high=10, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.exp) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.exp(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/floor/floor_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/floor/floor_pytorch2.py new file mode 100644 index 00000000000..e3328020f49 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/floor/floor_pytorch2.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random, is_wormhole_b0 + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1" and "suite_2") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [[1, 1, 1, 42], [1, 1, 32, 1]], + "input_a_dtype": [ttnn.bfloat16, ttnn.bfloat8_b], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +def mesh_device_fixture(): + device = ttnn.open_device(device_id=0) + assert ttnn.device.is_wormhole_b0(device), "This op is available for Wormhole_B0 only" + yield (device, "Wormhole_B0") + ttnn.close_device(device) + del device + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a mesh_device_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.floor) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + output_tensor = ttnn.floor(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(output_tensor) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/gelu/gelu_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/gelu/gelu_pytorch2.py new file mode 100644 index 00000000000..595c6613e0d --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/gelu/gelu_pytorch2.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1, 3072], + [1, 10, 3072], + [1, 10, 768], + [1, 1024, 2560], + [1, 1024, 512], + [1, 1024, 640], + [1, 1200, 1280], + [1, 1370, 5120], + [1, 14, 14, 1536], + [1, 14, 14, 2048], + [1, 1445, 768], + [1, 1500, 3072], + [1, 1536], + [1, 16, 16, 1536], + [1, 16, 16, 2048], + [1, 16, 3072], + [1, 16384, 128], + [1, 19, 4096], + [1, 19200, 256], + [1, 196, 3072], + [1, 197, 3072], + [1, 197, 4096], + [1, 201, 3072], + [1, 2048, 768], + [1, 24, 3072], + [1, 25, 3072], + [1, 256, 1024], + [1, 256, 1280], + [1, 256, 256], + [1, 256, 4096], + [1, 256, 5120], + [1, 28, 28, 1024], + [1, 28, 28, 768], + [1, 300, 2048], + [1, 32, 32, 1024], + [1, 32, 32, 768], + [1, 4, 3072], + [1, 4096, 1280], + [1, 4096, 256], + [1, 4800, 512], + [1, 50, 3072], + [1, 50, 4096], + [1, 56, 56, 384], + [1, 56, 56, 512], + [1, 64, 5120], + [1, 64, 64, 384], + [1, 64, 64, 512], + [1, 7, 18176], + [1, 7, 7, 3072], + [1, 7, 7, 4096], + [1, 768, 1500], + [1, 768, 3000], + [1, 768, 384], + [1, 8, 8, 3072], + [1, 8, 8, 4096], + ], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float16), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.gelu) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.gelu(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/hardsigmoid/hardsigmoid_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/hardsigmoid/hardsigmoid_pytorch2.py new file mode 100644 index 00000000000..cd9266e2bc7 --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/hardsigmoid/hardsigmoid_pytorch2.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1024, 1, 1], + [1, 120, 1, 1], + [1, 144, 1, 1], + [1, 16, 1, 1], + [1, 240, 1, 1], + [1, 256, 1, 1], + [1, 288, 1, 1], + [1, 480, 1, 1], + [1, 512, 1, 1], + [1, 576, 1, 1], + [1, 672, 1, 1], + [1, 72, 1, 1], + [1, 768, 1, 1], + [1, 96, 1, 1], + [1, 960, 1, 1], + ], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float16), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.hardsigmoid) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.hardsigmoid(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/leaky_relu/leaky_relu_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/leaky_relu/leaky_relu_pytorch2.py new file mode 100644 index 00000000000..5f0078fb17b --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/leaky_relu/leaky_relu_pytorch2.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [ + [1, 1024, 16, 16], + [1, 128, 128, 128], + [1, 128, 1536], + [1, 128, 32, 32], + [1, 128, 64, 64], + [1, 256, 16, 16], + [1, 256, 32, 32], + [1, 256, 384], + [1, 256, 64, 64], + [1, 32, 24576], + [1, 32, 24576], + [1, 32, 256, 256], + [1, 32, 512, 512], + [1, 512, 16, 16], + [1, 512, 32, 32], + [1, 512, 96], + [1, 64, 128, 128], + [1, 64, 256, 256], + [1, 64, 6144], + ], + "negative_slope": [0.1], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + negative_slope, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=-100, high=100, dtype=torch.float16), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.leaky_relu) + torch_output_tensor = golden_function(torch_input_tensor_a, negative_slope=negative_slope) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.leaky_relu(input_tensor_a, negative_slope=negative_slope, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/sweep_framework/sweeps/eltwise/unary/log/log_pytorch2.py b/tests/sweep_framework/sweeps/eltwise/unary/log/log_pytorch2.py new file mode 100644 index 00000000000..e53c4b11c7f --- /dev/null +++ b/tests/sweep_framework/sweeps/eltwise/unary/log/log_pytorch2.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple +from functools import partial + +import torch +import random +import ttnn +from tests.sweep_framework.utils import gen_shapes +from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt + +from tests.ttnn.utils_for_testing import check_with_pcc, start_measuring_time, stop_measuring_time +from models.utility_functions import torch_random + +# Override the default timeout in seconds for hang detection. +TIMEOUT = 30 + +random.seed(0) + +# Parameters provided to the test vector generator are defined here. +# They are defined as dict-type suites that contain the arguments to the run function as keys, and lists of possible inputs as values. +# Each suite has a key name (in this case "suite_1") which will associate the test vectors to this specific suite of inputs. +# Developers can create their own generator functions and pass them to the parameters as inputs. +parameters = { + "nightly": { + "input_shape": [[1, 1], [10, 10], [15, 15], [17, 17], [2, 2]], + "input_a_dtype": [ttnn.bfloat16], + "input_a_layout": [ttnn.TILE_LAYOUT], + "input_a_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + "output_memory_config": [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG], + }, +} + + +# This is the run instructions for the test, defined by the developer. +# The run function must take the above-defined parameters as inputs. +# The runner will call this run function with each test vector, and the returned results from this function will be stored. +# If you defined a device_mesh_fixture above, the object you yielded will be passed into this function as 'device'. Otherwise, it will be the default ttnn device opened by the infra. +def run( + input_shape, + input_a_dtype, + input_a_layout, + input_a_memory_config, + output_memory_config, + *, + device, +) -> list: + data_seed = random.randint(0, 20000000) + torch.manual_seed(data_seed) + + torch_input_tensor_a = gen_func_with_cast_tt( + partial(torch_random, low=1, high=100, dtype=torch.float32), input_a_dtype + )(input_shape) + + golden_function = ttnn.get_golden_function(ttnn.log) + torch_output_tensor = golden_function(torch_input_tensor_a) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, + dtype=input_a_dtype, + layout=input_a_layout, + device=device, + memory_config=input_a_memory_config, + ) + + start_time = start_measuring_time() + result = ttnn.log(input_tensor_a, memory_config=output_memory_config) + output_tensor = ttnn.to_torch(result) + e2e_perf = stop_measuring_time(start_time) + + return [check_with_pcc(torch_output_tensor, output_tensor, 0.999), e2e_perf] diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_unary_ops_ttnn.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_unary_ops_ttnn.py index bda1f32c355..7cf8ea27cfd 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_unary_ops_ttnn.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_unary_ops_ttnn.py @@ -463,34 +463,13 @@ def test_unary_gelu_ttnn(input_shapes, fast_and_approx, device): (torch.Size([1, 3, 320, 384])), ), ) -@pytest.mark.parametrize("negative_slope", [1.0, 5.0, 10.0]) +@pytest.mark.parametrize("negative_slope", [1.0, 5.0, 10.0, 0.1]) def test_unary_leaky_relu_ttnn(input_shapes, negative_slope, device): in_data, input_tensor = data_gen_with_range(input_shapes, -10, 10, device) _, output_tensor = data_gen_with_range(input_shapes, -1, 1, device) cq_id = 0 - ttnn.leaky_relu(input_tensor, slope=negative_slope, output_tensor=output_tensor, queue_id=cq_id) - golden_tensor = torch.nn.functional.leaky_relu(in_data, negative_slope) - - comp_pass = compare_pcc([output_tensor], [golden_tensor]) - assert comp_pass - - -@pytest.mark.parametrize( - "input_shapes", - ( - (torch.Size([1, 1, 32, 32])), - (torch.Size([1, 1, 320, 384])), - (torch.Size([1, 3, 320, 384])), - ), -) -@pytest.mark.parametrize("negative_slope", [1.0, 5.0, 10.0]) -def test_unary_leaky_relu_ttnn(input_shapes, negative_slope, device): - in_data, input_tensor = data_gen_with_range(input_shapes, -10, 10, device) - _, output_tensor = data_gen_with_range(input_shapes, -1, 1, device) - - cq_id = 0 - ttnn.leaky_relu(input_tensor, slope=negative_slope, output_tensor=output_tensor, queue_id=cq_id) + ttnn.leaky_relu(input_tensor, negative_slope=negative_slope, output_tensor=output_tensor, queue_id=cq_id) golden_tensor = torch.nn.functional.leaky_relu(in_data, negative_slope) comp_pass = compare_pcc([output_tensor], [golden_tensor]) diff --git a/tests/ttnn/unit_tests/operations/test_activation.py b/tests/ttnn/unit_tests/operations/eltwise/test_activation.py similarity index 88% rename from tests/ttnn/unit_tests/operations/test_activation.py rename to tests/ttnn/unit_tests/operations/eltwise/test_activation.py index 56eb3293c33..2f3f4b41865 100644 --- a/tests/ttnn/unit_tests/operations/test_activation.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_activation.py @@ -186,6 +186,38 @@ def torch_prelu(x, *args, weight, **kwargs): return result +def run_activation_test_elu(device, h, w, scalar, ttnn_function, pcc=0.99): + torch.manual_seed(0) + + torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16) + golden_function = ttnn.get_golden_function(ttnn_function) + torch_output_tensor = golden_function(torch_input_tensor_a, alpha=scalar) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn_function(input_tensor_a, alpha=scalar) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(torch_output_tensor, output_tensor, pcc) + + +def run_activation_test_leaky_relu(device, h, w, scalar, ttnn_function, pcc=0.99): + torch.manual_seed(0) + + torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16) + golden_function = ttnn.get_golden_function(ttnn_function) + torch_output_tensor = golden_function(torch_input_tensor_a, negative_slope=scalar) + + input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device) + + output_tensor = ttnn_function(input_tensor_a, negative_slope=scalar) + output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT) + output_tensor = ttnn.from_device(output_tensor) + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(torch_output_tensor, output_tensor, pcc) + + def run_activation_test_scalarB(device, h, w, scalar, ttnn_function, pcc=0.99): torch.manual_seed(0) @@ -222,7 +254,7 @@ def run_activation_test_scalarB_key(device, h, w, value, ttnn_function, pcc=0.99 @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_scalarB_elu(device, h, w, scalar): - run_activation_test_scalarB(device, h, w, scalar, ttnn.elu) + run_activation_test_elu(device, h, w, scalar, ttnn.elu) @pytest.mark.parametrize("alpha", [1, 2.5, 5.0]) @@ -268,11 +300,11 @@ def test_scalarB_heaviside(device, h, w, value): run_activation_test_scalarB_key(device, h, w, value, ttnn.heaviside) -@pytest.mark.parametrize("scalar", [-0.5, 0, 0.5]) +@pytest.mark.parametrize("scalar", [-0.5, 0, 0.1, 0.01, 0.5]) @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_scalarB_leaky_relu(device, h, w, scalar): - run_activation_test_scalarB(device, h, w, scalar, ttnn.leaky_relu) + run_activation_test_leaky_relu(device, h, w, scalar, ttnn.leaky_relu) @pytest.mark.parametrize("weight", [-0.5, 1.0, 0.5]) diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp index 763767c0bb7..4a46a796570 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/unary_pybind.hpp @@ -1449,7 +1449,7 @@ void py_module(py::module& module) { )doc"); detail::bind_unary_operation_with_float_parameter(module, ttnn::heaviside, "value", "The value parameter for the Heaviside function", ""); - detail::bind_unary_operation_with_float_parameter(module, ttnn::leaky_relu, "slope", "The slope parameter for the Leaky ReLU function", ""); + detail::bind_unary_operation_with_float_parameter(module, ttnn::leaky_relu, "negative_slope", "The slope parameter for the Leaky ReLU function", ""); detail::bind_unary_operation_with_float_parameter(module, ttnn::relu_max, "upper_limit", "The max value for ReLU function", "This function caps off the input to a max value and a min value of 0"); detail::bind_unary_operation_with_float_parameter(module, ttnn::relu_min, "lower_limit", "The min value for ReLU function", "This will carry out ReLU operation at min value instead of the standard 0"); diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/unary.py index 48d0d9eca5f..c282c2b21b2 100644 --- a/ttnn/ttnn/operations/unary.py +++ b/ttnn/ttnn/operations/unary.py @@ -66,8 +66,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): "gelu": torch.nn.functional.gelu, "rsqrt": torch.rsqrt, # Unaries with float parameter - "elu": torch.nn.functional.elu, - "leaky_relu": torch.nn.functional.leaky_relu, # "prelu": torch_prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Other unaries (composite operations) "softplus": torch.nn.functional.softplus, @@ -151,8 +149,6 @@ def _golden_function(input_tensor: ttnn.Tensor, **_): ttnn.gelu, ttnn.rsqrt, # Unaries with float parameter - ttnn.elu, - ttnn.leaky_relu, # ttnn.prelu, # Alias for leaky_relu. TODO(#8544): implement PReLU properly # Unaries using op_chain ttnn.log_sigmoid, @@ -236,6 +232,24 @@ def _golden_function_pow(input_tensor_a, exponent, *args, **kwargs): ttnn.attach_golden_function(ttnn.pow, golden_function=_golden_function_pow) +def _golden_function_elu(input_tensor_a, *args, alpha=1.0, **kwargs): + import torch + + return torch.nn.functional.elu(input_tensor_a, alpha=alpha) + + +ttnn.attach_golden_function(ttnn.elu, golden_function=_golden_function_elu) + + +def _golden_function_leaky_relu(input_tensor_a, *args, negative_slope=0.01, **kwargs): + import torch + + return torch.nn.functional.leaky_relu(input_tensor_a, negative_slope=negative_slope) + + +ttnn.attach_golden_function(ttnn.leaky_relu, golden_function=_golden_function_leaky_relu) + + def _golden_function_relu_min(input_tensor_a, *args, lower_limit, **kwargs): import torch From 9ffd5f7cfda3eeb7426752d0b75b089f30a88662 Mon Sep 17 00:00:00 2001 From: Kalaivani Baskar <156762498+KalaivaniMCW@users.noreply.github.com> Date: Thu, 10 Oct 2024 09:23:23 +0530 Subject: [PATCH 53/58] #13630: Move eltwise test files and operations under eltwise folder (#13631) --- tests/end_to_end_tests/test_ttnn.py | 2 +- .../backward_complex_utility_funcs.py | 0 .../backward/complex_ops/test_backward_abs.py | 4 ++-- .../complex_ops/test_backward_angle.py | 4 ++-- .../complex_ops/test_backward_complex_add.py | 2 +- .../complex_ops/test_backward_complex_div.py | 2 +- .../complex_ops/test_backward_complex_mul.py | 2 +- .../complex_ops/test_backward_complex_sub.py | 2 +- .../complex_ops/test_backward_conj.py | 4 ++-- .../complex_ops/test_backward_imag.py | 2 +- .../complex_ops/test_backward_polar.py | 2 +- .../complex_ops/test_backward_real.py | 2 +- .../complex_ops/test_backward_recip.py | 4 ++-- .../backward/test_backward_abs.py | 2 +- .../backward/test_backward_acos.py | 2 +- .../backward/test_backward_acosh.py | 2 +- .../backward/test_backward_add.py | 2 +- .../backward/test_backward_addalpha.py | 2 +- .../backward/test_backward_addcdiv.py | 2 +- .../backward/test_backward_addcmul.py | 2 +- .../backward/test_backward_asin.py | 2 +- .../backward/test_backward_asinh.py | 2 +- .../backward/test_backward_assign.py | 2 +- .../backward/test_backward_atan.py | 2 +- .../backward/test_backward_atan2.py | 2 +- .../backward/test_backward_atanh.py | 2 +- .../backward/test_backward_bias_gelu.py | 2 +- .../backward/test_backward_ceil.py | 2 +- .../backward/test_backward_celu.py | 2 +- .../backward/test_backward_clamp.py | 2 +- .../backward/test_backward_concat.py | 2 +- .../backward/test_backward_cos.py | 2 +- .../backward/test_backward_cosh.py | 2 +- .../backward/test_backward_deg2rad.py | 2 +- .../backward/test_backward_digamma.py | 2 +- .../backward/test_backward_div.py | 6 +++++- .../backward/test_backward_div_no_nan.py | 2 +- .../backward/test_backward_elu.py | 2 +- .../backward/test_backward_embedding.py | 0 .../backward/test_backward_erf.py | 2 +- .../backward/test_backward_erfc.py | 2 +- .../backward/test_backward_erfinv.py | 2 +- .../backward/test_backward_exp.py | 2 +- .../backward/test_backward_exp2.py | 2 +- .../backward/test_backward_expm1.py | 2 +- .../backward/test_backward_fill.py | 2 +- .../backward/test_backward_fill_zero.py | 2 +- .../backward/test_backward_floor.py | 2 +- .../backward/test_backward_fmod.py | 2 +- .../backward/test_backward_frac.py | 2 +- .../backward/test_backward_gelu.py | 2 +- .../backward/test_backward_hardshrink.py | 2 +- .../backward/test_backward_hardsigmoid.py | 2 +- .../backward/test_backward_hardswish.py | 2 +- .../backward/test_backward_hardtanh.py | 2 +- .../backward/test_backward_hypot.py | 2 +- .../backward/test_backward_i0.py | 2 +- .../backward/test_backward_ldexp.py | 2 +- .../backward/test_backward_leaky_relu.py | 2 +- .../backward/test_backward_lerp.py | 2 +- .../backward/test_backward_lgamma.py | 2 +- .../backward/test_backward_log.py | 2 +- .../backward/test_backward_log10.py | 2 +- .../backward/test_backward_log1p.py | 2 +- .../backward/test_backward_log2.py | 2 +- .../backward/test_backward_log_sigmoid.py | 2 +- .../backward/test_backward_logaddexp.py | 2 +- .../backward/test_backward_logaddexp2.py | 2 +- .../backward/test_backward_logit.py | 2 +- .../backward/test_backward_logiteps.py | 2 +- .../backward/test_backward_max.py | 2 +- .../backward/test_backward_min.py | 2 +- .../backward/test_backward_mul.py | 2 +- .../backward/test_backward_mvlgamma.py | 2 +- .../backward/test_backward_neg.py | 2 +- .../backward/test_backward_polygamma.py | 6 +++++- .../backward/test_backward_pow.py | 2 +- .../backward/test_backward_prod.py | 2 +- .../backward/test_backward_rad2deg.py | 2 +- .../backward/test_backward_rdiv.py | 2 +- .../backward/test_backward_reciprocal.py | 2 +- .../backward/test_backward_relu.py | 2 +- .../backward/test_backward_relu6.py | 2 +- .../backward/test_backward_remainder.py | 2 +- .../backward/test_backward_repeat.py | 2 +- .../backward/test_backward_round.py | 2 +- .../backward/test_backward_rpow.py | 2 +- .../backward/test_backward_rsqrt.py | 2 +- .../backward/test_backward_rsub.py | 2 +- .../backward/test_backward_selu.py | 2 +- .../backward/test_backward_sigmoid.py | 2 +- .../backward/test_backward_sign.py | 2 +- .../backward/test_backward_silu.py | 2 +- .../backward/test_backward_sin.py | 2 +- .../backward/test_backward_sinh.py | 2 +- .../backward/test_backward_softplus.py | 2 +- .../backward/test_backward_softshrink.py | 2 +- .../backward/test_backward_softsign.py | 2 +- .../backward/test_backward_sqrt.py | 2 +- .../backward/test_backward_square.py | 2 +- .../test_backward_squared_difference.py | 2 +- .../backward/test_backward_sub.py | 2 +- .../backward/test_backward_subalpha.py | 2 +- .../backward/test_backward_tan.py | 2 +- .../backward/test_backward_tanh.py | 2 +- .../backward/test_backward_tanhshrink.py | 2 +- .../backward/test_backward_threshold.py | 2 +- .../backward/test_backward_trunc.py | 2 +- .../backward/test_backward_where.py | 2 +- .../backward/test_backward_xlogy.py | 2 +- .../{ => eltwise}/backward/utility_funcs.py | 0 .../complex/test_complex_conj.py | 4 ++-- .../{ => eltwise}/complex/utility_funcs.py | 0 .../operations/{ => eltwise}/test_add.py | 0 .../operations/{ => eltwise}/test_backward.py | 2 +- .../{ => eltwise}/test_binary_composite.py | 2 +- .../{ => eltwise}/test_binary_scalar.py | 2 +- .../operations/{ => eltwise}/test_complex.py | 0 .../{ => eltwise}/test_complex_tensor.py | 4 ++-- .../{ => eltwise}/test_composite.py | 2 +- .../{ => eltwise}/test_elt_binary.py | 0 .../test_eltwise_logical_and_.py | 0 .../{ => eltwise}/test_eltwise_typecast.py | 0 .../operations/{ => eltwise}/test_inplace.py | 0 .../operations/{ => eltwise}/test_math.py | 0 .../{ => eltwise}/test_math_binary.py | 0 .../operations/{ => eltwise}/test_mul.py | 0 .../operations/{ => eltwise}/test_pow.py | 2 +- .../{ => eltwise}/test_relational.py | 0 .../operations/{ => eltwise}/test_sub.py | 0 .../operations/{ => eltwise}/test_ternary.py | 0 .../{ => eltwise}/test_ternary_composite.py | 6 +++++- .../operations/{ => eltwise}/test_unary.py | 2 +- ttnn/cpp/pybind11/operations/__init__.hpp | 21 ++++++++++--------- ttnn/ttnn/__init__.py | 1 + ttnn/ttnn/operations/eltwise/__init__.py | 12 +++++++++++ ttnn/ttnn/operations/{ => eltwise}/binary.py | 0 .../{ => eltwise}/binary_backward.py | 2 +- .../{ => eltwise}/binary_complex.py | 0 .../{ => eltwise}/complex_binary_backward.py | 0 .../{ => eltwise}/complex_unary_backward.py | 0 ttnn/ttnn/operations/{ => eltwise}/ternary.py | 0 .../{ => eltwise}/ternary_backward.py | 0 ttnn/ttnn/operations/{ => eltwise}/unary.py | 0 .../{ => eltwise}/unary_backward.py | 0 .../operations/{ => eltwise}/unary_complex.py | 0 146 files changed, 160 insertions(+), 134 deletions(-) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/backward_complex_utility_funcs.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_abs.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_angle.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_complex_add.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_complex_div.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_complex_mul.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_complex_sub.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_conj.py (90%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_imag.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_polar.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_real.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/complex_ops/test_backward_recip.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_abs.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_acos.py (87%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_acosh.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_add.py (96%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_addalpha.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_addcdiv.py (91%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_addcmul.py (91%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_asin.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_asinh.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_assign.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_atan.py (91%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_atan2.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_atanh.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_bias_gelu.py (96%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_ceil.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_celu.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_clamp.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_concat.py (98%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_cos.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_cosh.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_deg2rad.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_digamma.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_div.py (98%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_div_no_nan.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_elu.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_embedding.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_erf.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_erfc.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_erfinv.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_exp.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_exp2.py (87%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_expm1.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_fill.py (96%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_fill_zero.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_floor.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_fmod.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_frac.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_gelu.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_hardshrink.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_hardsigmoid.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_hardswish.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_hardtanh.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_hypot.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_i0.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_ldexp.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_leaky_relu.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_lerp.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_lgamma.py (91%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_log.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_log10.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_log1p.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_log2.py (91%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_log_sigmoid.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_logaddexp.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_logaddexp2.py (92%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_logit.py (91%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_logiteps.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_max.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_min.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_mul.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_mvlgamma.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_neg.py (96%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_polygamma.py (96%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_pow.py (98%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_prod.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_rad2deg.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_rdiv.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_reciprocal.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_relu.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_relu6.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_remainder.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_repeat.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_round.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_rpow.py (92%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_rsqrt.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_rsub.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_selu.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_sigmoid.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_sign.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_silu.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_sin.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_sinh.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_softplus.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_softshrink.py (95%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_softsign.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_sqrt.py (93%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_square.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_squared_difference.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_sub.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_subalpha.py (96%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_tan.py (92%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_tanh.py (94%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_tanhshrink.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_threshold.py (90%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_trunc.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_where.py (96%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/test_backward_xlogy.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/backward/utility_funcs.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/complex/test_complex_conj.py (89%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/complex/utility_funcs.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_add.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_backward.py (98%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_binary_composite.py (99%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_binary_scalar.py (90%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_complex.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_complex_tensor.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_composite.py (99%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_elt_binary.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_eltwise_logical_and_.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_eltwise_typecast.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_inplace.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_math.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_math_binary.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_mul.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_pow.py (88%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_relational.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_sub.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_ternary.py (100%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_ternary_composite.py (97%) rename tests/ttnn/unit_tests/operations/{ => eltwise}/test_unary.py (99%) create mode 100644 ttnn/ttnn/operations/eltwise/__init__.py rename ttnn/ttnn/operations/{ => eltwise}/binary.py (100%) rename ttnn/ttnn/operations/{ => eltwise}/binary_backward.py (99%) rename ttnn/ttnn/operations/{ => eltwise}/binary_complex.py (100%) rename ttnn/ttnn/operations/{ => eltwise}/complex_binary_backward.py (100%) rename ttnn/ttnn/operations/{ => eltwise}/complex_unary_backward.py (100%) rename ttnn/ttnn/operations/{ => eltwise}/ternary.py (100%) rename ttnn/ttnn/operations/{ => eltwise}/ternary_backward.py (100%) rename ttnn/ttnn/operations/{ => eltwise}/unary.py (100%) rename ttnn/ttnn/operations/{ => eltwise}/unary_backward.py (100%) rename ttnn/ttnn/operations/{ => eltwise}/unary_complex.py (100%) diff --git a/tests/end_to_end_tests/test_ttnn.py b/tests/end_to_end_tests/test_ttnn.py index 4166c7d7558..5d1e2dd2401 100644 --- a/tests/end_to_end_tests/test_ttnn.py +++ b/tests/end_to_end_tests/test_ttnn.py @@ -7,7 +7,7 @@ import ttnn import torch -import ttnn.operations.binary +import ttnn.operations.eltwise.binary @pytest.mark.eager_host_side diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/backward_complex_utility_funcs.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/backward_complex_utility_funcs.py similarity index 100% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/backward_complex_utility_funcs.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/backward_complex_utility_funcs.py diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_abs.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_abs.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_abs.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_abs.py index 3305f9ef35e..4da79b2f824 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_abs.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_abs.py @@ -11,11 +11,11 @@ import pytest import ttnn from loguru import logger -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_angle.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_angle.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_angle.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_angle.py index 86c8b1156c5..429bb518d69 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_angle.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_angle.py @@ -11,11 +11,11 @@ import pytest import ttnn from loguru import logger -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_add.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_add.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_add.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_add.py index 164d78baf73..bcf1fd9861c 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_add.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_add.py @@ -15,7 +15,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_div.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_div.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_div.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_div.py index 3aa19df1c67..9572487d078 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_div.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_div.py @@ -19,7 +19,7 @@ is_wormhole_b0, is_blackhole, ) -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_mul.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_mul.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_mul.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_mul.py index f5e88bc9c4f..588bbf32f00 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_mul.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_mul.py @@ -16,7 +16,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_sub.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_sub.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_sub.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_sub.py index ee43e94e8c0..60fb342ae1c 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_complex_sub.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_complex_sub.py @@ -10,7 +10,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_conj.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_conj.py similarity index 90% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_conj.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_conj.py index 25ce48408fb..3845ae09789 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_conj.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_conj.py @@ -11,11 +11,11 @@ import pytest import ttnn from loguru import logger -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_imag.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_imag.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_imag.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_imag.py index e006f5d82d9..10b4ef3a74b 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_imag.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_imag.py @@ -15,7 +15,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_polar.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_polar.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_polar.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_polar.py index 9c1e910ccd5..36ae664a3cd 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_polar.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_polar.py @@ -16,7 +16,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_real.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_real.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_real.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_real.py index 36e3e680e98..89a94d02360 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_real.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_real.py @@ -15,7 +15,7 @@ from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_recip.py b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_recip.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_recip.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_recip.py index ccfa6c64215..3a5e8cba83d 100644 --- a/tests/ttnn/unit_tests/operations/backward/complex_ops/test_backward_recip.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/complex_ops/test_backward_recip.py @@ -11,11 +11,11 @@ import pytest import ttnn from loguru import logger -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0, skip_for_wormhole_b0, is_blackhole -from tests.ttnn.unit_tests.operations.backward.complex_ops.backward_complex_utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.complex_ops.backward_complex_utility_funcs import ( Complex, convert_to_torch_tensor, random_complex_tensor, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_abs.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_abs.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_abs.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_abs.py index 0e4663b7f52..630242e0dc2 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_abs.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_abs.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_acos.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_acos.py similarity index 87% rename from tests/ttnn/unit_tests/operations/backward/test_backward_acos.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_acos.py index e83d92d594d..d4e4a190747 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_acos.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_acos.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_acosh.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_acosh.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_acosh.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_acosh.py index 08e9da6b616..137918090ef 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_acosh.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_acosh.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_add.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_add.py similarity index 96% rename from tests/ttnn/unit_tests/operations/backward/test_backward_add.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_add.py index 19029f0ae6f..835ac41dadd 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_add.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_add.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_addalpha.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addalpha.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_addalpha.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addalpha.py index 7ae316a5297..048d087cf6e 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_addalpha.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addalpha.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_addcdiv.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addcdiv.py similarity index 91% rename from tests/ttnn/unit_tests/operations/backward/test_backward_addcdiv.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addcdiv.py index 70007b0bcfe..2df72410854 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_addcdiv.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addcdiv.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_addcmul.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addcmul.py similarity index 91% rename from tests/ttnn/unit_tests/operations/backward/test_backward_addcmul.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addcmul.py index 8252447408d..cf97d53a7d4 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_addcmul.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_addcmul.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_asin.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_asin.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_asin.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_asin.py index e23e5eac232..39e48022133 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_asin.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_asin.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_asinh.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_asinh.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_asinh.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_asinh.py index eb231a9a450..023071ad3b2 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_asinh.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_asinh.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_assign.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_assign.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_assign.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_assign.py index 48d9115fa36..3115c447a1b 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_assign.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_assign.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_atan.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atan.py similarity index 91% rename from tests/ttnn/unit_tests/operations/backward/test_backward_atan.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atan.py index 686bf6af4d9..91aae353528 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_atan.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atan.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_pcc, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_atan2.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atan2.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/test_backward_atan2.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atan2.py index bff00df6cb4..528c129380d 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_atan2.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atan2.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_pcc, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_atanh.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atanh.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_atanh.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atanh.py index c69a6f6fb59..ef877ccbf02 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_atanh.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_atanh.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_bias_gelu.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_bias_gelu.py similarity index 96% rename from tests/ttnn/unit_tests/operations/backward/test_backward_bias_gelu.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_bias_gelu.py index 9ce2c7b55f5..af2b406e292 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_bias_gelu.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_bias_gelu.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_ceil.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_ceil.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_ceil.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_ceil.py index b9b540ec2e0..04c07035493 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_ceil.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_ceil.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_celu.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/test_backward_celu.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_celu.py index b42841ad356..7be5986f39f 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_celu.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_celu.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_clamp.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_clamp.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/test_backward_clamp.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_clamp.py index e213e1103d9..7720d27bf2f 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_clamp.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_clamp.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_concat.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_concat.py similarity index 98% rename from tests/ttnn/unit_tests/operations/backward/test_backward_concat.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_concat.py index e239e2d1aff..e589a9d3b86 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_concat.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_concat.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_cos.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_cos.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_cos.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_cos.py index 0a4fd1e29ba..9ce9c1a0788 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_cos.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_cos.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc from math import pi diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_cosh.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_cosh.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_cosh.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_cosh.py index b61c6fe0332..702289300a1 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_cosh.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_cosh.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_pcc, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_deg2rad.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_deg2rad.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_deg2rad.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_deg2rad.py index 694679c3edc..63955715069 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_deg2rad.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_deg2rad.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_digamma.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_digamma.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_digamma.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_digamma.py index 10a5fc3bf94..b776a6312dc 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_digamma.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_digamma.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_div.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_div.py similarity index 98% rename from tests/ttnn/unit_tests/operations/backward/test_backward_div.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_div.py index c8f1bd584cc..66543b493ac 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_div.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_div.py @@ -5,7 +5,11 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, data_gen_with_val, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( + data_gen_with_range, + data_gen_with_val, + compare_pcc, +) from models.utility_functions import ( is_wormhole_b0, is_blackhole, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_div_no_nan.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_div_no_nan.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_div_no_nan.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_div_no_nan.py index 70d3ba39cc4..60bc50101f6 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_div_no_nan.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_div_no_nan.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_elu.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/test_backward_elu.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_elu.py index 8114d3022c1..f93abf03e8b 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_elu.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_elu.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_embedding.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_embedding.py similarity index 100% rename from tests/ttnn/unit_tests/operations/backward/test_backward_embedding.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_embedding.py diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_erf.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erf.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_erf.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erf.py index 0278ad60e29..336965087e9 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_erf.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erf.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_erfc.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erfc.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_erfc.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erfc.py index 19a0cfc206c..6d91131c415 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_erfc.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erfc.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_erfinv.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erfinv.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_erfinv.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erfinv.py index 4c3197ce763..d235f11019e 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_erfinv.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_erfinv.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_exp.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_exp.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/test_backward_exp.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_exp.py index 9de5f8cbc1e..ab779433edf 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_exp.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_exp.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_exp2.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_exp2.py similarity index 87% rename from tests/ttnn/unit_tests/operations/backward/test_backward_exp2.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_exp2.py index 825e8232150..52511cdeaff 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_exp2.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_exp2.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_expm1.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_expm1.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_expm1.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_expm1.py index 5ea5b9672b2..df8fbdf8153 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_expm1.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_expm1.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_fill.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fill.py similarity index 96% rename from tests/ttnn/unit_tests/operations/backward/test_backward_fill.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fill.py index 16a9711867b..228e139f05e 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_fill.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fill.py @@ -6,7 +6,7 @@ import pytest import ttnn from models.utility_functions import is_wormhole_b0, is_blackhole -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_all_close, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_fill_zero.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fill_zero.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_fill_zero.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fill_zero.py index cabdb82b516..6c946d303a3 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_fill_zero.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fill_zero.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_floor.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_floor.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_floor.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_floor.py index 0229e5dca28..afe14c79ec4 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_floor.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_floor.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_fmod.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fmod.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/test_backward_fmod.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fmod.py index e9f0cc64293..d5ecc666e50 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_fmod.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_fmod.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range from models.utility_functions import skip_for_grayskull diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_frac.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_frac.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_frac.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_frac.py index 2e256084db6..1f148d4fab4 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_frac.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_frac.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_gelu.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_gelu.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_gelu.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_gelu.py index 236073e3cb9..1b514166b83 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_gelu.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_gelu.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardshrink.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardshrink.py index 5700e708382..96812608e00 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_hardshrink.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardshrink.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_hardsigmoid.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardsigmoid.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_hardsigmoid.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardsigmoid.py index d7ec46359bd..535c0844211 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_hardsigmoid.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardsigmoid.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_hardswish.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardswish.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_hardswish.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardswish.py index 14b51e51175..8241bb4b379 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_hardswish.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardswish.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_hardtanh.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardtanh.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/test_backward_hardtanh.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardtanh.py index d8e20cf43c3..75b0f22adab 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_hardtanh.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hardtanh.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_hypot.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hypot.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_hypot.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hypot.py index 5e28b34e894..05c199fb804 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_hypot.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_hypot.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_i0.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_i0.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_i0.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_i0.py index fa9a40f2e8f..20aa2720a18 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_i0.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_i0.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_ldexp.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_ldexp.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_ldexp.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_ldexp.py index 36893e981ef..67ede15225b 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_ldexp.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_ldexp.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_leaky_relu.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_leaky_relu.py index 5c33b4f1664..41f87a4e4a6 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_leaky_relu.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_leaky_relu.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_lerp.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_lerp.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/test_backward_lerp.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_lerp.py index 9bc3a6584cf..676a5885b45 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_lerp.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_lerp.py @@ -6,7 +6,7 @@ import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_lgamma.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_lgamma.py similarity index 91% rename from tests/ttnn/unit_tests/operations/backward/test_backward_lgamma.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_lgamma.py index d85b05be42d..2d066175651 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_lgamma.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_lgamma.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( compare_pcc, data_gen_with_range, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_log.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/test_backward_log.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log.py index 9b460ec1e32..85db970d27b 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_log.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_val, compare_pcc, data_gen_with_range, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_log10.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log10.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_log10.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log10.py index 20ca394a6bf..0c8e4aea027 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_log10.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log10.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_log1p.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log1p.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_log1p.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log1p.py index f1bf22ff16a..c0e04042a46 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_log1p.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log1p.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_log2.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log2.py similarity index 91% rename from tests/ttnn/unit_tests/operations/backward/test_backward_log2.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log2.py index 5d8256b1527..820fa92cee3 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_log2.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log2.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( compare_pcc, data_gen_with_range, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_log_sigmoid.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log_sigmoid.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_log_sigmoid.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log_sigmoid.py index 82ba28fd9b4..bd3def80766 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_log_sigmoid.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_log_sigmoid.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_pcc, data_gen_with_val, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_logaddexp.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logaddexp.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_logaddexp.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logaddexp.py index 4215c7de991..db0c1d6453f 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_logaddexp.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logaddexp.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_logaddexp2.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logaddexp2.py similarity index 92% rename from tests/ttnn/unit_tests/operations/backward/test_backward_logaddexp2.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logaddexp2.py index 89009a1ba6c..6a2f56a2745 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_logaddexp2.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logaddexp2.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_pcc, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_logit.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logit.py similarity index 91% rename from tests/ttnn/unit_tests/operations/backward/test_backward_logit.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logit.py index ac7cc3811ab..576c0d7be79 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_logit.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logit.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( compare_pcc, data_gen_with_range, data_gen_with_val, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logiteps.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logiteps.py index 5546090dcf7..bbe3733444b 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_logiteps.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_logiteps.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, data_gen_with_val, compare_pcc, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_max.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_max.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_max.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_max.py index bede78fff64..9025bf37420 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_max.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_max.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_min.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_min.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_min.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_min.py index 26a3712b881..dcf48a8a79c 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_min.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_min.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_mul.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_mul.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_mul.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_mul.py index 06becdfca1e..71d142b4692 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_mul.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_mul.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_mvlgamma.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_mvlgamma.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_mvlgamma.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_mvlgamma.py index a3877f5403a..1b4bace286e 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_mvlgamma.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_mvlgamma.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_neg.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_neg.py similarity index 96% rename from tests/ttnn/unit_tests/operations/backward/test_backward_neg.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_neg.py index 573b8fc8822..a103128a4db 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_neg.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_neg.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_polygamma.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_polygamma.py similarity index 96% rename from tests/ttnn/unit_tests/operations/backward/test_backward_polygamma.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_polygamma.py index d2a1670bbb3..b2053185bea 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_polygamma.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_polygamma.py @@ -5,7 +5,11 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range, data_gen_with_val +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( + compare_pcc, + data_gen_with_range, + data_gen_with_val, +) from models.utility_functions import ( is_wormhole_b0, is_blackhole, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_pow.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_pow.py similarity index 98% rename from tests/ttnn/unit_tests/operations/backward/test_backward_pow.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_pow.py index 393184e7bb9..aa5d72dc4c0 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_pow.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_pow.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_pcc, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_prod.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_prod.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_prod.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_prod.py index bd0db0d4eda..1dada61fe41 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_prod.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_prod.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_pt_tt, data_gen_pt_tt_prod, compare_results, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_rad2deg.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rad2deg.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_rad2deg.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rad2deg.py index 54f72d9e1fe..8eaed4c7c89 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_rad2deg.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rad2deg.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_rdiv.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rdiv.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/test_backward_rdiv.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rdiv.py index 1b7afe5b9ed..e00b9da6eed 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_rdiv.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rdiv.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_pt_tt, compare_results +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_pt_tt, compare_results @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_reciprocal.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_reciprocal.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/test_backward_reciprocal.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_reciprocal.py index 009dab096ca..a226f2f6137 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_reciprocal.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_reciprocal.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( compare_pcc, data_gen_with_range, data_gen_with_val, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_relu.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_relu.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_relu.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_relu.py index 25aaa9c840f..76bcad4413d 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_relu.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_relu.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_relu6.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_relu6.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_relu6.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_relu6.py index a34cdedbcb7..ec96169e8aa 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_relu6.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_relu6.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_remainder.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_remainder.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/test_backward_remainder.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_remainder.py index e20164644f1..1bb21aaa033 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_remainder.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_remainder.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range from models.utility_functions import skip_for_grayskull diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_repeat.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_repeat.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_repeat.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_repeat.py index 718ec611a5c..5aead86970f 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_repeat.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_repeat.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_pt_tt, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_pt_tt, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_round.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_round.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_round.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_round.py index 345f027743e..cbf5af1c47c 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_round.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_round.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_rpow.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rpow.py similarity index 92% rename from tests/ttnn/unit_tests/operations/backward/test_backward_rpow.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rpow.py index f82cc3d4ffd..fcaa965b3af 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_rpow.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rpow.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( compare_pcc, data_gen_with_range, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_rsqrt.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rsqrt.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/test_backward_rsqrt.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rsqrt.py index 7411f22ff94..6430d29cdd1 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_rsqrt.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rsqrt.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rsub.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rsub.py index ec8647a90ab..452846d4a8e 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_rsub.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_rsub.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_selu.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_selu.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_selu.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_selu.py index f46107c1257..60a32f18abe 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_selu.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_selu.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_sigmoid.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sigmoid.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_sigmoid.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sigmoid.py index c8ccd418f1e..6c0d38576ac 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_sigmoid.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sigmoid.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_sign.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sign.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_sign.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sign.py index d2a36a493de..abbd7a1335d 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_sign.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sign.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_silu.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_silu.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/test_backward_silu.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_silu.py index 21b41919f66..1274ede323c 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_silu.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_silu.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_sin.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sin.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_sin.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sin.py index f3b22857581..79d4b4ad05c 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_sin.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sin.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc from math import pi diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_sinh.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sinh.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_sinh.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sinh.py index c1c233bfaa9..26b511683bf 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_sinh.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sinh.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc from models.utility_functions import ( is_wormhole_b0, diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_softplus.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softplus.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/test_backward_softplus.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softplus.py index f1fcc514cf9..7d191595c64 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_softplus.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softplus.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softshrink.py similarity index 95% rename from tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softshrink.py index bf3651a8fac..275be0d5d39 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_softshrink.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softshrink.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_results, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_softsign.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softsign.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_softsign.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softsign.py index b83ab110341..a4c5d1837a3 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_softsign.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_softsign.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_sqrt.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sqrt.py similarity index 93% rename from tests/ttnn/unit_tests/operations/backward/test_backward_sqrt.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sqrt.py index d5ab971374c..e435b33c390 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_sqrt.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sqrt.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_square.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_square.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_square.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_square.py index 4cf8da1b250..86bfae70d69 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_square.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_square.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_squared_difference.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_squared_difference.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_squared_difference.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_squared_difference.py index 965c769ab38..11404a49303 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_squared_difference.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_squared_difference.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sub.py similarity index 97% rename from tests/ttnn/unit_tests/operations/backward/test_backward_sub.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sub.py index a9f83a26c5e..135103a03f9 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_sub.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_sub.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_subalpha.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_subalpha.py similarity index 96% rename from tests/ttnn/unit_tests/operations/backward/test_backward_subalpha.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_subalpha.py index d0c6ac83de0..2b6870dd002 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_subalpha.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_subalpha.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_tan.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tan.py similarity index 92% rename from tests/ttnn/unit_tests/operations/backward/test_backward_tan.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tan.py index d460fcd3aab..a5ab6581cf7 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_tan.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tan.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, compare_results, ) diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_tanh.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tanh.py similarity index 94% rename from tests/ttnn/unit_tests/operations/backward/test_backward_tanh.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tanh.py index b9206bf6a3b..e4e74470648 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_tanh.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tanh.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_tanhshrink.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tanhshrink.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_tanhshrink.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tanhshrink.py index 4d9183f3a01..cf74e5bd8bc 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_tanhshrink.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_tanhshrink.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_threshold.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_threshold.py similarity index 90% rename from tests/ttnn/unit_tests/operations/backward/test_backward_threshold.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_threshold.py index 34e62c8c5fd..05fdabe68e5 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_threshold.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_threshold.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_trunc.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_trunc.py similarity index 88% rename from tests/ttnn/unit_tests/operations/backward/test_backward_trunc.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_trunc.py index ffc71489351..5ef3b2b91ac 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_trunc.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_trunc.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import compare_pcc, data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import compare_pcc, data_gen_with_range @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_where.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_where.py similarity index 96% rename from tests/ttnn/unit_tests/operations/backward/test_backward_where.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_where.py index a8da1533eea..8a28bccb2fc 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_where.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_where.py @@ -6,7 +6,7 @@ import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/test_backward_xlogy.py b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_xlogy.py similarity index 89% rename from tests/ttnn/unit_tests/operations/backward/test_backward_xlogy.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_xlogy.py index 64a4af879ad..ae7990a164c 100644 --- a/tests/ttnn/unit_tests/operations/backward/test_backward_xlogy.py +++ b/tests/ttnn/unit_tests/operations/eltwise/backward/test_backward_xlogy.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/backward/utility_funcs.py b/tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py similarity index 100% rename from tests/ttnn/unit_tests/operations/backward/utility_funcs.py rename to tests/ttnn/unit_tests/operations/eltwise/backward/utility_funcs.py diff --git a/tests/ttnn/unit_tests/operations/complex/test_complex_conj.py b/tests/ttnn/unit_tests/operations/eltwise/complex/test_complex_conj.py similarity index 89% rename from tests/ttnn/unit_tests/operations/complex/test_complex_conj.py rename to tests/ttnn/unit_tests/operations/eltwise/complex/test_complex_conj.py index 383dcb896b5..2ffa699cedb 100644 --- a/tests/ttnn/unit_tests/operations/complex/test_complex_conj.py +++ b/tests/ttnn/unit_tests/operations/eltwise/complex/test_complex_conj.py @@ -7,11 +7,11 @@ import pytest import ttnn from loguru import logger -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0, skip_for_grayskull -from tests.ttnn.unit_tests.operations.complex.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.complex.utility_funcs import ( convert_complex_to_torch_tensor, random_complex_tensor, ) diff --git a/tests/ttnn/unit_tests/operations/complex/utility_funcs.py b/tests/ttnn/unit_tests/operations/eltwise/complex/utility_funcs.py similarity index 100% rename from tests/ttnn/unit_tests/operations/complex/utility_funcs.py rename to tests/ttnn/unit_tests/operations/eltwise/complex/utility_funcs.py diff --git a/tests/ttnn/unit_tests/operations/test_add.py b/tests/ttnn/unit_tests/operations/eltwise/test_add.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_add.py rename to tests/ttnn/unit_tests/operations/eltwise/test_add.py diff --git a/tests/ttnn/unit_tests/operations/test_backward.py b/tests/ttnn/unit_tests/operations/eltwise/test_backward.py similarity index 98% rename from tests/ttnn/unit_tests/operations/test_backward.py rename to tests/ttnn/unit_tests/operations/eltwise/test_backward.py index afad991bd33..5530a9ababb 100644 --- a/tests/ttnn/unit_tests/operations/test_backward.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_backward.py @@ -9,7 +9,7 @@ import ttnn from models.utility_functions import is_wormhole_b0, is_blackhole -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_val, compare_all_close, ) diff --git a/tests/ttnn/unit_tests/operations/test_binary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py similarity index 99% rename from tests/ttnn/unit_tests/operations/test_binary_composite.py rename to tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py index f5460e85cf9..89e74e1f85b 100644 --- a/tests/ttnn/unit_tests/operations/test_binary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py @@ -6,7 +6,7 @@ import pytest import random import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( data_gen_with_range, data_gen_with_range_int, compare_pcc, diff --git a/tests/ttnn/unit_tests/operations/test_binary_scalar.py b/tests/ttnn/unit_tests/operations/eltwise/test_binary_scalar.py similarity index 90% rename from tests/ttnn/unit_tests/operations/test_binary_scalar.py rename to tests/ttnn/unit_tests/operations/eltwise/test_binary_scalar.py index a421f155fca..a7adfeaa031 100644 --- a/tests/ttnn/unit_tests/operations/test_binary_scalar.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_binary_scalar.py @@ -6,7 +6,7 @@ import pytest import ttnn import random -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/test_complex.py b/tests/ttnn/unit_tests/operations/eltwise/test_complex.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_complex.py rename to tests/ttnn/unit_tests/operations/eltwise/test_complex.py diff --git a/tests/ttnn/unit_tests/operations/test_complex_tensor.py b/tests/ttnn/unit_tests/operations/eltwise/test_complex_tensor.py similarity index 88% rename from tests/ttnn/unit_tests/operations/test_complex_tensor.py rename to tests/ttnn/unit_tests/operations/eltwise/test_complex_tensor.py index b12ec3b0b4d..98e080977f6 100644 --- a/tests/ttnn/unit_tests/operations/test_complex_tensor.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_complex_tensor.py @@ -7,11 +7,11 @@ import pytest import ttnn from loguru import logger -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_pcc, comp_equal, comp_allclose from models.utility_functions import is_wormhole_b0 -from tests.ttnn.unit_tests.operations.complex.utility_funcs import ( +from tests.ttnn.unit_tests.operations.eltwise.complex.utility_funcs import ( convert_complex_to_torch_tensor, random_complex_tensor, ) diff --git a/tests/ttnn/unit_tests/operations/test_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_composite.py similarity index 99% rename from tests/ttnn/unit_tests/operations/test_composite.py rename to tests/ttnn/unit_tests/operations/eltwise/test_composite.py index 89479b9e858..5f43cd2ee17 100644 --- a/tests/ttnn/unit_tests/operations/test_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_composite.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc from models.utility_functions import skip_for_grayskull, is_wormhole_b0, is_blackhole diff --git a/tests/ttnn/unit_tests/operations/test_elt_binary.py b/tests/ttnn/unit_tests/operations/eltwise/test_elt_binary.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_elt_binary.py rename to tests/ttnn/unit_tests/operations/eltwise/test_elt_binary.py diff --git a/tests/ttnn/unit_tests/operations/test_eltwise_logical_and_.py b/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_logical_and_.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_eltwise_logical_and_.py rename to tests/ttnn/unit_tests/operations/eltwise/test_eltwise_logical_and_.py diff --git a/tests/ttnn/unit_tests/operations/test_eltwise_typecast.py b/tests/ttnn/unit_tests/operations/eltwise/test_eltwise_typecast.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_eltwise_typecast.py rename to tests/ttnn/unit_tests/operations/eltwise/test_eltwise_typecast.py diff --git a/tests/ttnn/unit_tests/operations/test_inplace.py b/tests/ttnn/unit_tests/operations/eltwise/test_inplace.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_inplace.py rename to tests/ttnn/unit_tests/operations/eltwise/test_inplace.py diff --git a/tests/ttnn/unit_tests/operations/test_math.py b/tests/ttnn/unit_tests/operations/eltwise/test_math.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_math.py rename to tests/ttnn/unit_tests/operations/eltwise/test_math.py diff --git a/tests/ttnn/unit_tests/operations/test_math_binary.py b/tests/ttnn/unit_tests/operations/eltwise/test_math_binary.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_math_binary.py rename to tests/ttnn/unit_tests/operations/eltwise/test_math_binary.py diff --git a/tests/ttnn/unit_tests/operations/test_mul.py b/tests/ttnn/unit_tests/operations/eltwise/test_mul.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_mul.py rename to tests/ttnn/unit_tests/operations/eltwise/test_mul.py diff --git a/tests/ttnn/unit_tests/operations/test_pow.py b/tests/ttnn/unit_tests/operations/eltwise/test_pow.py similarity index 88% rename from tests/ttnn/unit_tests/operations/test_pow.py rename to tests/ttnn/unit_tests/operations/eltwise/test_pow.py index d0eed9fc9d0..51296079a5e 100644 --- a/tests/ttnn/unit_tests/operations/test_pow.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_pow.py @@ -5,7 +5,7 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/test_relational.py b/tests/ttnn/unit_tests/operations/eltwise/test_relational.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_relational.py rename to tests/ttnn/unit_tests/operations/eltwise/test_relational.py diff --git a/tests/ttnn/unit_tests/operations/test_sub.py b/tests/ttnn/unit_tests/operations/eltwise/test_sub.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_sub.py rename to tests/ttnn/unit_tests/operations/eltwise/test_sub.py diff --git a/tests/ttnn/unit_tests/operations/test_ternary.py b/tests/ttnn/unit_tests/operations/eltwise/test_ternary.py similarity index 100% rename from tests/ttnn/unit_tests/operations/test_ternary.py rename to tests/ttnn/unit_tests/operations/eltwise/test_ternary.py diff --git a/tests/ttnn/unit_tests/operations/test_ternary_composite.py b/tests/ttnn/unit_tests/operations/eltwise/test_ternary_composite.py similarity index 97% rename from tests/ttnn/unit_tests/operations/test_ternary_composite.py rename to tests/ttnn/unit_tests/operations/eltwise/test_ternary_composite.py index 9d75a2eabfb..2b9207dbbeb 100644 --- a/tests/ttnn/unit_tests/operations/test_ternary_composite.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_ternary_composite.py @@ -5,7 +5,11 @@ import torch import pytest import ttnn -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, data_gen_with_val, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import ( + data_gen_with_range, + data_gen_with_val, + compare_pcc, +) @pytest.mark.parametrize( diff --git a/tests/ttnn/unit_tests/operations/test_unary.py b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py similarity index 99% rename from tests/ttnn/unit_tests/operations/test_unary.py rename to tests/ttnn/unit_tests/operations/eltwise/test_unary.py index b4a3fca3c4c..e70e76a3b2a 100644 --- a/tests/ttnn/unit_tests/operations/test_unary.py +++ b/tests/ttnn/unit_tests/operations/eltwise/test_unary.py @@ -9,7 +9,7 @@ import ttnn from tests.ttnn.utils_for_testing import assert_with_pcc, assert_equal -from tests.ttnn.unit_tests.operations.backward.utility_funcs import data_gen_with_range, compare_pcc +from tests.ttnn.unit_tests.operations.eltwise.backward.utility_funcs import data_gen_with_range, compare_pcc from models.utility_functions import torch_random, skip_for_grayskull, is_wormhole_b0, is_blackhole diff --git a/ttnn/cpp/pybind11/operations/__init__.hpp b/ttnn/cpp/pybind11/operations/__init__.hpp index 9a0f48ac50d..886c8a3c8f4 100644 --- a/ttnn/cpp/pybind11/operations/__init__.hpp +++ b/ttnn/cpp/pybind11/operations/__init__.hpp @@ -53,25 +53,29 @@ void py_module(py::module& module) { auto m_examples = module.def_submodule("examples", "examples of operations"); examples::py_module(m_examples); + auto m_ccl = module.def_submodule("ccl", "collective communication operations"); + ccl::py_bind_all_gather(m_ccl); + ccl::py_bind_reduce_scatter(m_ccl); + + // Eltwise operations: unary, binary, ternary, backward, complex auto m_unary = module.def_submodule("unary", "unary operations"); unary::py_module(m_unary); auto m_binary = module.def_submodule("binary", "binary operations"); binary::py_module(m_binary); + auto m_ternary = module.def_submodule("ternary", "ternary operations"); + ternary::py_module(m_ternary); + + auto m_unary_backward = module.def_submodule("unary_backward", "unary_backward operations"); + unary_backward::py_module(m_unary_backward); + auto m_binary_backward = module.def_submodule("binary_backward", "binary_backward operations"); binary_backward::py_module(m_binary_backward); auto m_ternary_backward = module.def_submodule("ternary_backward", "ternary_backward operations"); ternary_backward::py_module(m_ternary_backward); - auto m_unary_backward = module.def_submodule("unary_backward", "unary_backward operations"); - unary_backward::py_module(m_unary_backward); - - auto m_ccl = module.def_submodule("ccl", "collective communication operations"); - ccl::py_bind_all_gather(m_ccl); - ccl::py_bind_reduce_scatter(m_ccl); - auto m_complex = module.def_submodule("complex", "complex tensor creation"); complex::py_module(m_complex); @@ -81,9 +85,6 @@ void py_module(py::module& module) { auto m_complex_unary_backward = module.def_submodule("complex_unary_backward", "complex_unary_backward operations"); complex_unary_backward::py_module(m_complex_unary_backward); - auto m_ternary = module.def_submodule("ternary", "ternary operations"); - ternary::py_module(m_ternary); - auto m_creation = module.def_submodule("creation", "creation operations"); creation::py_module(m_creation); diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index a0e3bf481fa..91db0d8e039 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -236,6 +236,7 @@ def auto_register_ttnn_cpp_operations(module): import ttnn.experimental_loader.golden_functions import ttnn.operations +import ttnn.operations.eltwise sub = ttnn.subtract sub_ = ttnn.subtract_ diff --git a/ttnn/ttnn/operations/eltwise/__init__.py b/ttnn/ttnn/operations/eltwise/__init__.py new file mode 100644 index 00000000000..8a3ac51ecef --- /dev/null +++ b/ttnn/ttnn/operations/eltwise/__init__.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pkgutil + +__all__ = [] + +for loader, module_name, is_pkg in pkgutil.walk_packages(__path__): + __all__.append(module_name) + _module = loader.find_spec(module_name).loader.load_module(module_name) + globals()[module_name] = _module diff --git a/ttnn/ttnn/operations/binary.py b/ttnn/ttnn/operations/eltwise/binary.py similarity index 100% rename from ttnn/ttnn/operations/binary.py rename to ttnn/ttnn/operations/eltwise/binary.py diff --git a/ttnn/ttnn/operations/binary_backward.py b/ttnn/ttnn/operations/eltwise/binary_backward.py similarity index 99% rename from ttnn/ttnn/operations/binary_backward.py rename to ttnn/ttnn/operations/eltwise/binary_backward.py index ac833338f2f..255e72c5d80 100644 --- a/ttnn/ttnn/operations/binary_backward.py +++ b/ttnn/ttnn/operations/eltwise/binary_backward.py @@ -5,7 +5,7 @@ import sys import ttnn -from ttnn.operations.complex_binary_backward import ( +from ttnn.operations.eltwise.complex_binary_backward import ( _golden_function_complex_add, _golden_function_complex_sub, _golden_function_complex_mul, diff --git a/ttnn/ttnn/operations/binary_complex.py b/ttnn/ttnn/operations/eltwise/binary_complex.py similarity index 100% rename from ttnn/ttnn/operations/binary_complex.py rename to ttnn/ttnn/operations/eltwise/binary_complex.py diff --git a/ttnn/ttnn/operations/complex_binary_backward.py b/ttnn/ttnn/operations/eltwise/complex_binary_backward.py similarity index 100% rename from ttnn/ttnn/operations/complex_binary_backward.py rename to ttnn/ttnn/operations/eltwise/complex_binary_backward.py diff --git a/ttnn/ttnn/operations/complex_unary_backward.py b/ttnn/ttnn/operations/eltwise/complex_unary_backward.py similarity index 100% rename from ttnn/ttnn/operations/complex_unary_backward.py rename to ttnn/ttnn/operations/eltwise/complex_unary_backward.py diff --git a/ttnn/ttnn/operations/ternary.py b/ttnn/ttnn/operations/eltwise/ternary.py similarity index 100% rename from ttnn/ttnn/operations/ternary.py rename to ttnn/ttnn/operations/eltwise/ternary.py diff --git a/ttnn/ttnn/operations/ternary_backward.py b/ttnn/ttnn/operations/eltwise/ternary_backward.py similarity index 100% rename from ttnn/ttnn/operations/ternary_backward.py rename to ttnn/ttnn/operations/eltwise/ternary_backward.py diff --git a/ttnn/ttnn/operations/unary.py b/ttnn/ttnn/operations/eltwise/unary.py similarity index 100% rename from ttnn/ttnn/operations/unary.py rename to ttnn/ttnn/operations/eltwise/unary.py diff --git a/ttnn/ttnn/operations/unary_backward.py b/ttnn/ttnn/operations/eltwise/unary_backward.py similarity index 100% rename from ttnn/ttnn/operations/unary_backward.py rename to ttnn/ttnn/operations/eltwise/unary_backward.py diff --git a/ttnn/ttnn/operations/unary_complex.py b/ttnn/ttnn/operations/eltwise/unary_complex.py similarity index 100% rename from ttnn/ttnn/operations/unary_complex.py rename to ttnn/ttnn/operations/eltwise/unary_complex.py From 7d48d0a98a3a5e1101e9286860628e398696fe76 Mon Sep 17 00:00:00 2001 From: Joseph Chu <122298491+cfjchu@users.noreply.github.com> Date: Wed, 9 Oct 2024 22:12:56 -0700 Subject: [PATCH 54/58] #13650: Add example of hybrid TP/DP llama-70b model on TG (#13651) - Tile T3000 model configuration 4x on TG --- models/MODEL_HYBRID_TP_DP.md | 44 ++++ .../tests/test_llama_perf_decode.py | 218 ++++++++++++++++++ tests/scripts/tg/run_tg_model_perf_tests.sh | 1 + 3 files changed, 263 insertions(+) create mode 100644 models/MODEL_HYBRID_TP_DP.md diff --git a/models/MODEL_HYBRID_TP_DP.md b/models/MODEL_HYBRID_TP_DP.md new file mode 100644 index 00000000000..299cfdc369c --- /dev/null +++ b/models/MODEL_HYBRID_TP_DP.md @@ -0,0 +1,44 @@ +# Hybrid Tensor and Data Parallelism Implementation + +This short guide explains how to add hybrid tensor and data parallelism to your model using submesh tiling across a larger mesh. + +## Overview of Changes + +The main changes involve: + +1. Creating multiple submeshes from the main mesh +2. Running the model on each submesh +3. Capturing and replaying a trace across all submeshes in parallel + +## Key Implementation Details + +### 1. Submesh Creation + +```python + # Work with submesh device as you would with a regular ttnn.MeshDevice + submesh_devices: List[ttnn.MeshDevice] = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring) +``` + +### 2. Compile & Run the Model on Each Submesh + +```python + # Run the model on each submesh + for submesh_device in submesh_devices: + model(..., device=submesh_device) +``` + +### 3. Capture & Replay the Trace + +```python + + # Capture Model Trace spanning all submeshes + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + for submesh_device in submesh_devices: + model(..., device=submesh) # Run the model on each submesh + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + + # Execute Model Trace across all submeshes in parallel + ttnn.execute_trace(mesh_device, trace_id, blocking=False) + ttnn.release_trace(mesh_device, trace_id) + +``` diff --git a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py index fbccd4176c3..3526206b852 100644 --- a/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py +++ b/models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py @@ -28,6 +28,8 @@ ) from models.perf.perf_utils import prep_perf_report +from collections import defaultdict + def get_decode_time(profiler, start_token, end_token): total_time = 0 @@ -254,3 +256,219 @@ def test_Llama_perf_host( tokenizer_path, cache_path, ) + + +def run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel( + mesh_device, + llama_version, + batch, + seq_len, + max_context_len, + model_config, + n_layers, + n_devices, + generation_length, + expected_compile_time, + expected_inference_time, + ckpt_dir, + tokenizer_path, + cache_path, +): + # Prepare paths and devices + skip_model_load = should_skip_model_load() + + logger.info(f"Running num_layer: {n_layers}") + + generator = Llama.build( + ckpt_dir, + tokenizer_path, + max_seq_len=max_context_len, + max_batch_size=batch, + n_layers=1, + skip_model_load=skip_model_load, + ) + hugging_face_reference_model, tokenizer = generator.model, generator.tokenizer + hugging_face_reference_model.eval() + # state_dict = hugging_face_reference_model.state_dict() + state_dict = load_llama_state_dict(ckpt_dir, n_layers=n_layers) + configuration = hugging_face_reference_model.params + + # Prepare input ----------------------------------------------------------------------- + torch.manual_seed(0) + total_len = min(max_context_len, generation_length + 1) + n_iters = 100 # Number of iterations to run in order to get a perf estimate + tokens = torch.randint(0, 10000, (batch, 1), dtype=torch.long) + # Clear global profiler state before starting measurements + profiler.clear() + + submesh_to_metadata = defaultdict(dict) + submeshes = mesh_device.create_submeshes((2, 4), ttnn.MeshType.Ring) + for submesh in submeshes: + # Set up model ----------------------------------------------------------------------- + logger.info("Moving weights to devices; might take some time...") + profiler.start("TT_llama_model_setup") + tt_model = TtLlamaModel_optimized( + submesh, + state_dict, + BASE_URL, + n_layers, + model_config, + configuration, + cache_path=cache_path, + read_cache=True, + ) + + for i in submesh.get_device_ids(): + device = submesh.get_device(i) + ttnn.synchronize_device(device) + + profiler.end("TT_llama_model_setup") + + ##### Prepare Inputs ##### + prev_pos = total_len - 1 + tt_inp_emb, prev_pos, rot_mat, cache_idxs = tt_model.prepare_inputs(tokens, prev_pos) + tt_inp_emb = ttnn.to_device(tt_inp_emb, submesh, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_inp_emb = tt_model.tt_embd(tt_inp_emb) + tt_inp_emb = ttnn.interleaved_to_sharded(tt_inp_emb, tt_model.model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) + + rot_mat = ttnn.to_device(rot_mat, submesh, memory_config=tt_model.model_config["ROT_MAT_MM_IN1_MEMCFG"]) + cache_idxs = ttnn.to_device(cache_idxs, submesh, memory_config=ttnn.DRAM_MEMORY_CONFIG) + + ##### Compile Model ##### + logger.info("Compiling model") + profiler.start(f"compile_time") + tt_logits = tt_model(tt_inp_emb, rot_mat, prev_pos, cache_idxs=cache_idxs, mode="decode") + tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_logits_tensors = ttnn.get_device_tensors(tt_logits) + logits_rm = ttnn.to_layout(tt_logits_tensors[0], ttnn.ROW_MAJOR_LAYOUT) + logits = ttnn.to_torch(logits_rm) + profiler.end(f"compile_time") + profiler.print() + compile_iter_time = profiler.get("compile_time") + logger.info(f"decode with compile time, single iter latency: {compile_iter_time}") + + submesh_to_metadata[submesh.get_mesh_id()] = { + "submesh": submesh, + "logits_rm": logits_rm, + "tt_model": tt_model, + "prev_pos": prev_pos, + "tt_inp_emb": tt_inp_emb, + "rot_mat": rot_mat, + "cache_idxs": cache_idxs, + } + + ##### Capture Trace ##### + logger.info("Capturing trace") + trace_id = ttnn.begin_trace_capture(mesh_device, cq_id=0) + + for submesh in submeshes: + mesh_id = submesh.get_mesh_id() + tt_model = submesh_to_metadata[mesh_id]["tt_model"] + tt_inp_emb = submesh_to_metadata[mesh_id]["tt_inp_emb"] + rot_mat = submesh_to_metadata[mesh_id]["rot_mat"] + cache_idxs = submesh_to_metadata[mesh_id]["cache_idxs"] + prev_pos = submesh_to_metadata[mesh_id]["prev_pos"] + + tt_logits = tt_model(tt_inp_emb, rot_mat, prev_pos, cache_idxs=cache_idxs, mode="decode") + tt_logits = ttnn.all_gather(tt_logits, dim=3, num_links=1, memory_config=ttnn.DRAM_MEMORY_CONFIG) + tt_logits_tensors = ttnn.get_device_tensors(tt_logits) + logits_rm = ttnn.to_layout(tt_logits_tensors[0], ttnn.ROW_MAJOR_LAYOUT) + ttnn.end_trace_capture(mesh_device, trace_id, cq_id=0) + + ##### Execute Trace ##### + logger.info("Executing trace") + profiler.start(f"end_to_end_inference") + for i in range(n_iters): + ttnn.execute_trace(mesh_device, trace_id, blocking=False) + logits = ttnn.to_torch(logits_rm) + profiler.end(f"end_to_end_inference") + ttnn.release_trace(mesh_device, trace_id) + + profiler.print() + loop_time = profiler.get("end_to_end_inference") + iter_time = loop_time / n_iters + logger.info(f"decode cached, single iter latency: {iter_time}") + + comment = f"num_layers={n_layers}L_n_devices={n_devices}" + + prep_perf_report( + model_name=f"{llama_version}_70b_{comment}", + batch_size=batch, + inference_and_compile_time=compile_iter_time, + inference_time=iter_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments=comment, + ) + + tokens_per_s_per_user = 1 / iter_time + tokens_per_s_overall = tokens_per_s_per_user * batch * len(submeshes) + + logger.info(f"Time per iteration: {iter_time}") + logger.info(f"Tokens per s per user: {tokens_per_s_per_user}") + logger.info(f"Tokens per s overall: {tokens_per_s_overall}") + + # assert compile_time <= expected_compile_time + assert iter_time <= expected_inference_time + + +@skip_for_grayskull("Requires eth connected devices to run") +@pytest.mark.timeout(4500) +@pytest.mark.model_perf_tg +@pytest.mark.parametrize( + "llama_version", + (("llama3"),), +) +@pytest.mark.parametrize( + "generation_length, expected_compile_time, expected_inference_time, batch, seq_len, max_context_len", + ( + (32, 10000, 0.0653 + 0.01, 32, 1, 4096), + (128, 10000, 0.0655 + 0.01, 32, 1, 4096), + (2048, 10000, 0.0771 + 0.01, 32, 1, 4096), + (8192, 10000, 0.0825 + 0.01, 16, 1, 8192), + (128 * 1024, 10000, 0.0918 + 0.01, 1, 1, 128 * 1024), + ), + ids=["gen32", "gen128", "gen2k", "gen8k", "gen128k"], +) +@pytest.mark.parametrize("device_params", [{"trace_region_size": 20000000}], indirect=True) +@pytest.mark.parametrize("mesh_device", [pytest.param((8, 4), id="8x4_grid")], indirect=True) +def test_Llama_perf_hybrid_data_tensor_parallel( + mesh_device, + generation_length, + expected_compile_time, + expected_inference_time, + batch, + seq_len, + max_context_len, + llama_version, + use_program_cache, + n_layers=80, + n_devices=8, +): + model_config, ckpt_dir, tokenizer_path, cache_path = setup_llama_env( + llama_version=llama_version, + max_batch_size=batch, + max_context_len=max_context_len, + ) + + check_mesh_device(mesh_device, model_config) + mesh_device.enable_async(True) + + disable_compilation_reports() + + run_test_LlamaModel_end_to_end_hybrid_data_tensor_parallel( + mesh_device, + llama_version, + batch, + seq_len, + max_context_len, + model_config, + n_layers, + n_devices, + generation_length, + expected_compile_time, + expected_inference_time, + ckpt_dir, + tokenizer_path, + cache_path, + ) diff --git a/tests/scripts/tg/run_tg_model_perf_tests.sh b/tests/scripts/tg/run_tg_model_perf_tests.sh index 9501e79e423..d86a7a96688 100755 --- a/tests/scripts/tg/run_tg_model_perf_tests.sh +++ b/tests/scripts/tg/run_tg_model_perf_tests.sh @@ -4,6 +4,7 @@ run_tg_llm_tests() { echo "LOG_METAL: Running run_t3000_llama2_70b_tests" pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_t3000" --timeout=600 ; fail+=$? + pytest -n auto models/demos/t3000/llama2_70b/tests/test_llama_perf_decode.py -m "model_perf_tg" --timeout=600 ; fail+=$? # Merge all the generated reports env python models/perf/merge_perf_results.py; fail+=$? From 26c42ae7a3234bf28f4fe089705674b3d5d7729e Mon Sep 17 00:00:00 2001 From: Joseph Chu <122298491+cfjchu@users.noreply.github.com> Date: Wed, 9 Oct 2024 22:14:20 -0700 Subject: [PATCH 55/58] #0: Update README.md for tracking model performance (#13652) - update to use TP/DP abbreviation and track tok/s column - add performance result for llama-70b with TP,DP --- README.md | 48 ++++++++++++++++++++++++------------------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 62a6cd660aa..e038a3320d6 100644 --- a/README.md +++ b/README.md @@ -21,22 +21,24 @@ --- ## LLMs -| Model | Batch | Hardware | ttft (s) | t/s/u | Target t/s/u | Release | -|----------------------------------------------------------------------|-------|----------------------------------------------------------|------------|-------|--------------|---------------------------------------------------------------------------| -| [Falcon7B-decode](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | | -| [Falcon7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.07 | 16.7 | 26 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | -| [Mistral-7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) | -| [Mamba-2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.04 | 12.3 | 41 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) | -| [LLaMA-3.1-8B](./models/demos/wormhole/llama31_8b) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.20 | 21.4 | 23 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Falcon7B (data parallel)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.10 | 14.4 | 26 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | -| [LLaMA-2-70B - (tensor parallel)](./models/demos/t3000/llama2_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | -| [LLaMA-3.1-70B (tensor parallel)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | -| [Falcon40B (tensor parallel)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | [v0.53.0-rc2](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc2) | -| [Mixtral7Bx8 (tensor parallel)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.23 | 14.2 | 33 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | -| [Falcon7B (data parallel)](./models/demos/tg/falcon7b) |1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.21 | 4.4 | 26 | [v0.53.0-rc9](https://github.com/tenstorrent/tt-metal/tree/v0.53.0-rc9) | +| Model | Batch | Hardware | ttft (s) | t/s/u | Target t/s/u | t/s | Release | +|----------------------------------------------------------------------|-------|----------------------------------------------------------|----------|-------|--------------|--------|---------------------------------------------------------------------------| +| [Falcon7B-decode](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | 134.4 | | +| [Falcon7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.07 | 16.7 | 26 | 534.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Mistral-7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | 316.8 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) | +| [Mamba-2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.04 | 12.3 | 41 | 393.6 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) | +| [LLaMA-3.1-8B](./models/demos/wormhole/llama31_8b) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.20 | 21.4 | 23 | 21.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Falcon7B (DP=8)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.10 | 14.4 | 26 | 3686.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [LLaMA-2-70B - (TP=8)](./models/demos/t3000/llama2_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | 483.2 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [LLaMA-3.1-70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | 483.2 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Falcon40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Mixtral7Bx8 (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.23 | 14.2 | 33 | 454.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Falcon7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.24 | 4.4 | 26 | 4505.6 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [LLaMA-3.1-70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.19 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | > **Last Update:** October 7, 2024 > **Notes:** +> - TP = Tensor Parallel, DP = Data Parallel; Defines parallelization factors across multiple devices. > - The reported LLM performance is for an input sequence length (number of rows filled in the KV cache) of 128 for all models except Mamba (which can accept any sequence length). > - The t/s/u reported is the throughput of the first token generated after prefill, i.e. 1 / inter token latency. @@ -45,22 +47,20 @@ |-----------------------------------------------------------------------------|-------|----------------------------------------------------------|---------|------------|-------------| | [ResNet-50 (224x224)](./models/demos/grayskull/resnet50) | 20 | [e150](https://tenstorrent.com/hardware/grayskull) | 5,100 | 10,000 | | | [ResNet-50 (224x224)](./models/demos/wormhole/resnet50) | 16 | [n150](https://tenstorrent.com/hardware/wormhole) | 4,100 | 7,000 | | -| [ResNet-50 (224x224) (data parallel)](./models/demos/t3000/resnet50) | 128 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 32,250 | 56,000 | | -| [ResNet-50 (224x224) (data parallel)](./models/demos/tg/resnet50) | 512 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 95,900 | 224,000 | | -| [ResNet-50 (224x224) (data parallel)](./models/demos/tgg/resnet50) | 1024 | [Two Galaxies](https://tenstorrent.com/hardware/galaxy) | 128,800 | 448,000 | | +| [ResNet-50 (224x224) (DP=8)](./models/demos/t3000/resnet50) | 128 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 32,250 | 56,000 | | +| [ResNet-50 (224x224) (DP=32)](./models/demos/tg/resnet50) | 512 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 95,900 | 224,000 | | +| [ResNet-50 (224x224) (DP=64)](./models/demos/tgg/resnet50) | 1024 | [Two Galaxies](https://tenstorrent.com/hardware/galaxy) | 128,800 | 448,000 | | | [ViT](./models/demos/grayskull/vit) | 9 | [e150](https://tenstorrent.com/hardware/grayskull) | 1,360 | 2,000 | | | [ViT](./models/demos/wormhole/vit) | 8 | [n150](https://tenstorrent.com/hardware/wormhole) | 912 | 1,600 | | | [Stable Diffusion 1.4 (512x512)](./models/demos/wormhole/stable_diffusion) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.167 | 0.3 | | ## NLPs -| Model | Batch | Hardware | sen/sec | Target sen/sec | Release | -|-----------------------------------------------------|-------|----------------------------------------------------|-----------|----------------|-------------| -| [BERT-Large](./models/demos/metal_BERT_large_11/) | 12 | [e150](https://tenstorrent.com/hardware/grayskull) | 370 | 410 | | -| [BERT-Large](./models/demos/metal_BERT_large_11/) | 8 | [n150](https://tenstorrent.com/hardware/wormhole) | 270 | 400 | | -| [T5 small](.models/demos/grayskull/t5) | | [e150](https://tenstorrent.com/hardware/grayskull) | 140 | | | -| [Bloom](.models/demos/grayskull/functional_bloom) | | [e150](https://tenstorrent.com/hardware/grayskull) | 70 | | | - - +| Model | Batch | Hardware | sen/sec | Target sen/sec | Release | +|-----------------------------------------------------|-------|----------------------------------------------------|---------|----------------|---------| +| [BERT-Large](./models/demos/metal_BERT_large_11/) | 12 | [e150](https://tenstorrent.com/hardware/grayskull) | 370 | 410 | | +| [BERT-Large](./models/demos/metal_BERT_large_11/) | 8 | [n150](https://tenstorrent.com/hardware/wormhole) | 270 | 400 | | +| [T5 small](.models/demos/grayskull/t5) | | [e150](https://tenstorrent.com/hardware/grayskull) | 140 | | | +| [Bloom](.models/demos/grayskull/functional_bloom) | | [e150](https://tenstorrent.com/hardware/grayskull) | 70 | | | ## Model Updates For the latest model updates and features, please see [MODEL_UPDATES.md](models/MODEL_UPDATES.md) From 114eb487d49aee2e455eb315a75a876ecb590ce8 Mon Sep 17 00:00:00 2001 From: Joseph Chu <122298491+cfjchu@users.noreply.github.com> Date: Wed, 9 Oct 2024 22:31:58 -0700 Subject: [PATCH 56/58] #0: update README to compact model performance table (#13666) --- README.md | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index e038a3320d6..cbad5ca6cb6 100644 --- a/README.md +++ b/README.md @@ -21,20 +21,20 @@ --- ## LLMs -| Model | Batch | Hardware | ttft (s) | t/s/u | Target t/s/u | t/s | Release | -|----------------------------------------------------------------------|-------|----------------------------------------------------------|----------|-------|--------------|--------|---------------------------------------------------------------------------| -| [Falcon7B-decode](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | 134.4 | | -| [Falcon7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.07 | 16.7 | 26 | 534.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Mistral-7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | 316.8 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) | -| [Mamba-2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.04 | 12.3 | 41 | 393.6 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) | -| [LLaMA-3.1-8B](./models/demos/wormhole/llama31_8b) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.20 | 21.4 | 23 | 21.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Falcon7B (DP=8)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.10 | 14.4 | 26 | 3686.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [LLaMA-2-70B - (TP=8)](./models/demos/t3000/llama2_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | 483.2 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [LLaMA-3.1-70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | 483.2 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Falcon40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Mixtral7Bx8 (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.23 | 14.2 | 33 | 454.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [Falcon7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.24 | 4.4 | 26 | 4505.6 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | -| [LLaMA-3.1-70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.19 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| Model | Batch | Hardware | ttft (s) | t/s/u | Target
t/s/u | t/s | Release | +|---------------------------------------------------------------|-------|----------------------------------------------------------|----------|-------|-----------------|--------|---------------------------------------------------------------------------| +| [Falcon7B-decode](./models/demos/ttnn_falcon7b) | 32 | [e150](https://tenstorrent.com/hardware/grayskull) | | 4.2 | 4.4 | 134.4 | | +| [Falcon7B](./models/demos/wormhole/falcon7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.07 | 16.7 | 26 | 534.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Mistral-7B](./models/demos/wormhole/mistral7b) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | | 9.9 | 25 | 316.8 | [v0.51.0-rc28](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc28) | +| [Mamba-2.8B](./models/demos/wormhole/mamba) | 32 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.04 | 12.3 | 41 | 393.6 | [v0.51.0-rc26](https://github.com/tenstorrent/tt-metal/tree/v0.51.0-rc26) | +| [LLaMA-3.1-8B](./models/demos/wormhole/llama31_8b) | 1 | [n150](https://tenstorrent.com/hardware/wormhole) | 0.20 | 21.4 | 23 | 21.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Falcon7B (DP=8)](./models/demos/t3000/falcon7b) | 256 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.10 | 14.4 | 26 | 3686.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [LLaMA-2-70B - (TP=8)](./models/demos/t3000/llama2_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | 483.2 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [LLaMA-3.1-70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.19 | 15.1 | 20 | 483.2 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Falcon40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Mixtral7Bx8 (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 0.23 | 14.2 | 33 | 454.4 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [Falcon7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.24 | 4.4 | 26 | 4505.6 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | +| [LLaMA-3.1-70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 0.19 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | > **Last Update:** October 7, 2024 > **Notes:** From 04f71b457c29f9467731c82e4df38a7358169979 Mon Sep 17 00:00:00 2001 From: namhyeong-kim <141107133+namhyeong-kim@users.noreply.github.com> Date: Thu, 10 Oct 2024 14:46:54 +0900 Subject: [PATCH 57/58] #13576: Apply unpack_to_dest and use matmul_tiles instread of reduce_tile in w-dim (#13617) * #13576: Apply unpack_to_dest to moreh_mean_h * #13576: Apply unpack_to_dest to moreh_linear_bwd's bias_bwd * #13576: Add fp32 moreh_matmul test * #13576: Use matmul_tiles instead of reduce_tile in moreh_mean_w * #13576: Do not check rank of tiled idx tensor * #13576: Use matmul_tiles instead of reduce_w in moreh_sum_w * #13576: Apply unpack_to_dest in moreh_sum_h * #13576: Fix check-black error --------- Co-authored-by: Dongjin Na --- .../operations/test_moreh_getitem.py | 12 +- .../operations/test_moreh_linear.py | 112 ++++++++++++++---- .../operations/test_moreh_matmul.py | 89 +++++++++++++- .../device/moreh_getitem_device_operation.cpp | 4 +- .../kernels/reader_moreh_bias_backward_h.cpp | 12 +- ...ar_backward_multi_core_program_factory.cpp | 16 ++- .../device/kernels/moreh_mean_w.cpp | 17 ++- .../device/kernels/reader_moreh_mean_w.cpp | 25 +--- .../device/moreh_mean_h_program_factory.cpp | 8 +- .../moreh_mean_backward_program_factory.cpp | 2 - .../device/moreh_sum_h_program_factory.cpp | 4 +- .../moreh_sum_w_impl_kernels/moreh_sum_w.cpp | 55 ++++----- .../reader_moreh_sum_w.cpp | 4 +- .../device/moreh_sum_w_program_factory.cpp | 10 +- 14 files changed, 256 insertions(+), 114 deletions(-) diff --git a/tests/ttnn/unit_tests/operations/test_moreh_getitem.py b/tests/ttnn/unit_tests/operations/test_moreh_getitem.py index 7814eed71b3..e67bdaba854 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_getitem.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_getitem.py @@ -359,7 +359,7 @@ def test_getitem_tilized_one_index(shape_index_dim, dtype, index_size, row_major else: dev_idx = ( ttnn.Tensor(idx, ttnn.int32) - .reshape([1, 1, 1, 1, index_size]) + .reshape([1, index_size]) .pad_to_tile(float("nan")) .to(ttnn.TILE_LAYOUT) .to(device) @@ -452,7 +452,7 @@ def test_getitem_tilized_two_indices(shape_index_dims, dtype, index_size, row_ma else: dev_idx = ( ttnn.Tensor(idx, ttnn.int32) - .reshape([1, 1, 1, 1, index_size]) + .reshape([1, index_size]) .pad_to_tile(float("nan")) .to(ttnn.TILE_LAYOUT) .to(device) @@ -541,7 +541,7 @@ def test_getitem_tilized_three_indices(shape_index_dims, dtype, index_size, row_ else: dev_idx = ( ttnn.Tensor(idx, ttnn.int32) - .reshape([1, 1, 1, 1, index_size]) + .reshape([1, index_size]) .pad_to_tile(float("nan")) .to(ttnn.TILE_LAYOUT) .to(device) @@ -625,7 +625,7 @@ def test_getitem_tilized_four_indices(shape_index_dims, dtype, index_size, row_m else: dev_idx = ( ttnn.Tensor(idx, ttnn.int32) - .reshape([1, 1, 1, 1, index_size]) + .reshape([1, index_size]) .pad_to_tile(float("nan")) .to(ttnn.TILE_LAYOUT) .to(device) @@ -706,7 +706,7 @@ def test_getitem_tilized_five_indices(shape_index_dims, dtype, index_size, row_m else: dev_idx = ( ttnn.Tensor(idx, ttnn.int32) - .reshape([1, 1, 1, 1, index_size]) + .reshape([1, index_size]) .pad_to_tile(float("nan")) .to(ttnn.TILE_LAYOUT) .to(device) @@ -751,7 +751,7 @@ def run_moreh_geitem_tilized_one_index(shape_index_dim, dtype, index_size, row_m else: dev_idx = ( ttnn.Tensor(idx, ttnn.int32) - .reshape([1, 1, 1, 1, index_size]) + .reshape([1, index_size]) .pad_to_tile(float("nan")) .to(ttnn.TILE_LAYOUT) .to(device) diff --git a/tests/ttnn/unit_tests/operations/test_moreh_linear.py b/tests/ttnn/unit_tests/operations/test_moreh_linear.py index 12109abba8e..2cdb0c3be32 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_linear.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_linear.py @@ -5,8 +5,8 @@ import pytest import torch import ttnn -from models.utility_functions import comp_allclose_and_pcc -from tests.ttnn.unit_tests.operations.test_moreh_matmul import get_tensors +from models.utility_functions import comp_allclose_and_pcc, skip_for_grayskull +from tests.ttnn.unit_tests.operations.test_moreh_matmul import get_tensors, get_bias_tensors from loguru import logger from tests.ttnn.unit_tests.operations.test_utils import ( get_compute_kernel_options, @@ -15,24 +15,6 @@ ) -# TODO: add this feature in get_tensors method -def get_bias_tensors(bias_shape, require_bias_grad, device): - npu_dtype = ttnn.bfloat16 - cpu_dtype = torch.bfloat16 - npu_layout = ttnn.TILE_LAYOUT - cpu_layout = ttnn.ROW_MAJOR_LAYOUT - - bias = torch.randint(-10, 10, bias_shape, dtype=cpu_dtype) - tt_bias = ttnn.Tensor(bias, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) - - tt_bias_grad = None - if require_bias_grad: - bias_grad = torch.full(bias_shape, float("nan"), dtype=cpu_dtype) - tt_bias_grad = ttnn.Tensor(bias_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) - - return tt_bias, bias, tt_bias_grad - - def moreh_linear(shapes, has_bias, has_output, compute_kernel_config, device): torch.manual_seed(3072) input_shape, weight_shape, bias_shape, output_shape = shapes @@ -267,3 +249,93 @@ def test_moreh_linear_backward_enable_cache(shapes, device, use_program_cache): num_program_cache_entries_list.append(device.num_program_cache_entries()) assert passing assert len(set(num_program_cache_entries_list)) == 1 + + +@skip_for_grayskull("GS does not support fp32") +@pytest.mark.parametrize( + "shapes", + ( + # input, weight, bias(1d or scalar), output + # GPT2-Small cases + ([8, 512, 768], [2304, 768], [1, 2304], [8, 512, 2304]), + ([8, 512, 768], [768, 768], [1, 768], [8, 512, 768]), + ([8, 512, 768], [3072, 768], [1, 3072], [8, 512, 3072]), + ), +) +def test_moreh_bias_backward_fp32(shapes, device): + torch.manual_seed(3072) + compute_kernel_fp32_config = get_compute_kernel_options(True) + compute_kernel_config = get_compute_kernel_options(False) + requires_input_grad, requires_weight_grad, requires_bias_grad = (True, False, True) + input_shape, weight_shape, bias_shape, output_shape = shapes + ( + tt_input, + tt_weight, + _, + tt_output_grad, + tt_input_grad, + _, + torch_input, + torch_weight, + torch_output_grad, + ) = get_tensors( + input_shape, weight_shape, output_shape, requires_input_grad, requires_weight_grad, False, device, False + ) + tt_bias, torch_bias, tt_bias_grad = get_bias_tensors(bias_shape, requires_bias_grad, device, False) + (_, _, _, _, tt_input_grad_fp32, _, _, _, _) = get_tensors( + input_shape, weight_shape, output_shape, requires_input_grad, requires_weight_grad, False, device, False + ) + (_, _, tt_bias_grad_fp32) = get_bias_tensors(bias_shape, requires_bias_grad, device, False) + ## tt linear backward (fp32 mode) + tt_input_grad_fp32, _, tt_bias_grad_fp32 = ttnn.operations.moreh.linear_backward( + tt_output_grad, + tt_input, + tt_weight, + are_required_outputs=(requires_input_grad, requires_weight_grad, requires_bias_grad), + bias=tt_bias, + input_grad=tt_input_grad_fp32, + weight_grad=None, + bias_grad=tt_bias_grad_fp32, + compute_kernel_config=compute_kernel_fp32_config, + ) + ## tt linear backward (bf16 mode) + tt_input_grad, _, tt_bias_grad = ttnn.operations.moreh.linear_backward( + tt_output_grad, + tt_input, + tt_weight, + are_required_outputs=(requires_input_grad, requires_weight_grad, requires_bias_grad), + bias=tt_bias, + input_grad=tt_input_grad, + weight_grad=None, + bias_grad=tt_bias_grad, + compute_kernel_config=compute_kernel_config, + ) + torch_input_fp32 = torch_input.float() + torch_weight_fp32 = torch_weight.float() + torch_bias_fp32 = torch_bias.float() + ## reference + torch_output = torch.nn.functional.linear( + torch_input_fp32.requires_grad_(requires_input_grad), + torch_weight_fp32.requires_grad_(requires_weight_grad), + torch_bias_fp32.requires_grad_(requires_bias_grad), + ) + torch_output.backward(torch_output_grad.float()) + ## test for equivalance + rtol = atol = 0.1 + tt_bias_grad_fp32_cpu = tt_bias_grad_fp32.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(bias_shape).to_torch() + tt_bias_grad_cpu = tt_bias_grad.cpu().to(ttnn.ROW_MAJOR_LAYOUT).unpad_from_tile(bias_shape).to_torch() + passing, output_pcc = comp_allclose_and_pcc( + torch_bias_fp32.grad, tt_bias_grad_fp32_cpu, pcc=0.98, rtol=rtol, atol=atol + ) + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + assert passing + diff_fp32 = torch.abs(torch_bias_fp32.grad - tt_bias_grad_fp32_cpu) + logger.debug(f"std={torch.std(diff_fp32)}") + logger.debug(f"mean={diff_fp32.mean()}") + logger.debug(f"topk(5) {torch.topk(diff_fp32.reshape(-1), 5)}") + diff = torch.abs(torch_bias_fp32.grad - tt_bias_grad_cpu) + logger.debug(f"std={torch.std(diff)}") + logger.debug(f"mean={diff.mean()}") + logger.debug(f"topk(5) {torch.topk(diff.reshape(-1), 5)}") + assert diff_fp32.mean() < diff.mean() diff --git a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py index 7c2cdf97447..34350dfcabd 100644 --- a/tests/ttnn/unit_tests/operations/test_moreh_matmul.py +++ b/tests/ttnn/unit_tests/operations/test_moreh_matmul.py @@ -48,7 +48,11 @@ def get_tensors( # tensors for backward output_grad = tt_output_grad = torch_output_grad = tt_input_grad = tt_other_grad = None if require_input_grad or require_other_grad: - output_grad = torch.randint(-2, 3, output_shape, dtype=cpu_dtype) + output_grad = ( + torch.randint(-2, 3, output_shape, dtype=cpu_dtype) + if use_randint + else torch.rand(output_shape, dtype=cpu_dtype) + ) tt_output_grad = ttnn.Tensor(output_grad, npu_dtype).pad_to_tile(float(-1)).to(npu_layout).to(device) torch_output_grad = output_grad[0][0][0][0] if is_1d else output_grad @@ -81,6 +85,24 @@ def get_tensors( ) +def get_bias_tensors(bias_shape, require_bias_grad, device, use_int=True): + npu_dtype = ttnn.bfloat16 + cpu_dtype = torch.bfloat16 + npu_layout = ttnn.TILE_LAYOUT + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + bias = ( + torch.randint(-10, 10, bias_shape, dtype=cpu_dtype) + if use_int + else torch.rand(bias_shape, dtype=cpu_dtype) * 10 - 5 + ) + tt_bias = ttnn.Tensor(bias, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + tt_bias_grad = None + if require_bias_grad: + bias_grad = torch.full(bias_shape, float("nan"), dtype=cpu_dtype) + tt_bias_grad = ttnn.Tensor(bias_grad, npu_dtype).pad_to_tile(float("nan")).to(npu_layout).to(device) + return tt_bias, bias, tt_bias_grad + + def moreh_matmul(params, has_output, compute_kernel_config, device): torch.manual_seed(3072) input_shape, other_shape, output_shape, transpose_input, transpose_other = params @@ -200,7 +222,7 @@ def test_moreh_matmul_enable_cache(params, device, use_program_cache): assert device.num_program_cache_entries() == 2 -@skip_for_grayskull("Doesn't seem to work properly on Grayskull devices. Wormhole_b0 devices work fine.") +@skip_for_grayskull("GS does not support fp32") @pytest.mark.parametrize( "params", ( @@ -455,3 +477,66 @@ def test_moreh_matmul_1d_backward(input_shape, requires_grad, device): logger.debug(f"other_grad passing={passing}") logger.debug(f"other_grad pcc={output_pcc}") assert passing + + +@skip_for_grayskull("GS does not support fp32") +@pytest.mark.parametrize( + "params", + ( + # input, other, output shape, transpose input, other + ([31, 3100], [3100, 31], [31, 31], False, False), + ), +) +def test_moreh_matmul_with_bias_add_fp32_dest_acc(params, device): + torch.manual_seed(3072) + input_shape, other_shape, output_shape, transpose_input, transpose_other = params + tt_input, tt_other, tt_output_fp32, _, _, _, torch_input, torch_other, _ = get_tensors( + input_shape, other_shape, output_shape, False, False, False, device, use_randint=False + ) + tt_bias, torch_bias, _ = get_bias_tensors([1, 31], False, device, False) + compute_kernel_config_fp32_dest_acc = get_compute_kernel_options(True) + compute_kernel_config_bf16_dest_acc = get_compute_kernel_options(False) + torch_input = torch_input.transpose(-1, -2) if transpose_input else torch_input + torch_other = torch_other.transpose(-1, -2) if transpose_other else torch_other + # tt matmul + tt_output_fp32 = ttnn.operations.moreh.matmul( + tt_input, + tt_other, + transpose_input=transpose_input, + transpose_other=transpose_other, + output=tt_output_fp32, + bias=tt_bias, + compute_kernel_config=compute_kernel_config_fp32_dest_acc, + ) + tt_output_fp16 = ttnn.operations.moreh.matmul( + tt_input, + tt_other, + transpose_input=transpose_input, + transpose_other=transpose_other, + bias=tt_bias, + compute_kernel_config=compute_kernel_config_bf16_dest_acc, + ) + cpu_layout = ttnn.ROW_MAJOR_LAYOUT + tt_output_cpu_fp32 = tt_output_fp32.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + tt_output_cpu_bf16 = tt_output_fp16.cpu().to(cpu_layout).unpad_from_tile(output_shape).to_torch() + # torch matmul (float) + torch_out = torch.matmul(torch_input.float(), torch_other.float()) + torch_bias + # test for equivalance + rtol = atol = 0.1 + passing, output_pcc = comp_allclose_and_pcc(torch_out, tt_output_cpu_fp32, pcc=0.99, rtol=rtol, atol=atol) + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + diff = torch.abs(torch_out - tt_output_cpu_fp32) + logger.debug(f"std={torch.std(diff)}") + logger.debug(f"mean={diff.mean()}") + logger.debug(f"topk(5) {torch.topk(diff.reshape(-1), 5)}") + assert passing + torch_out = torch.matmul(torch_input.bfloat16(), torch_other.bfloat16()) + passing, output_pcc = comp_allclose_and_pcc(torch_out, tt_output_cpu_bf16, pcc=0.99, rtol=rtol, atol=atol) + logger.debug(f"Out passing={passing}") + logger.debug(f"Output pcc={output_pcc}") + diff_fp16 = torch.abs(torch_out - tt_output_cpu_bf16) + logger.debug(f"std={torch.std(diff_fp16)}") + logger.debug(f"mean={diff_fp16.mean()}") + logger.debug(f"topk(5) {torch.topk(diff_fp16.reshape(-1), 5)}") + assert diff.mean() < diff_fp16.mean() diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp index 03acb730526..e1b157c667d 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_getitem/device/moreh_getitem_device_operation.cpp @@ -34,8 +34,8 @@ void MorehGetItemOperation::validate_inputs( auto index_layout = index_tensor.get_layout(); if (index_layout == Layout::ROW_MAJOR) { TT_FATAL(index_shape.rank() == 1, "Index tensor must be 1D for ROW_MAJOR layout!"); - } else if (index_layout == Layout::TILE) { - TT_FATAL(index_shape.rank() == 5, "Index tensor must be 5D for TILE layout!"); + } else { + // nothing } TT_FATAL( !(input_layout == Layout::ROW_MAJOR && index_layout == Layout::TILE), diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/kernels/reader_moreh_bias_backward_h.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/kernels/reader_moreh_bias_backward_h.cpp index b42f9629ad3..3263b96d40a 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/kernels/reader_moreh_bias_backward_h.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/kernels/reader_moreh_bias_backward_h.cpp @@ -2,8 +2,8 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" - void kernel_main() { ArgFetcher arg_fetcher; const uint32_t src0_addr = arg_fetcher.get_next_arg_val(); @@ -17,16 +17,12 @@ void kernel_main() { const bool do_mask_w = (arg_fetcher.get_next_arg_val() == 1); constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; + constexpr uint32_t scaler = get_compile_time_arg_val(1); constexpr uint32_t cb_id_in0 = 0; constexpr uint32_t cb_id_scaler = 1; constexpr uint32_t cb_id_mask_h_w = 2; - union { - float f; - uint32_t u; - } scaler; - scaler.f = 1.0f; - fill_cb_with_value(cb_id_scaler, scaler.u); + generate_reduce_scaler(cb_id_scaler, scaler); if (do_mask_h || do_mask_w) { generate_mask_h_w(cb_id_mask_h_w, mask_h, mask_w); @@ -41,7 +37,7 @@ void kernel_main() { constexpr uint32_t onetile = 1; for (uint32_t wt = 0; wt < Wt_per_core; ++wt) { uint32_t read_tile_id = start_id + wt; - for (uint32_t b= 0; b < batch_num; ++b) { + for (uint32_t b = 0; b < batch_num; ++b) { cb_reserve_back(cb_id_in0, onetile); l1_write_addr_in0 = get_write_ptr(cb_id_in0); noc_async_read_tile(read_tile_id, s0, l1_write_addr_in0); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_multi_core_program_factory.cpp index 62c48d1b95b..6be91e1ba3f 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/moreh_linear_backward_multi_core_program_factory.cpp @@ -6,6 +6,7 @@ #include "moreh_linear_backward_device_operation.hpp" #include "tt_dnn/op_library/moreh_helper_functions.hpp" +#include "tt_metal/common/bfloat16.hpp" #include "tt_metal/common/work_split.hpp" #include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" @@ -63,7 +64,7 @@ MorehBiasAddBackwardOperation::MultiCoreProgramFactory::create( //////////////////////////////////////////////////////////////////////////// const uint32_t in0_t = 2; const uint32_t in1_t = 1; - const uint32_t in2_t = (do_mask_h || do_mask_w) ? 2 : 0; // mask_h_w + const uint32_t in2_t = 2; // mask_h_w const uint32_t out0_t = 1; const uint32_t im0_t = 1; @@ -84,9 +85,10 @@ MorehBiasAddBackwardOperation::MultiCoreProgramFactory::create( //////////////////////////////////////////////////////////////////////////// // DataMovementKernel SetUp //////////////////////////////////////////////////////////////////////////// - + const ::bfloat16 bfloat_scaler_value = ::bfloat16(1.0f); + const uint32_t packed_scaler_value = pack_two_bfloat16_into_uint32({bfloat_scaler_value, bfloat_scaler_value}); const std::vector reader_compile_time_args{ - static_cast(tt::operations::primary::is_dram(output_grad))}; + static_cast(tt::operations::primary::is_dram(output_grad)), packed_scaler_value}; const std::vector writer_compile_time_args{ static_cast(tt::operations::primary::is_dram(bias_grad))}; @@ -108,8 +110,10 @@ MorehBiasAddBackwardOperation::MultiCoreProgramFactory::create( std::map compute_defines; compute_defines["REDUCE_OP"] = "PoolType::SUM"; compute_defines["REDUCE_DIM"] = "ReduceDim::REDUCE_COL"; + std::vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); if (fp32_dest_acc_en) { compute_defines["FP32_DEST_ACC_EN"] = "1"; + unpack_to_dest_mode[tt::CB::c_intermed1] = UnpackToDestMode::UnpackToDestFp32; } const auto compute_kernel_file = "ttnn/cpp/ttnn/operations/moreh/moreh_linear_backward/device/kernels/moreh_bias_backward_multi_core_h.cpp"; @@ -121,7 +125,8 @@ MorehBiasAddBackwardOperation::MultiCoreProgramFactory::create( compute_defines, math_fidelity, fp32_dest_acc_en, - math_approx_mode); + math_approx_mode, + unpack_to_dest_mode); std::optional compute_kernel_2_id = std::nullopt; if (!core_group_2.ranges().empty()) { @@ -133,7 +138,8 @@ MorehBiasAddBackwardOperation::MultiCoreProgramFactory::create( compute_defines, math_fidelity, fp32_dest_acc_en, - math_approx_mode); + math_approx_mode, + unpack_to_dest_mode); } //////////////////////////////////////////////////////////////////////////// diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_w.cpp index ac560ac7d6c..ef0cbe17f92 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/moreh_mean_w.cpp @@ -4,6 +4,7 @@ #include +#include "compute_kernel_api/matmul.h" #include "compute_kernel_api/eltwise_binary.h" #include "compute_kernel_api/mask.h" #include "compute_kernel_api/reduce.h" @@ -47,13 +48,15 @@ void MAIN { if (!is_w_single_tile) { tile_regs_acquire(); - reduce_init_delta_with_dt(cb_accum_dst, cb_input, cb_scaler); for (uint32_t wt = 0; wt < Wt - 1; ++wt) { cb_wait_front(cb_input, onetile); - reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx); +#if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(cb_input, cb_scaler); +#endif + mm_init_short(cb_input, cb_scaler, false); + matmul_tiles(cb_input, cb_scaler, 0, 0, reduce_dst_idx, false); cb_pop_front(cb_input, onetile); } - reduce_revert_delta(cb_accum_dst); tile_regs_commit(); cb_reserve_back(cb_accum_dst, onetile); @@ -96,9 +99,11 @@ void MAIN { copy_tile(cb_accum_dst, 0, reduce_dst_idx); } - reduce_init_delta_with_dt(cb_out, cb_input, cb_scaler); - reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx); - reduce_revert_delta(cb_out); +#if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(cb_input, cb_scaler); +#endif + mm_init_short(cb_input, cb_scaler, false); + matmul_tiles(cb_input, cb_scaler, 0, 0, reduce_dst_idx, false); tile_regs_commit(); cb_reserve_back(cb_out, onetile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/reader_moreh_mean_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/reader_moreh_mean_w.cpp index 4f2c714da34..934d5a0e42b 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/reader_moreh_mean_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/kernels/reader_moreh_mean_w.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_mm_scaler.hpp" #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" void kernel_main() { @@ -13,29 +14,7 @@ void kernel_main() { constexpr uint32_t scaler = get_compile_time_arg_val(1); constexpr uint32_t cb_id_in2 = tt::CB::c_in2; - cb_reserve_back(cb_id_in2, 1); - constexpr uint32_t num_zeros_reads = 2048 / MEM_ZEROS_SIZE; - uint64_t zeros_noc_addr = get_noc_addr(MEM_ZEROS_BASE); - uint32_t write_addr = get_write_ptr(cb_id_in2); - // Fill tile with zeros - for (uint32_t i = 0; i < num_zeros_reads; ++i) { - noc_async_read(zeros_noc_addr, write_addr, MEM_ZEROS_SIZE); - write_addr += MEM_ZEROS_SIZE; - } - noc_async_read_barrier(); - if constexpr (scaler != 0) { - volatile tt_l1_ptr uint32_t* ptr = reinterpret_cast(get_write_ptr(cb_id_in2)); - uint32_t idx = 0; - for (uint32_t k = 0; k < 4; ++k) { - uint32_t curr_idx = idx; - for (uint32_t j = 0; j < 8; ++j) { - ptr[curr_idx] = scaler; - curr_idx++; - } - idx += 128; - } - } - cb_push_back(cb_id_in2, 1); + generate_mm_scaler(cb_id_in2, scaler); constexpr uint32_t cb_id_mask_w = tt::CB::c_in3; #ifdef DO_MASK_W diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_h_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_h_program_factory.cpp index 87dd9ca4f98..dee9313a7e3 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_h_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean/device/moreh_mean_h_program_factory.cpp @@ -62,7 +62,6 @@ MorehMeanOperation::MorehMeanHFactory::cached_program_t MorehMeanOperation::More tt::DataFormat data_format = datatype_to_dataformat_converter(input.get_dtype()); auto fp32_dest_acc_en_data_format = fp32_dest_acc_en ? tt::DataFormat::Float32 : data_format; - uint32_t num_input_tiles = 2; uint32_t num_output_tiles = 2; CreateCircularBuffer( @@ -112,21 +111,22 @@ MorehMeanOperation::MorehMeanHFactory::cached_program_t MorehMeanOperation::More auto reduce_op = ReduceOpMath::SUM; auto reduce_dim = ReduceOpDim::H; std::map compute_defines = reduce_op_utils::get_defines(reduce_op, reduce_dim); + std::vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); if (fp32_dest_acc_en) { compute_defines["FP32_DEST_ACC_EN"] = 1; + unpack_to_dest_mode[tt::CB::c_intermed0] = UnpackToDestMode::UnpackToDestFp32; } - vector compute_kernel_args_group_1 = { + std::vector compute_kernel_args_group_1 = { Ht, // Ht units_per_core_group_1, // Wt 1, // NC origin_H}; - vector compute_kernel_args_group_2 = { + std::vector compute_kernel_args_group_2 = { Ht, // Ht units_per_core_group_2, // Wt 1, // NC origin_H}; - vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); auto compute_kernel_ids = CreateComputeKernel( program, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp index a8c97375ea2..97ae2b2d842 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_mean_backward/device/moreh_mean_backward_program_factory.cpp @@ -170,8 +170,6 @@ MorehMeanBackwardOperation::MorehMeanBackwardFactory::create( tt::operations::primary::ComputeKernelConfig{ .math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, - // TODO(hyungsuk): change unpack_to_dest_mode from false to fp32_dest_acc_en after fix #10337 - // .unpack_to_dest_mode = fp32_dest_acc_en, .unpack_to_dest_mode = unpack_to_dest_mode, .math_approx_mode = math_approx_mode, .defines = compute_defines}); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp index b8e13b6a297..7f579a43674 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_h_program_factory.cpp @@ -160,8 +160,10 @@ MorehSumOperation::MorehSumHFactory::cached_program_t MorehSumOperation::MorehSu 1, // NC origin_H}; - // set unpack_to_dest_mode to the same value as fp32_dest_acc_en vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); + if (fp32_dest_acc_en) { + unpack_to_dest_mode[tt::CB::c_intermed0] = UnpackToDestMode::UnpackToDestFp32; + } auto reduce_compute_kernel_group_1_id = tt::tt_metal::CreateKernel( program, compute_kernel_name, diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_impl_kernels/moreh_sum_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_impl_kernels/moreh_sum_w.cpp index d93f93814aa..006c096db34 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_impl_kernels/moreh_sum_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_impl_kernels/moreh_sum_w.cpp @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 +#include "compute_kernel_api/matmul.h" #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp" namespace NAMESPACE { @@ -11,7 +12,6 @@ void MAIN { uint32_t NC = get_compile_time_arg_val(2); constexpr uint32_t origin_W = get_compile_time_arg_val(3); - auto cb_input = tt::CB::c_in0; constexpr auto cb_scaler = tt::CB::c_in2; constexpr auto cb_mask_w = tt::CB::c_in3; @@ -44,22 +44,20 @@ void MAIN { tile_regs_acquire(); for (uint32_t wt = 0; wt < Wt - 1; ++wt) { cb_wait_front(cb_input, onetile); - - #if defined FP32_DEST_ACC_EN - unpack_reconfig_data_format(cb_input, cb_scaler); - #endif - reduce_init_delta(); - reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx); - reduce_revert_delta(); +#if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(cb_input, cb_scaler); +#endif + mm_init_short(cb_input, cb_scaler, false); + matmul_tiles(cb_input, cb_scaler, 0, 0, reduce_dst_idx, false); cb_pop_front(cb_input, onetile); } tile_regs_commit(); cb_reserve_back(cb_accum_dst, onetile); tile_regs_wait(); - #if defined FP32_DEST_ACC_EN - pack_reconfig_data_format(cb_accum_dst); - #endif +#if defined FP32_DEST_ACC_EN + pack_reconfig_data_format(cb_accum_dst); +#endif pack_tile(reduce_dst_idx, cb_accum_dst); tile_regs_release(); cb_push_back(cb_accum_dst, onetile); @@ -68,9 +66,9 @@ void MAIN { if (do_mask_w) { tile_regs_acquire(); cb_wait_front(cb_input, onetile); - #if defined FP32_DEST_ACC_EN - unpack_reconfig_data_format_srca(cb_input); - #endif +#if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format_srca(cb_input); +#endif copy_tile_to_dst_init_short(cb_input); copy_tile(cb_input, 0, reduce_dst_idx); copy_tile(cb_mask_w, 0, mask_dst_idx); @@ -80,9 +78,9 @@ void MAIN { cb_reserve_back(cb_masked_input, onetile); tile_regs_wait(); - #if defined FP32_DEST_ACC_EN - pack_reconfig_data_format(cb_masked_input); - #endif +#if defined FP32_DEST_ACC_EN + pack_reconfig_data_format(cb_masked_input); +#endif pack_tile(reduce_dst_idx, cb_masked_input); tile_regs_release(); cb_push_back(cb_masked_input, onetile); @@ -94,27 +92,26 @@ void MAIN { tile_regs_acquire(); cb_wait_front(cb_input, onetile); if (!is_w_single_tile) { - #if defined FP32_DEST_ACC_EN - unpack_reconfig_data_format_srca(cb_accum_dst); - #endif +#if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format_srca(cb_accum_dst); +#endif cb_wait_front(cb_accum_dst, onetile); copy_tile_to_dst_init_short(cb_accum_dst); copy_tile(cb_accum_dst, 0, reduce_dst_idx); } - #if defined FP32_DEST_ACC_EN - unpack_reconfig_data_format(cb_input, cb_scaler); - #endif - reduce_init_delta(); - reduce_tile(cb_input, cb_scaler, 0, 0, reduce_dst_idx); - reduce_revert_delta(); +#if defined FP32_DEST_ACC_EN + unpack_reconfig_data_format(cb_input, cb_scaler); +#endif + mm_init_short(cb_input, cb_scaler, false); + matmul_tiles(cb_input, cb_scaler, 0, 0, reduce_dst_idx, false); tile_regs_commit(); cb_reserve_back(cb_out, onetile); tile_regs_wait(); - #if defined FP32_DEST_ACC_EN - pack_reconfig_data_format(cb_out); - #endif +#if defined FP32_DEST_ACC_EN + pack_reconfig_data_format(cb_out); +#endif pack_tile(reduce_dst_idx, cb_out); tile_regs_release(); cb_push_back(cb_out, onetile); diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_impl_kernels/reader_moreh_sum_w.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_impl_kernels/reader_moreh_sum_w.cpp index a40833f38e1..2b82a8908b1 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_impl_kernels/reader_moreh_sum_w.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_impl_kernels/reader_moreh_sum_w.cpp @@ -3,7 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 #include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/moreh_common.hpp" -#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_reduce_scaler.hpp" +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_mm_scaler.hpp" void kernel_main() { uint32_t src_addr = get_arg_val(0); @@ -14,7 +14,7 @@ void kernel_main() { constexpr uint32_t scaler = get_compile_time_arg_val(1); constexpr uint32_t cb_id_in2 = 2; - generate_reduce_scaler(cb_id_in2, scaler); + generate_mm_scaler(cb_id_in2, scaler); constexpr uint32_t cb_id_mask_w = 3; #ifdef DO_MASK_W diff --git a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp index f5d5dd37adb..55e40f80d59 100644 --- a/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/moreh/moreh_sum/device/moreh_sum_w_program_factory.cpp @@ -146,15 +146,17 @@ MorehSumOperation::MorehSumWFactory::cached_program_t MorehSumOperation::MorehSu if (fp32_dest_acc_en) { reduce_defines["FP32_DEST_ACC_EN"] = "1"; } - vector compute_kernel_args_group_1 = { + std::vector compute_kernel_args_group_1 = { num_rows_per_core_group_1, // Ht Wt, // Wt 1, // NC origin_W, }; - // set unpack_to_dest_mode to the same value as fp32_dest_acc_en - vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); + std::vector unpack_to_dest_mode(NUM_CIRCULAR_BUFFERS, UnpackToDestMode::Default); + if (fp32_dest_acc_en) { + unpack_to_dest_mode[tt::CB::c_intermed0] = UnpackToDestMode::UnpackToDestFp32; + } auto reduce_compute_kernel_group_1_id = tt::tt_metal::CreateKernel( program, compute_kernel_name, @@ -168,7 +170,7 @@ MorehSumOperation::MorehSumWFactory::cached_program_t MorehSumOperation::MorehSu .defines = reduce_defines}); if (!core_group_2.ranges().empty()) { - vector compute_kernel_args_group_2 = { + std::vector compute_kernel_args_group_2 = { num_rows_per_core_group_2, // Ht Wt, // Wt 1, // NC From ab725abcb4ddec8e5239337b9313f7718eb0a4e5 Mon Sep 17 00:00:00 2001 From: Naif Tarafdar <135640067+ntarafdar@users.noreply.github.com> Date: Wed, 9 Oct 2024 23:05:09 -0700 Subject: [PATCH 58/58] Attach Golden Function (already implemented ) For Reshape (accidentally unattached in previous PR) (#13614) #13613: add golden function for reshape Co-authored-by: Saad Jameel <163029024+sjameelTT@users.noreply.github.com> --- ttnn/ttnn/operations/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 8a1e2beb603..45e024ac82f 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -143,6 +143,7 @@ def _postprocess_golden_function_outputs(output, args, kwargs): ttnn.Shape([32, 64]) """ +ttnn.attach_golden_function(ttnn.reshape, golden_function=_golden_function) # TODO(arakhmati): remove this once underlying C++ code can handle non-4D shapes ttnn.register_python_operation(name="ttnn.unsqueeze_to_4D")(ttnn._ttnn.operations.core.unsqueeze_to_4D)