diff --git a/README.md b/README.md index d97293bb9de0..ca1b108b91e0 100644 --- a/README.md +++ b/README.md @@ -24,10 +24,10 @@ | Model | Batch | End-to-end throughput [1] | Device throughput [2] | Target | |---------------------------------------------------------- |---------------------|------------------------------|-----------------------------|-------------------------------------| -| [ResNet-50](./models/demos/resnet) (fps) | 20 | 2,850 | 7,200 | 10,000 | +| [ResNet-50](./models/demos/resnet) (fps) | 20 | 4,400 | 7,700 | 10,000 | | [BERT-Large](./models/demos/bert) (sen/s) | 12 | 362 | 406 | 410 | | [Falcon7B-decode](./models/demos/ttnn_falcon7b) (t/s) | 32 | 135 | 135 | 140 | -| [ViT](./models/demos/grayskull/vit) (fps) | 8 | 480 | 1570 | 2000 | +| [ViT](./models/demos/grayskull/vit) (fps) | 8 | 860 | 1570 | 2000 | | [T5 small](.models/demos/grayskull/t5) (sen/s) | | 140 | | | | [Bloom](.models/demos/grayskull/functional_bloom) (sen/s) | | 70 | | | | U-Net | coming soon | | | | @@ -42,13 +42,13 @@ > > All model demos in this table function on both N150 and N300 Wormhole cards, unless otherwise stated. -| Model | Gen. Token [3] | Batch | End-to-end throughput [1] | Device throughput [2] | Target | -|-------------------------------------------------------------|--------------------|----------------------|------------------------------|-----------------------------|----------------| -| [Falcon7B-decode](./models/demos/wormhole/falcon7b) | 129th | 32 | 11.6 t/s/u - 371 t/s | 15.4 t/s/u - 493 t/s | 21 t/s/u | -| [Mistral-7B-decode](./models/demos/wormhole/mistral7b) | 33rd | 32 | 10.9 t/s/u - 349 t/s | 13.3 t/s/u - 426 t/s | 21 t/s/u | -| [Mamba-2.8B-decode](./models/demos/mamba) | any | 32 | 9.2 t/s/u - 295 t/s | 13.1 t/s/u - 419 t/s | 22 t/s/u | -| [BERT-Large](./models/demos/metal_BERT_large_11/) (sen/s) [4] | any | 8 | 270 | 340 | 400 | -| [Stable Diffusion 1.4](./models/demos/wormhole/stable_diffusion) 512x512 (sec/img) | | 1 | 8s | 5s | | +| Model | Gen. Token [3] | Batch | End-to-end throughput [1] | Device throughput [2] | Target | +|--------------------------------------------------------------------------------------|--------------------|----------------------|------------------------------|-----------------------------|----------------| +| [Falcon7B-decode](./models/demos/wormhole/falcon7b) | 129th | 32 | 11.6 t/s/u - 371 t/s | 15.4 t/s/u - 493 t/s | 21 | +| [Mistral-7B-decode](./models/demos/wormhole/mistral7b) | 33rd | 32 | 10.9 t/s/u - 349 t/s | 13.3 t/s/u - 426 t/s | 21 | +| [Mamba-2.8B-decode](./models/demos/mamba) | any | 32 | 9.2 t/s/u - 295 t/s | 13.1 t/s/u - 419 t/s | 22 | +| [BERT-Large](./models/demos/metal_BERT_large_11/) (sen/s) [4] | | 8 | 270 | 340 | 400 | +| [Stable Diffusion 1.4](./models/demos/wormhole/stable_diffusion) 512x512 (sec/img) | | 1 | 8 | 5 | | [1] - Observed from the host. Includes dispatch overhead and kernel execution time. diff --git a/conftest.py b/conftest.py index 7df64b2c7505..6c617cc1e7a8 100644 --- a/conftest.py +++ b/conftest.py @@ -326,9 +326,6 @@ def device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_devices_requested = len(device_ids) - if num_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh(ttnn.DeviceGrid(1, num_devices_requested), device_ids[:num_devices_requested]) logger.debug(f"multidevice with {device_mesh.get_num_devices()} devices is created") @@ -354,9 +351,6 @@ def pcie_device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_pcie_devices_requested = len(device_ids) - if num_pcie_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh( ttnn.DeviceGrid(1, num_pcie_devices_requested), device_ids[:num_pcie_devices_requested] ) @@ -386,9 +380,6 @@ def t3k_device_mesh(request, silicon_arch_name, silicon_arch_wormhole_b0): except (ValueError, AttributeError): num_devices_requested = len(device_ids) - if num_devices_requested <= 1: - pytest.skip("Requires multiple devices to run") - device_mesh = ttnn.open_device_mesh(ttnn.DeviceGrid(1, num_devices_requested), device_ids[:num_devices_requested]) logger.debug(f"multidevice with {device_mesh.get_num_devices()} devices is created") diff --git a/dockerfile/ubuntu-20.04-x86.Dockerfile b/dockerfile/ubuntu-20.04-amd64.Dockerfile similarity index 56% rename from dockerfile/ubuntu-20.04-x86.Dockerfile rename to dockerfile/ubuntu-20.04-amd64.Dockerfile index bdb5cb7d8697..a5ca82f1d762 100644 --- a/dockerfile/ubuntu-20.04-x86.Dockerfile +++ b/dockerfile/ubuntu-20.04-amd64.Dockerfile @@ -1,4 +1,4 @@ -# Second stage: the actual image +# TT-METAL UBUNTU 20.04 AMD64 DOCKERFILE FROM ubuntu:20.04 ARG DEBIAN_FRONTEND=noninteractive @@ -25,16 +25,19 @@ RUN /bin/bash /opt/tt_metal_infra/scripts/docker/install_test_deps.sh ${GTEST_VE COPY /scripts /opt/tt_metal_infra/scripts COPY build_metal.sh /scripts/build_metal.sh -# ENV TT_METAL_INFRA_DIR=/opt/tt_metal_infra -# ENV PYTHON_ENV_DIR=${TT_METAL_INFRA_DIR}/tt-metal/python_env -# RUN python3 -m venv $PYTHON_ENV_DIR +# Setup Env variables to setup Python Virtualenv - Install TT-Metal Python deps +ENV TT_METAL_INFRA_DIR=/opt/tt_metal_infra +ENV PYTHON_ENV_DIR=${TT_METAL_INFRA_DIR}/tt-metal/python_env +RUN python3 -m venv $PYTHON_ENV_DIR +ENV PATH="$PYTHON_ENV_DIR/bin:$PATH" -# COPY /docs/requirements-docs.txt ${TT_METAL_INFRA_DIR}/tt-metal/docs/. -# COPY /tt_metal/python_env/* ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/. -# ENV PATH="$PYTHON_ENV_DIR/bin:$PATH" -# RUN python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu \ -# && python3 -m pip install setuptools wheel +# Copy requirements from tt-metal folders with requirements.txt docs +COPY /docs/requirements-docs.txt ${TT_METAL_INFRA_DIR}/tt-metal/docs/. +COPY /tt_metal/python_env/* ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/. +RUN python3 -m pip config set global.extra-index-url https://download.pytorch.org/whl/cpu \ + && python3 -m pip install setuptools wheel -# RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/requirements-dev.txt +RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/tt_metal/python_env/requirements-dev.txt +RUN python3 -m pip install -r ${TT_METAL_INFRA_DIR}/tt-metal/docs/requirements-docs.txt CMD ["tail", "-f", "/dev/null"] diff --git a/models/demos/mamba/demo/demo.py b/models/demos/mamba/demo/demo.py index fb95f1ececda..e798f2973348 100644 --- a/models/demos/mamba/demo/demo.py +++ b/models/demos/mamba/demo/demo.py @@ -28,13 +28,8 @@ def get_tt_metal_model( from models.demos.mamba.tt import model_config reference_model = get_cpu_reference_model(version, batch_size=batch_size) - if cache_dir: - cache_path = model_config.get_weights_cache_path(version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch_size, reference_model.args.d_model) - model = MambaTT(reference_model, device, config, tt_cache_path=cache_path) + model = MambaTT(reference_model, device, config, tt_cache_path=cache_dir) return model @@ -89,6 +84,7 @@ def run_mamba_demo( assert batch_size == len(prompts), "32 prompts are required" logger.info(f"Running Mamba demo (weights='{model_version}') with batch={batch_size}") + logger.info(f"Using tensor cache at '{cache_dir}'") model = get_tt_metal_model(model_version, device, cache_dir, batch_size) @@ -129,8 +125,18 @@ def run_mamba_demo( @pytest.mark.parametrize( - "max_gen_len", - ([100]), + "model_version, max_gen_len", + ( + ( + "state-spaces/mamba-2.8b-slimpj", + 100, + ), + ), ) -def test_demo(user_input, device, use_program_cache, max_gen_len): - return run_mamba_demo(prompts=user_input, device=device, generated_sequence_length=max_gen_len) +def test_demo(user_input, device, use_program_cache, get_tt_cache_path, model_version, max_gen_len): + return run_mamba_demo( + prompts=user_input, + device=device, + cache_dir=get_tt_cache_path(model_version), + generated_sequence_length=max_gen_len, + ) diff --git a/models/demos/mamba/tests/test_full_model.py b/models/demos/mamba/tests/test_full_model.py index afbdca353e80..c0a5fac3c6c9 100644 --- a/models/demos/mamba/tests/test_full_model.py +++ b/models/demos/mamba/tests/test_full_model.py @@ -46,9 +46,9 @@ def run_inference( model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], num_layers: int, iterations: int, + cache_dir: Optional[str], ): torch.manual_seed(10) @@ -64,13 +64,8 @@ def run_inference( with torch.no_grad(): reference_output = mamba_model_pytorch(input_ids) - if cache_dir: - cache_path = model_config.get_weights_cache_path(model_version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch, reference_model.args.d_model) - mamba_model_tt = MambaTT(reference_model, device, config, tt_cache_path=cache_path, num_layers=num_layers) + mamba_model_tt = MambaTT(reference_model, device, config, tt_cache_path=cache_dir, num_layers=num_layers) for _ in range(iterations): tt_output = mamba_model_tt(input_ids) @@ -87,13 +82,12 @@ def run_inference( @skip_for_grayskull("Not supported on Grayskull") @pytest.mark.parametrize( - "model_version, batch, pcc, cache_dir, num_layers, iterations", + "model_version, batch, pcc, num_layers, iterations", ( ( "state-spaces/mamba-2.8b", 32, - 0.985, - None, + 0.98, 64, 1, ), @@ -102,14 +96,23 @@ def run_inference( def test_inference( device: ttnn.Device, use_program_cache, + get_tt_cache_path, model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], num_layers: int, iterations: int, ): - run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations) + run_inference( + device, + use_program_cache, + model_version, + batch, + pcc, + num_layers, + iterations, + cache_dir=get_tt_cache_path(model_version), + ) @skip_for_grayskull("Not supported on Grayskull") @@ -120,11 +123,20 @@ def test_inference( def test_device_perf( device: ttnn.Device, use_program_cache, + get_tt_cache_path, iterations, model_version="state-spaces/mamba-2.8b", batch=32, pcc=0.97, - cache_dir=None, num_layers=1, ): - run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations) + run_inference( + device, + use_program_cache, + model_version, + batch, + pcc, + num_layers, + iterations, + cache_dir=get_tt_cache_path(model_version), + ) diff --git a/models/demos/mamba/tests/test_full_model_loop.py b/models/demos/mamba/tests/test_full_model_loop.py deleted file mode 100644 index 532e9f509cf5..000000000000 --- a/models/demos/mamba/tests/test_full_model_loop.py +++ /dev/null @@ -1,22 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn - -from models.demos.mamba.tests.test_full_model import run_inference -from models.utility_functions import skip_for_grayskull - - -@skip_for_grayskull("Not supported on Grayskull") -def test_inference_loop( - device: ttnn.Device, - use_program_cache, - model_version="state-spaces/mamba-2.8b", - batch=32, - pcc=0.88, - cache_dir=None, - num_layers=64, - iterations=10, -): - run_inference(device, use_program_cache, model_version, batch, pcc, cache_dir, num_layers, iterations) diff --git a/models/demos/mamba/tests/test_mamba_block.py b/models/demos/mamba/tests/test_mamba_block.py index 0589e551d2ea..8d118a26b26b 100644 --- a/models/demos/mamba/tests/test_mamba_block.py +++ b/models/demos/mamba/tests/test_mamba_block.py @@ -10,7 +10,6 @@ from models.demos.mamba.tt.full_model import TtTensorLoader from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName from models.demos.mamba.tt.mamba_block import TtMambaBlock -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer from models.demos.mamba.tt import model_config from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_allclose, @@ -30,13 +29,12 @@ def forward(self, x): @pytest.mark.parametrize( - "model_version, batch, pcc, cache_dir", + "model_version, batch, pcc", ( ( "state-spaces/mamba-2.8b", 32, 0.99, - None, ), ), ) @@ -46,7 +44,6 @@ def test_mamba_block_inference( model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], ): torch.manual_seed(0) @@ -63,19 +60,11 @@ def test_mamba_block_inference( residual_block = reference_model.layers[LAYER_NUM] assert not isinstance(residual_block, torch.Tensor), "Expected torch.Module" - if cache_dir: - cache_path = model_config.get_weights_cache_path(model_version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch, d_model) - loader = TtTensorLoader(reference_model.state_dict(), device, tt_cache_path=cache_path) - transformer = MambaSsmBlockTransformer( - device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 - ) + loader = TtTensorLoader(reference_model.state_dict(), device) - model = TtMambaBlock(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM), transformer) + model = TtMambaBlock(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM)) tt_input = input.view(1, 1, batch, d_model) tt_input = ttnn.to_device( ttnn.from_torch(tt_input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), diff --git a/models/demos/mamba/tests/test_mamba_demo.py b/models/demos/mamba/tests/test_mamba_demo.py index d14b07571eb3..21a8ed6734b9 100644 --- a/models/demos/mamba/tests/test_mamba_demo.py +++ b/models/demos/mamba/tests/test_mamba_demo.py @@ -7,8 +7,15 @@ @pytest.mark.parametrize( - "user_input, max_gen_len", - ((["Hello World"], 2),), + "user_input, model_version, max_gen_len", + ((["Hello World"], "state-spaces/mamba-2.8b-slimpj", 2),), ) -def test_demo(user_input, device, use_program_cache, max_gen_len): - return run_mamba_demo(prompts=user_input, device=device, generated_sequence_length=max_gen_len, display=False) +def test_demo(user_input, model_version, device, use_program_cache, get_tt_cache_path, max_gen_len): + return run_mamba_demo( + prompts=user_input, + model_version=model_version, + device=device, + generated_sequence_length=max_gen_len, + display=False, + cache_dir=get_tt_cache_path(model_version), + ) diff --git a/models/demos/mamba/tests/test_mamba_perf.py b/models/demos/mamba/tests/test_mamba_perf.py index 1563a29d00bc..e83e3ac4976c 100644 --- a/models/demos/mamba/tests/test_mamba_perf.py +++ b/models/demos/mamba/tests/test_mamba_perf.py @@ -27,7 +27,14 @@ ((32, 10, 12.5, 0.40),), # Issue 7816 Compile time ) def test_mamba_e2e_perf( - device, batch, iterations, expected_compile_time, expected_inference_time, use_program_cache, reset_seeds + device, + batch, + iterations, + expected_compile_time, + expected_inference_time, + use_program_cache, + reset_seeds, + get_tt_cache_path, ): model_version = "state-spaces/mamba-2.8b-slimpj" display_decoded_seq = False @@ -46,7 +53,7 @@ def test_mamba_e2e_perf( profiler.end("pytorch_ref_model_setup") profiler.start("tt_model_setup") - tt_model = get_tt_metal_model(model_version, device, cache_dir=None, batch_size=batch) + tt_model = get_tt_metal_model(model_version, device, cache_dir=get_tt_cache_path(model_version), batch_size=batch) profiler.end("tt_model_setup") sequences: torch.Tensor = tokenizer(prompts, return_tensors="pt", padding=True).input_ids diff --git a/models/demos/mamba/tests/test_mamba_ssm.py b/models/demos/mamba/tests/test_mamba_ssm.py index 43d5b66ac3ef..bc489d5b7be9 100644 --- a/models/demos/mamba/tests/test_mamba_ssm.py +++ b/models/demos/mamba/tests/test_mamba_ssm.py @@ -10,7 +10,6 @@ from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName from models.demos.mamba.tt.full_model import TtTensorLoader from models.demos.mamba.tt.mamba_one_step_ssm import TtMambaSSM -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer from models.demos.mamba.tt import model_config from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( comp_allclose, @@ -30,13 +29,12 @@ def forward(self, x): @pytest.mark.parametrize( - "model_version, batch, pcc, cache_dir", + "model_version, batch, pcc", ( ( "state-spaces/mamba-2.8b", 32, 0.99, - None, ), ), ) @@ -46,7 +44,6 @@ def test_mamba_ssm_inference( model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], ): torch.manual_seed(0) @@ -63,19 +60,11 @@ def test_mamba_ssm_inference( residual_block = reference_model.layers[LAYER_NUM] assert not isinstance(residual_block, torch.Tensor), "Expected torch.Module" - if cache_dir: - cache_path = model_config.get_weights_cache_path(model_version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch, reference_model.args.d_model) - loader = TtTensorLoader(reference_model.state_dict(), device, tt_cache_path=cache_path) - transformer = MambaSsmBlockTransformer( - device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 - ) + loader = TtTensorLoader(reference_model.state_dict(), device) - model = TtMambaSSM(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM), transformer) + model = TtMambaSSM(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM)) tt_input = input.view(1, 1, batch, d_in) tt_input = ttnn.to_device( ttnn.from_torch(tt_input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), diff --git a/models/demos/mamba/tests/test_residual_block.py b/models/demos/mamba/tests/test_residual_block.py index 16e521c70717..005eba21ed13 100644 --- a/models/demos/mamba/tests/test_residual_block.py +++ b/models/demos/mamba/tests/test_residual_block.py @@ -7,7 +7,7 @@ from loguru import logger from typing import Optional import ttnn -from models.demos.mamba.tt.full_model import TtTensorLoader, MambaSsmBlockTransformer +from models.demos.mamba.tt.full_model import TtTensorLoader from models.demos.mamba.reference.decode_model import MambaDecode, MambaPretrainedModelName from models.demos.mamba.tt.residual_block import TtResidualBlock from models.demos.mamba.tt import model_config @@ -29,13 +29,12 @@ def forward(self, x): @pytest.mark.parametrize( - "model_version, batch, pcc, cache_dir", + "model_version, batch, pcc", ( ( "state-spaces/mamba-2.8b", 32, 0.99, - None, ), ), ) @@ -45,7 +44,6 @@ def test_mamba_residual_block_inference( model_version: MambaPretrainedModelName, batch: int, pcc: float, - cache_dir: Optional[str], ): torch.manual_seed(0) @@ -62,19 +60,11 @@ def test_mamba_residual_block_inference( residual_block = reference_model.layers[LAYER_NUM] assert not isinstance(residual_block, torch.Tensor), "Expected torch.Module" - if cache_dir: - cache_path = model_config.get_weights_cache_path(model_version, cache_dir) - else: - cache_path = None - config = model_config.create_model_config(batch, d_model) - loader = TtTensorLoader(reference_model.state_dict(), device, tt_cache_path=cache_path) - transformer = MambaSsmBlockTransformer( - device, batch, reference_model.args.d_inner, reference_model.args.d_state * 2 - ) + loader = TtTensorLoader(reference_model.state_dict(), device) - model = TtResidualBlock(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM), transformer) + model = TtResidualBlock(reference_model.args, device, config, loader.get_tensor_loader(LAYER_NUM)) tt_input = input.view(1, 1, batch, d_model) tt_input = ttnn.to_device( ttnn.from_torch(tt_input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), diff --git a/models/demos/mamba/tests/test_transforms.py b/models/demos/mamba/tests/test_transforms.py deleted file mode 100644 index 0e94ec769081..000000000000 --- a/models/demos/mamba/tests/test_transforms.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import torch -import pytest - -import ttnn -import tt_lib as ttl - -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer -from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( - comp_allclose, - comp_pcc, -) - -N = 32 -HIDDEN_SIZE = 2560 - - -@pytest.mark.parametrize( - "batch, pcc", - ( - ( - 32, - 0.99, - ), - ), -) -def test_mamba_ssm_block_repeat_interleave( - device: ttnn.Device, - use_program_cache, - batch: int, - pcc: float, -): - input = torch.rand(1, 1, batch, HIDDEN_SIZE * 2) - - expected = torch.repeat_interleave(input, N, dim=3) - - transformer = MambaSsmBlockTransformer(device, batch, HIDDEN_SIZE * 2, N) - input = ttnn.to_device( - ttnn.from_torch(input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), - device=device, - memory_config=ttnn.L1_MEMORY_CONFIG, - ) - actual = transformer.repeat_interleave( - input, - memory_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) - - assert list(actual.get_legacy_shape()) == [1, 1, batch, 2 * HIDDEN_SIZE * N] - - actual = ttnn.to_torch(actual) - passing_pcc, output_pcc = comp_pcc(actual, expected, 0.9999) - assert passing_pcc - - -@pytest.mark.parametrize( - "batch, pcc", - ( - ( - 32, - 0.99, - ), - ), -) -def test_mamba_ssm_block_repeat( - device: ttnn.Device, - batch: int, - pcc: float, - use_program_cache, -): - input = torch.rand(1, 1, batch, N) - - # (1, 1, B, n) -> (1, 1, B, hidden * 2 * n) - expected = input.repeat((1, 1, 1, HIDDEN_SIZE * 2)) - - transformer = MambaSsmBlockTransformer(device, batch, HIDDEN_SIZE * 2, N) - input = ttnn.to_device( - ttnn.from_torch(input, layout=ttnn.TILE_LAYOUT, dtype=ttnn.bfloat16), - device=device, - memory_config=ttnn.L1_MEMORY_CONFIG, - ) - actual = transformer.repeat( - input, - memory_config=ttl.tensor.MemoryConfig(ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1), - ) - - assert list(actual.get_legacy_shape()) == [1, 1, batch, 2 * HIDDEN_SIZE * N] - - actual = ttnn.to_torch(actual) - passing_pcc, output_pcc = comp_pcc(actual, expected, 0.9999) - assert passing_pcc diff --git a/models/demos/mamba/tt/full_model.py b/models/demos/mamba/tt/full_model.py index 0c3c3438ac97..509eb6ff6d32 100644 --- a/models/demos/mamba/tt/full_model.py +++ b/models/demos/mamba/tt/full_model.py @@ -4,6 +4,7 @@ import torch import ttnn +import tt_lib as ttl from loguru import logger @@ -11,7 +12,6 @@ from typing import Callable, Optional from models.demos.mamba.tt.residual_block import TtResidualBlock -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer class TtTensorLoader: @@ -64,7 +64,12 @@ def load_tt_tensor( class MambaTT(torch.nn.Module): def __init__( - self, reference_model, device: ttnn.Device, configs, tt_cache_path: Optional[str] = None, num_layers=None + self, + reference_model, + device: ttnn.Device, + configs, + tt_cache_path: Optional[str] = None, + num_layers=None, ): super().__init__() self.args = reference_model.args @@ -81,13 +86,9 @@ def __init__( self.embedding = reference_model.embedding loader = TtTensorLoader(reference_model.state_dict(), self.device, tt_cache_path=tt_cache_path) - transformer = MambaSsmBlockTransformer( - self.device, self.args.batch_size, self.args.d_inner, configs["latent_size"] - ) self.layers = [ - TtResidualBlock(self.args, device, configs, loader.get_tensor_loader(i), transformer) - for i in range(self.num_layers) + TtResidualBlock(self.args, device, configs, loader.get_tensor_loader(i)) for i in range(self.num_layers) ] load_fn = loader.get_tensor_loader() @@ -100,6 +101,11 @@ def __init__( lambda x: x.transpose(-1, -2), tt_dtype=ttnn.bfloat16, ) + self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=ttl.tensor.MathFidelity.HiFi2, + math_approx_mode=False, + fp32_dest_acc_en=True, + ) def forward(self, x): assert len(x.shape) == 2, f"Mamba expects inputs to be rank 2 (was {len(x.shape)})" @@ -114,7 +120,7 @@ def forward(self, x): device=self.device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG, - dtype=ttnn.bfloat16, + dtype=self.configs["dtype"]["activations"], ) for layer in self.layers: @@ -134,7 +140,8 @@ def forward(self, x): self.lm_head_weights, memory_config=ttnn.L1_MEMORY_CONFIG, use_1d_systolic_array=True, - core_grid=ttnn.CoreGrid(y=7, x=8), + compute_kernel_config=self.compute_kernel_config, + dtype=self.configs["dtype"]["activations"], ) x = ttnn.to_torch(x).to(torch.float32) # (1, 1, B, E) diff --git a/models/demos/mamba/tt/mamba_block.py b/models/demos/mamba/tt/mamba_block.py index 5dd3ab55ec30..c2fd778f8ea9 100644 --- a/models/demos/mamba/tt/mamba_block.py +++ b/models/demos/mamba/tt/mamba_block.py @@ -10,11 +10,10 @@ from models.demos.mamba.reference.args import ModelArgs from models.demos.mamba.tt.mamba_one_step_ssm import TtMambaSSM -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer class TtMambaBlock(torch.nn.Module): - def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transformer: MambaSsmBlockTransformer): + def __init__(self, args: ModelArgs, device, configs, load_fn: Callable): super().__init__() self.device = device @@ -76,15 +75,13 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transfor ) ) - self.tt_ssm = TtMambaSSM(self.args, self.device, configs, load_fn, transformer) + self.tt_ssm = TtMambaSSM(self.args, self.device, configs, load_fn) self.compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( math_fidelity=ttl.tensor.MathFidelity.HiFi3, math_approx_mode=False, fp32_dest_acc_en=True, ) - self.core_grid_row = 4 - self.core_grid_col = 8 def forward(self, x): assert len(x.shape) == 4, "Mamba block expects inputs to be rank 4" @@ -97,7 +94,7 @@ def forward(self, x): memory_config=ttnn.L1_MEMORY_CONFIG, compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, - core_grid=ttnn.CoreGrid(y=4, x=8), + dtype=self.configs["dtype"]["activations"], ) # shift the states leftward @@ -112,24 +109,38 @@ def forward(self, x): # do the convolution conv1d_wt = ttnn.to_memory_config(self.conv1d_weights[0], memory_config=self.configs["sharded_d"]) conv_state = ttnn.to_memory_config(self.conv_states[0], memory_config=self.configs["sharded_d"]) - conv_accumulator = ttnn.mul(conv_state, conv1d_wt, memory_config=self.configs["sharded_d"]) + conv_accumulator = ttnn.mul( + conv_state, conv1d_wt, memory_config=self.configs["sharded_d"], dtype=self.configs["dtype"]["activations"] + ) ttnn.deallocate(conv1d_wt) ttnn.deallocate(conv_state) for i in range(1, 4): conv1d_wt = ttnn.to_memory_config(self.conv1d_weights[i], memory_config=self.configs["sharded_d"]) conv_state = ttnn.to_memory_config(self.conv_states[i], memory_config=self.configs["sharded_d"]) - prod = ttnn.mul(conv_state, conv1d_wt, memory_config=self.configs["sharded_d"]) + prod = ttnn.mul( + conv_state, + conv1d_wt, + memory_config=self.configs["sharded_d"], + dtype=self.configs["dtype"]["activations"], + ) ttnn.deallocate(conv1d_wt) ttnn.deallocate(conv_state) - conv_out = ttnn.add(conv_accumulator, prod, memory_config=self.configs["sharded_d"]) + conv_out = ttnn.add( + conv_accumulator, + prod, + memory_config=self.configs["sharded_d"], + dtype=self.configs["dtype"]["activations"], + ) ttnn.deallocate(conv_accumulator) ttnn.deallocate(prod) conv_accumulator = conv_out conv1d_bias = ttnn.to_memory_config(self.conv1d_bias, memory_config=self.configs["sharded_d"]) - conv_out_with_bias = ttnn.add(conv_out, conv1d_bias, memory_config=self.configs["sharded_d"]) + conv_out_with_bias = ttnn.add( + conv_out, conv1d_bias, memory_config=self.configs["sharded_d"], dtype=self.configs["dtype"]["activations"] + ) ttnn.deallocate(conv_out) ttnn.deallocate(conv1d_bias) @@ -143,16 +154,21 @@ def forward(self, x): residual_connection, self.mlp_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG, - core_grid=ttnn.CoreGrid(y=4, x=8), compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, + dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(residual_connection) residual_with_silu = ttnn.silu(residual, memory_config=ttnn.L1_MEMORY_CONFIG) ttnn.deallocate(residual) - out = ttnn.mul(ssm_output, residual_with_silu, memory_config=ttnn.L1_MEMORY_CONFIG) + out = ttnn.mul( + ssm_output, + residual_with_silu, + memory_config=ttnn.L1_MEMORY_CONFIG, + dtype=self.configs["dtype"]["activations"], + ) ttnn.deallocate(residual_with_silu) ttnn.deallocate(ssm_output) @@ -160,9 +176,9 @@ def forward(self, x): out, self.out_proj_weights, memory_config=ttnn.L1_MEMORY_CONFIG, - core_grid=ttnn.CoreGrid(y=4, x=8), compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, + dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(out) diff --git a/models/demos/mamba/tt/mamba_one_step_ssm.py b/models/demos/mamba/tt/mamba_one_step_ssm.py index 5cf769e75aed..833af10c2696 100644 --- a/models/demos/mamba/tt/mamba_one_step_ssm.py +++ b/models/demos/mamba/tt/mamba_one_step_ssm.py @@ -9,15 +9,12 @@ from typing import Callable from models.demos.mamba.reference.args import ModelArgs -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer class TtMambaSSM(torch.nn.Module): - def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transformer: MambaSsmBlockTransformer): + def __init__(self, args: ModelArgs, device, configs, load_fn: Callable): super().__init__() - self.transformer = transformer - self.device = device self.args = args @@ -116,6 +113,7 @@ def forward(self, x): compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + dtype=self.configs["dtype"]["activations"], ) delta_t1 = ttnn.linear( @@ -126,10 +124,16 @@ def forward(self, x): compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(delta_t0) - delta_t2 = ttnn.softplus(delta_t1, beta=1.0, threshold=20.0, memory_config=ttnn.L1_MEMORY_CONFIG) + delta_t2 = ttnn.softplus( + delta_t1, + beta=1.0, + threshold=20.0, + memory_config=ttnn.L1_MEMORY_CONFIG, + ) ttnn.deallocate(delta_t1) # calculate abar @@ -140,6 +144,7 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(abar0) @@ -154,7 +159,9 @@ def forward(self, x): # multiply abar and hidden_state hidden_state0 = ttnn.to_memory_config(self.tt_hidden_state, memory_config=ttnn.L1_MEMORY_CONFIG) - amulh0 = ttnn.mul(abar2, hidden_state0, memory_config=ttnn.L1_MEMORY_CONFIG) + amulh0 = ttnn.mul( + abar2, hidden_state0, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.configs["dtype"]["activations"] + ) ttnn.deallocate(abar2) ttnn.deallocate(hidden_state0) @@ -166,6 +173,7 @@ def forward(self, x): compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + dtype=self.configs["dtype"]["activations"], ) # bbar @@ -175,6 +183,7 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(delta_t2) ttnn.deallocate(B0) @@ -186,13 +195,16 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) # deallocate bbar ttnn.deallocate(bbar0) # add amulh and bmulx - hidden_state1 = ttnn.add(amulh0, bmulx0, memory_config=ttnn.L1_MEMORY_CONFIG) + hidden_state1 = ttnn.add( + amulh0, bmulx0, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.configs["dtype"]["activations"] + ) ttnn.deallocate(self.tt_hidden_state) self.tt_hidden_state = ttnn.to_memory_config(hidden_state1, memory_config=ttnn.DRAM_MEMORY_CONFIG) ttnn.deallocate(amulh0) @@ -206,6 +218,7 @@ def forward(self, x): compute_kernel_config=self.compute_kernel_config, use_1d_systolic_array=True, core_grid=ttnn.CoreGrid(y=self.core_grid_row, x=self.core_grid_col), + dtype=self.configs["dtype"]["activations"], ) # b,n # c * hidden_state @@ -215,6 +228,7 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(hidden_state1) ttnn.deallocate(C0) @@ -225,16 +239,17 @@ def forward(self, x): output_mem_config=ttl.tensor.MemoryConfig( ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.L1 ), + output_dtype=self.configs["dtype"]["activations"], ) ttnn.deallocate(C1) # x * D D = ttnn.to_memory_config(self.D, memory_config=ttnn.L1_MEMORY_CONFIG) - xD = ttnn.mul(x, D, memory_config=ttnn.L1_MEMORY_CONFIG) + xD = ttnn.mul(x, D, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.configs["dtype"]["activations"]) ttnn.deallocate(x) # add xD and x - output = ttnn.add(xD, C2, memory_config=ttnn.L1_MEMORY_CONFIG) + output = ttnn.add(xD, C2, memory_config=ttnn.L1_MEMORY_CONFIG, dtype=self.configs["dtype"]["activations"]) ttnn.deallocate(C2) ttnn.deallocate(xD) diff --git a/models/demos/mamba/tt/model_config.py b/models/demos/mamba/tt/model_config.py index ac6e30a9c50a..3823034d3131 100644 --- a/models/demos/mamba/tt/model_config.py +++ b/models/demos/mamba/tt/model_config.py @@ -34,6 +34,7 @@ def create_model_config(batch_size, hidden_size): block_w=(hidden_size // (col * row)) // 32, inplace=False, ) + configs["dtype"] = {"activations": ttnn.bfloat8_b} return configs diff --git a/models/demos/mamba/tt/residual_block.py b/models/demos/mamba/tt/residual_block.py index a1cf33f2d70a..ff80dc199ef8 100644 --- a/models/demos/mamba/tt/residual_block.py +++ b/models/demos/mamba/tt/residual_block.py @@ -10,11 +10,10 @@ from models.demos.mamba.reference.args import ModelArgs from models.demos.mamba.tt.mamba_block import TtMambaBlock -from models.demos.mamba.tt.transforms import MambaSsmBlockTransformer class TtResidualBlock(torch.nn.Module): - def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transformer: MambaSsmBlockTransformer): + def __init__(self, args: ModelArgs, device, configs, load_fn: Callable): super().__init__() self.device = device @@ -24,7 +23,7 @@ def __init__(self, args: ModelArgs, device, configs, load_fn: Callable, transfor rms_norm_weight_name = "norm.weight" self.rms_norm_weights = load_fn(rms_norm_weight_name) - self.tt_mamba_block = TtMambaBlock(self.args, self.device, configs, load_fn, transformer) + self.tt_mamba_block = TtMambaBlock(self.args, self.device, configs, load_fn) def forward(self, x): assert len(x.shape) == 4, "Mamba residual block expects inputs to be rank 4" @@ -43,4 +42,4 @@ def forward(self, x): ttnn.deallocate(rms_norm_weights) mamba_x = self.tt_mamba_block(mamba_x) - return ttnn.add(residual, mamba_x) + return ttnn.add(residual, mamba_x, dtype=self.configs["dtype"]["activations"]) diff --git a/models/demos/mamba/tt/transforms.py b/models/demos/mamba/tt/transforms.py deleted file mode 100644 index 8978da096c24..000000000000 --- a/models/demos/mamba/tt/transforms.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. - -# SPDX-License-Identifier: Apache-2.0 - -import ttnn -import tt_lib as ttl -import torch - - -class MambaSsmBlockTransformer: - def __init__(self, device, batch_size, hidden_size, latent_size): - self.device = device - self.batch_size = batch_size - self.hidden_size = hidden_size - self.latent_size = latent_size - repeat_interleave_mask = torch.ones(1, 1, batch_size, latent_size) - self.repeat_interleave_mask = ttnn.from_torch( - repeat_interleave_mask, - layout=ttnn.TILE_LAYOUT, - device=device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=ttnn.bfloat16, - ) - - repeat_mask = torch.ones(1, 1, batch_size, hidden_size) - self.repeat_mask = ttnn.from_torch( - repeat_mask, - layout=ttnn.TILE_LAYOUT, - device=device, - memory_config=ttnn.DRAM_MEMORY_CONFIG, - dtype=ttnn.bfloat16, - ) - - def repeat_interleave(self, x, memory_config): - """ - This function implements an SSM-specific repeat_interleave operation needed to transform - the SSM block input (X) from (B, 2E) to (B, 2EN) so that it can be multiplied with delta*B. - - """ - assert x.shape == ( - 1, - 1, - self.batch_size, - self.hidden_size, - ), f"Expected repeat_interleave input to be (1, 1, B, 2E) (was {x.shape})" - return ttl.operations.primary.transformers.ssm_eltwise_mul( - self.repeat_interleave_mask, x, output_mem_config=memory_config - ) - - def repeat(self, x, memory_config): - """ - This function implements an SSM-specific repeat operation needed to transform the C - value from (B, N) to (B, 2EN) where N is the latent size (32) and E is the - up project size (2560). - """ - assert x.shape == ( - 1, - 1, - self.batch_size, - self.latent_size, - ), f"Expected repeat input to be (1, 1, B, N) (was {x.shape})" - return ttl.operations.primary.transformers.ssm_eltwise_mul(x, self.repeat_mask, output_mem_config=memory_config) diff --git a/models/demos/resnet/tests/test_metal_resnet50.py b/models/demos/resnet/tests/test_metal_resnet50.py index b24297caab86..ad332a641c2b 100644 --- a/models/demos/resnet/tests/test_metal_resnet50.py +++ b/models/demos/resnet/tests/test_metal_resnet50.py @@ -8,7 +8,7 @@ import pytest import tt_lib -from models.utility_functions import is_e75, skip_for_wormhole_b0 +from models.utility_functions import is_e75, skip_for_wormhole_b0, divup from models.demos.resnet.tt.metalResnetBlock50 import ResNet, Bottleneck from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import ( @@ -117,26 +117,107 @@ } -@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") -@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) -@pytest.mark.parametrize("batch_size", [1, 2, 16, 20], ids=["batch_1", "batch_2", "batch_16", "batch_20"]) -@pytest.mark.parametrize( - "weights_dtype", - [tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.DataType.BFLOAT8_B], - ids=["weights_BFLOAT16", "weights_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "activations_dtype", - [tt_lib.tensor.DataType.BFLOAT16, tt_lib.tensor.DataType.BFLOAT8_B], - ids=["activations_BFLOAT16", "activations_BFLOAT8_B"], -) -@pytest.mark.parametrize( - "math_fidelity", - [tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi], - ids=["HiFi4", "HiFi2", "LoFi"], -) -def test_run_resnet50_inference( - device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +def run_model(device, tt_image, tt_resnet50): + tt_output = tt_resnet50(tt_image) + return tt_output.cpu(blocking=True) + + +def run_2cq_model(device, tt_image, tt_resnet50): + input_shape = tt_image.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_image.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_image.shape, tt_image.dtype, tt_image.layout, device, sharded_mem_config_DRAM + ) + op_event = tt_lib.device.CreateEvent() + write_event = tt_lib.device.CreateEvent() + # Initialize the op event so we can write + tt_lib.device.RecordEvent(device, 0, op_event) + + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_image, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + + # Test overlapping write + outputs = [] + for iter in range(0, 2): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_image, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + outputs.append(tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + return outputs[1] + + +def run_trace_model(device, tt_image, tt_resnet50): + input_shape = tt_image.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_image.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_image.shape, tt_image.dtype, tt_image.layout, device, sharded_mem_config_DRAM + ) + tt_lib.tensor.write_tensor(tt_image, tt_image_res) + + # Compile + tt_resnet50(tt_image_res) + # Trace + tid = tt_lib.device.BeginTraceCapture(device, 0, 1500000) + tt_output_res = tt_resnet50(tt_image_res) + tt_lib.device.EndTraceCapture(device, 0, tid) + + tt_lib.tensor.write_tensor(tt_image, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, True) + + # Done with the trace, can deallocate the buffers now. + tt_lib.device.ReleaseTrace(device, tid) + + return tt_output_res.cpu(blocking=True) + + +def run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_fn, ): if is_e75(device): pytest.skip("Resnet50 is not supported on E75") @@ -159,8 +240,6 @@ def test_run_resnet50_inference( with torch.no_grad(): torch.manual_seed(1234) - tt_lib.device.EnableMemoryReports() - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) torch_resnet50.eval() @@ -185,17 +264,8 @@ def test_run_resnet50_inference( torch_output = torch_resnet50(image).unsqueeze(1).unsqueeze(1) tt_image = tt_resnet50.preprocessing(image) - tt_output = tt_resnet50(tt_image) - tt_output = tt_output.cpu().to_torch().to(torch.float) - - # # run again to measure end to end perf - # start_time = datetime.now() - # tt_output = tt_resnet50(image) - # end_time = datetime.now() - # diff = end_time - start_time - # logger.info("End to end time (microseconds))", diff.microseconds) - # throughput_fps = (float) (1000000 / diff.microseconds) - # logger.info("Throughput (fps)", throughput_fps) + tt_output = run_fn(device, tt_image, tt_resnet50) + tt_output = tt_output.to_torch().to(torch.float) _, _, _, info = get_atol_rtol_pcc(torch_output, tt_output) logger.info(info) @@ -239,6 +309,72 @@ def test_run_resnet50_inference( [tt_lib.tensor.MathFidelity.HiFi4, tt_lib.tensor.MathFidelity.HiFi2, tt_lib.tensor.MathFidelity.LoFi], ids=["HiFi4", "HiFi2", "LoFi"], ) +def test_run_resnet50_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_2cqs_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_2cq_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) @pytest.mark.parametrize("enable_async", [True, False]) def test_run_resnet50_trace_inference( device, @@ -250,101 +386,17 @@ def test_run_resnet50_trace_inference( imagenet_sample_input, enable_async, ): - if is_e75(device): - pytest.skip("Resnet50 is not supported on E75") device.enable_async(enable_async) - if batch_size > 8 and ( - activations_dtype != tt_lib.tensor.DataType.BFLOAT8_B or weights_dtype != tt_lib.tensor.DataType.BFLOAT8_B - ): - pytest.skip("Batch > 8 must be run fully bfp8") - if batch_size <= 2: - pytest.skip("batch 1 and 2 are not supported with sharded data") - image1 = imagenet_sample_input - image = image1 - model_config = { - "MATH_FIDELITY": math_fidelity, - "WEIGHTS_DTYPE": weights_dtype, - "ACTIVATIONS_DTYPE": activations_dtype, - } - for i in range(batch_size - 1): - image = torch.cat((image, image1), dim=0) - with torch.no_grad(): - torch.manual_seed(1234) - - tt_lib.device.EnableMemoryReports() - - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) - torch_resnet50.eval() - - state_dict = torch_resnet50.state_dict() - storage_in_dram = False - sharded = False - if batch_size >= 8: - sharded = True - # run once to compile ops - tt_resnet50 = ResNet( - Bottleneck, - [3, 4, 6, 3], - device=device, - state_dict=state_dict, - base_address="", - fold_batchnorm=True, - storage_in_dram=storage_in_dram, - batch_size=batch_size, - model_config=model_config, - sharded=sharded, - ) - - torch_output = torch_resnet50(image).unsqueeze(1).unsqueeze(1) - interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig( - memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=tt_lib.tensor.BufferType.DRAM, - ) - - tt_image_res = tt_resnet50.preprocessing(image).to(device, interleaved_mem_config_DRAM) - # Compile - tt_resnet50(tt_image_res) - # Trace - tid = tt_lib.device.BeginTraceCapture(device, 0, 1334880) - tt_output_res = tt_resnet50(tt_image_res) - tt_lib.device.EndTraceCapture(device, 0, tid) + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_trace_model, + ) - tt_lib.device.ReplayTrace(device, 0, tid, True) - - tt_output = tt_output_res.cpu().to_torch().to(torch.float) - - # # run again to measure end to end perf - # start_time = datetime.now() - # tt_output = tt_resnet50(image) - # end_time = datetime.now() - # diff = end_time - start_time - # logger.info("End to end time (microseconds))", diff.microseconds) - # throughput_fps = (float) (1000000 / diff.microseconds) - # logger.info("Throughput (fps)", throughput_fps) - - _, _, _, info = get_atol_rtol_pcc(torch_output, tt_output) - logger.info(info) - - valid_pcc = 1.0 - if batch_size >= 8: - valid_pcc = golden_pcc[batch_size][ - (model_config["MATH_FIDELITY"], model_config["WEIGHTS_DTYPE"], model_config["ACTIVATIONS_DTYPE"]) - ] - else: - if model_config["ACTIVATIONS_DTYPE"] == tt_lib.tensor.DataType.BFLOAT8_B: - if model_config["MATH_FIDELITY"] == tt_lib.tensor.MathFidelity.LoFi: - valid_pcc = 0.87 - else: - valid_pcc = 0.94 - else: - if model_config["MATH_FIDELITY"] == tt_lib.tensor.MathFidelity.LoFi: - valid_pcc = 0.93 - else: - valid_pcc = 0.982 - passing_pcc, _ = comp_pcc(torch_output, tt_output, pcc=valid_pcc) - assert passing_pcc - # assert passing # fails because of torch.allclose - # Done with the trace, can deallocate the buffers now. - tt_lib.device.ReleaseTrace(device, tid) device.enable_async(False) diff --git a/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py b/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py new file mode 100644 index 000000000000..6bb3147c6d32 --- /dev/null +++ b/models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py @@ -0,0 +1,42 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tt_lib + +from models.demos.resnet.tests.test_metal_resnet50 import run_resnet50_inference, run_2cq_model +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_2cqs_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_2cq_model, + ) diff --git a/models/demos/resnet/tests/test_metal_resnet50_performant.py b/models/demos/resnet/tests/test_metal_resnet50_performant.py new file mode 100644 index 000000000000..cbd266c568c9 --- /dev/null +++ b/models/demos/resnet/tests/test_metal_resnet50_performant.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import tt_lib + +from models.demos.resnet.tests.test_metal_resnet50 import run_resnet50_inference, run_model, run_trace_model +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +def test_run_resnet50_inference( + device, use_program_cache, batch_size, weights_dtype, activations_dtype, math_fidelity, imagenet_sample_input +): + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_model, + ) + + +@skip_for_wormhole_b0("This test is not supported on WHB0, please use the TTNN version.") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.parametrize("batch_size", [20], ids=["batch_20"]) +@pytest.mark.parametrize( + "weights_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["weights_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "activations_dtype", + [tt_lib.tensor.DataType.BFLOAT8_B], + ids=["activations_BFLOAT8_B"], +) +@pytest.mark.parametrize( + "math_fidelity", + [tt_lib.tensor.MathFidelity.LoFi], + ids=["LoFi"], +) +@pytest.mark.parametrize("enable_async", [True, False]) +def test_run_resnet50_trace_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + enable_async, +): + device.enable_async(enable_async) + + run_resnet50_inference( + device, + use_program_cache, + batch_size, + weights_dtype, + activations_dtype, + math_fidelity, + imagenet_sample_input, + run_trace_model, + ) + + device.enable_async(False) diff --git a/models/demos/resnet/tests/test_perf_accuracy_resnet.py b/models/demos/resnet/tests/test_perf_accuracy_resnet.py index 722000caea57..6c719ebbf5b9 100644 --- a/models/demos/resnet/tests/test_perf_accuracy_resnet.py +++ b/models/demos/resnet/tests/test_perf_accuracy_resnet.py @@ -84,6 +84,7 @@ def run_perf_resnet( tt_output = tt_output.cpu().to_torch().to(torch.float) profiler.end(first_key) del tt_output + return enable_persistent_kernel_cache() diff --git a/models/demos/resnet/tests/test_perf_resnet.py b/models/demos/resnet/tests/test_perf_resnet.py index f7bc7368ed2b..a93c82876c97 100644 --- a/models/demos/resnet/tests/test_perf_resnet.py +++ b/models/demos/resnet/tests/test_perf_resnet.py @@ -9,9 +9,7 @@ import pytest import tt_lib -from models.utility_functions import is_e75 -from models.utility_functions import profiler -from models.utility_functions import disable_persistent_kernel_cache, skip_for_wormhole_b0 +from models.utility_functions import is_e75, profiler, divup, disable_persistent_kernel_cache, skip_for_wormhole_b0 from models.perf.perf_utils import prep_perf_report from loguru import logger @@ -24,13 +22,145 @@ } +def run_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + profiler.start("compile") + _ = tt_resnet50(tt_inputs).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + _ = tt_resnet50(tt_inputs).cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + outputs.append(tt_resnet50(tt_inputs).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + +def run_2cq_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + input_shape = tt_inputs.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_inputs.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_inputs.shape, tt_inputs.dtype, tt_inputs.layout, device, sharded_mem_config_DRAM + ) + op_event = tt_lib.device.CreateEvent() + write_event = tt_lib.device.CreateEvent() + # Initialize the op event so we can write + tt_lib.device.RecordEvent(device, 0, op_event) + + profiler.start("compile") + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + _ = tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + tt_lib.device.WaitForEvent(device, 1, op_event) + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res, 1) + tt_lib.device.RecordEvent(device, 1, write_event) + outputs.append(tt_resnet50(tt_image_res, write_event, op_event).cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + +def run_trace_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations): + input_shape = tt_inputs.get_legacy_shape() + shard_spec = tt_lib.tensor.ShardSpec( + tt_lib.tensor.CoreRangeSet( + { + tt_lib.tensor.CoreRange( + tt_lib.tensor.CoreCoord(0, 0), + tt_lib.tensor.CoreCoord(7, 0), + ) + } + ), + [ + divup(tt_inputs.volume() // input_shape[3], 8), + input_shape[3], + ], + tt_lib.tensor.ShardOrientation.ROW_MAJOR, + False, + ) + sharded_mem_config_DRAM = tt_lib.tensor.MemoryConfig( + tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.DRAM, shard_spec + ) + tt_image_res = tt_lib.tensor.allocate_tensor_on_device( + tt_inputs.shape, tt_inputs.dtype, tt_inputs.layout, device, sharded_mem_config_DRAM + ) + # Compile + profiler.start("compile") + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_resnet50(tt_image_res).cpu(blocking=True) + profiler.end("compile") + tt_lib.device.DumpDeviceProfiler(device) + + # Capture + tid = tt_lib.device.BeginTraceCapture(device, 0, 1500000) + tt_output_res = tt_resnet50(tt_image_res) + tt_lib.device.EndTraceCapture(device, 0, tid) + tt_lib.device.DumpDeviceProfiler(device) + + for iter in range(0, num_warmup_iterations): + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, False) + _ = tt_output_res.cpu(blocking=True) + tt_lib.device.DumpDeviceProfiler(device) + + outputs = [] + profiler.start(f"run") + for iter in range(0, num_measurement_iterations): + tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) + tt_lib.device.ReplayTrace(device, 0, tid, False) + outputs.append(tt_output_res.cpu(blocking=False)) + tt_lib.device.Synchronize(device) + profiler.end(f"run") + tt_lib.device.DumpDeviceProfiler(device) + + def run_perf_resnet( batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, + model_version, ): + if is_e75(device): + pytest.skip("Resnet is not supported on E75") disable_persistent_kernel_cache() if batch_size <= 2: pytest.skip("Batch size 1 and 2 are not supported with sharded data") @@ -69,6 +199,10 @@ def run_perf_resnet( model_config=model_config, sharded=sharded, ) + tt_lib.device.Synchronize(device) + + num_warmup_iterations = 5 + num_measurement_iterations = 15 with torch.no_grad(): profiler.start(cpu_key) @@ -76,36 +210,24 @@ def run_perf_resnet( profiler.end(cpu_key) tt_inputs = tt_resnet50.preprocessing(inputs) - warmup_end = 5 - for iter in range(0, warmup_end): - profiler.start(f"{iter}_key") - _ = tt_resnet50(tt_inputs).cpu(blocking=True) - profiler.end(f"{iter}_key") - tt_lib.device.DumpDeviceProfiler(device) - - num_warm_iterations = 15 - warm_start = warmup_end - warm_end = warm_start + num_warm_iterations - - outputs = [] - profiler.start(f"run") - for iter in range(warm_start, warm_end): - outputs.append(tt_resnet50(tt_inputs).cpu(blocking=False)) - tt_lib.device.Synchronize(device) - profiler.end(f"run") - tt_lib.device.DumpDeviceProfiler(device) + if "resnet50_2cqs" in model_version: + run_2cq_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + elif "resnet50_trace" in model_version: + run_trace_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + elif "resnet50" in model_version: + run_model(device, tt_inputs, tt_resnet50, num_warmup_iterations, num_measurement_iterations) + else: + assert False, f"Model version to run {model_version} not found" - # enable_persistent_kernel_cache() - - first_iter_time = profiler.get(f"{0}_key") + first_iter_time = profiler.get(f"compile") # ensuring inference time fluctuations is not noise - inference_time_avg = profiler.get("run") / num_warm_iterations + inference_time_avg = profiler.get("run") / num_measurement_iterations cpu_time = profiler.get(cpu_key) compile_time = first_iter_time - inference_time_avg prep_perf_report( - model_name=f"resnet50_batch_size{batch_size}", + model_name=f"{model_version}_batch_size{batch_size}", batch_size=batch_size, inference_and_compile_time=first_iter_time, inference_time=inference_time_avg, @@ -115,8 +237,8 @@ def run_perf_resnet( inference_time_cpu=cpu_time, ) - logger.info(f"resnet50 {comments} inference time (avg): {inference_time_avg}") - logger.info(f"resnet50 compile time: {compile_time}") + logger.info(f"{model_name} {comments} inference time (avg): {inference_time_avg}") + logger.info(f"{model_name} compile time: {compile_time}") @skip_for_wormhole_b0(reason_str="Not tested on single WH") @@ -125,10 +247,8 @@ def run_perf_resnet( @pytest.mark.parametrize( "batch_size, expected_inference_time, expected_compile_time", ( - (1, 0.001, 1), - (2, 0.001, 1), - (16, 0.007, 7), - (20, 0.007, 7), + (16, 0.007, 16), + (20, 0.007, 16), ), ) def test_perf_bare_metal( @@ -143,145 +263,16 @@ def test_perf_bare_metal( pytest.skip("Resnet is not supported on E75") run_perf_resnet( - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, - device, - ) - - -def run_perf_resnet_trace( - batch_size, - expected_inference_time, - expected_compile_time, - hf_cat_image_sample_input, - device, -): - disable_persistent_kernel_cache() - if batch_size <= 2: - pytest.skip("Batch size 1 and 2 are not supported with sharded data") - first_key = f"first_iter_batchsize{batch_size}" - second_key = f"second_iter_batchsize{batch_size}" - cpu_key = f"ref_key_batchsize{batch_size}" - model_name = "microsoft/resnet-50" - - image = hf_cat_image_sample_input - image_processor = AutoImageProcessor.from_pretrained(model_name) - inputs = image_processor(image, return_tensors="pt") - - inputs = inputs["pixel_values"] - comments = f"{list(inputs.shape)[-2]}x{list(inputs.shape)[-1]}_batchsize{batch_size}" - - inputs1 = inputs - for i in range(batch_size - 1): - inputs = torch.cat((inputs, inputs1), dim=0) - - torch_resnet50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1) - torch_resnet50.eval() - - state_dict = torch_resnet50.state_dict() - sharded = False - if batch_size >= 8: - sharded = True - tt_resnet50 = ResNet( - Bottleneck, - [3, 4, 6, 3], - device=device, - state_dict=state_dict, - base_address="", - fold_batchnorm=True, - storage_in_dram=False, - batch_size=batch_size, - model_config=model_config, - sharded=sharded, + batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50" ) - with torch.no_grad(): - profiler.start(cpu_key) - logits = torch_resnet50(inputs) - profiler.end(cpu_key) - - tt_inputs = tt_resnet50.preprocessing(inputs) - interleaved_mem_config_DRAM = tt_lib.tensor.MemoryConfig( - memory_layout=tt_lib.tensor.TensorMemoryLayout.INTERLEAVED, - buffer_type=tt_lib.tensor.BufferType.DRAM, - ) - tt_image_res = tt_inputs.to(device, interleaved_mem_config_DRAM) - # Compile - profiler.start(f"{0}_key") - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_resnet50(tt_image_res).cpu(blocking=True) - profiler.end(f"{0}_key") - tt_lib.device.DumpDeviceProfiler(device) - - # Capture - tid = tt_lib.device.BeginTraceCapture(device, 0, 1334880) - tt_output_res = tt_resnet50(tt_image_res) - tt_lib.device.EndTraceCapture(device, 0, tid) - tt_lib.device.DumpDeviceProfiler(device) - - warmup_end = 6 - for iter in range(1, warmup_end): - profiler.start(f"{iter}_key") - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_lib.device.ReplayTrace(device, 0, tid, False) - _ = tt_output_res.cpu(blocking=True) - profiler.end(f"{iter}_key") - tt_lib.device.DumpDeviceProfiler(device) - - num_warm_iterations = 15 - warm_start = warmup_end - warm_end = warm_start + num_warm_iterations - - outputs = [] - profiler.start(f"run") - for iter in range(warm_start, warm_end): - tt_lib.tensor.write_tensor(tt_inputs, tt_image_res) - tt_lib.device.ReplayTrace(device, 0, tid, False) - outputs.append(tt_output_res.cpu(blocking=False)) - tt_lib.device.Synchronize(device) - profiler.end(f"run") - tt_lib.device.DumpDeviceProfiler(device) - - # enable_persistent_kernel_cache() - - first_iter_time = profiler.get(f"{0}_key") - - # ensuring inference time fluctuations is not noise - inference_time_avg = profiler.get("run") / num_warm_iterations - - cpu_time = profiler.get(cpu_key) - compile_time = first_iter_time - inference_time_avg - prep_perf_report( - model_name=f"resnet50_trace_batch_size{batch_size}", - batch_size=batch_size, - inference_and_compile_time=first_iter_time, - inference_time=inference_time_avg, - expected_compile_time=expected_compile_time, - expected_inference_time=expected_inference_time, - comments=comments, - inference_time_cpu=cpu_time, - ) - - logger.info(f"resnet50 {comments} inference time (avg): {inference_time_avg}") - logger.info(f"resnet50 compile time: {compile_time}") - - tt_lib.device.ReleaseTrace(device, tid) - - assert inference_time_avg < expected_inference_time, f"resnet50 {comments} inference is too slow" - assert compile_time < expected_compile_time, f"resnet50 {comments} compilation is too slow" - @skip_for_wormhole_b0(reason_str="Not tested on single WH") @pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) @pytest.mark.models_performance_bare_metal @pytest.mark.parametrize( "batch_size, expected_inference_time, expected_compile_time", - ( - (16, 0.04, 25), - (20, 0.04, 25), - ), + ((20, 0.008, 16),), ) @pytest.mark.parametrize("enable_async", [True, False]) def test_perf_trace_bare_metal( @@ -293,14 +284,14 @@ def test_perf_trace_bare_metal( hf_cat_image_sample_input, enable_async, ): - if is_e75(device): - pytest.skip("Resnet is not supported on E75") device.enable_async(enable_async) - run_perf_resnet_trace( + mode = "async" if enable_async else "sync" + run_perf_resnet( batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, + f"resnet50_trace_{mode}", ) device.enable_async(False) diff --git a/models/demos/resnet/tests/test_perf_resnet_2cqs.py b/models/demos/resnet/tests/test_perf_resnet_2cqs.py new file mode 100644 index 000000000000..eddbc1bf4ed7 --- /dev/null +++ b/models/demos/resnet/tests/test_perf_resnet_2cqs.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from models.demos.resnet.tests.test_perf_resnet import run_perf_resnet +from models.utility_functions import skip_for_wormhole_b0 + + +@skip_for_wormhole_b0(reason_str="Not tested on single WH") +@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768, "num_hw_cqs": 2}], indirect=True) +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize( + "batch_size, expected_inference_time, expected_compile_time", + ((20, 0.0055, 16),), +) +def test_perf_2cqs_bare_metal( + device, + use_program_cache, + batch_size, + expected_inference_time, + expected_compile_time, + hf_cat_image_sample_input, +): + run_perf_resnet( + batch_size, expected_inference_time, expected_compile_time, hf_cat_image_sample_input, device, "resnet50_2cqs" + ) diff --git a/models/demos/resnet/tt/metalResnetBlock50.py b/models/demos/resnet/tt/metalResnetBlock50.py index 16f8fb01ffb1..32e3f913c314 100644 --- a/models/demos/resnet/tt/metalResnetBlock50.py +++ b/models/demos/resnet/tt/metalResnetBlock50.py @@ -2101,7 +2101,7 @@ def preprocessing_with_fold(self, x: torch.Tensor) -> tt_lib.tensor: return x - def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: + def forward(self, x: tt_lib.tensor, write_event=None, op_event=None) -> tt_lib.tensor: if not self.sharded: original_A_cl_host_shape = x.get_legacy_shape() x = x.reshape( @@ -2116,7 +2116,7 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: original_A_cl_host_shape[2], original_A_cl_host_shape[3], ) - elif x.storage_type() != tt_lib.tensor.StorageType.DEVICE: + else: x_shape = x.get_legacy_shape() shard_spec = tt_lib.tensor.ShardSpec( self.shard_grid, @@ -2130,21 +2130,16 @@ def forward(self, x: tt_lib.tensor) -> tt_lib.tensor: mem_config = tt_lib.tensor.MemoryConfig( tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec ) - x = x.to(self.device, mem_config) - else: - shard_spec = tt_lib.tensor.ShardSpec( - self.shard_grid, - [ - x.get_legacy_shape()[2] // self.first_conv_num_cores_nhw, - x.get_legacy_shape()[3], - ], - tt_lib.tensor.ShardOrientation.ROW_MAJOR, - False, - ) - mem_config = tt_lib.tensor.MemoryConfig( - tt_lib.tensor.TensorMemoryLayout.HEIGHT_SHARDED, tt_lib.tensor.BufferType.L1, shard_spec - ) - x = tt_lib.tensor.interleaved_to_sharded(x, mem_config) + if write_event is not None: + tt_lib.device.WaitForEvent(self.device, 0, write_event) + if x.storage_type() != tt_lib.tensor.StorageType.DEVICE: + x = x.to(self.device, mem_config) + elif x.memory_config().is_sharded(): + x = tt_lib.tensor.reshard(x, mem_config) + else: + x = tt_lib.tensor.interleaved_to_sharded(x, mem_config) + if op_event is not None: + tt_lib.device.RecordEvent(self.device, 0, op_event) x = self.conv1(x) # Relu is fused with conv1 diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py index c4428ed36366..3db26429c2e6 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_mlp.py @@ -69,7 +69,6 @@ def test_mixtral_mlp_inference(t3k_device_mesh, use_program_cache, reset_seeds): layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(t3k_device_mesh), ) - tt_input = ttnn.to_device(tt_input, t3k_device_mesh) tt_output = tt_model(tt_input) tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ConcatMeshToTensor(t3k_device_mesh, dim=0))[0] diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py index 043666dd8ce9..174ff0c5b235 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py @@ -44,10 +44,10 @@ def forward(self, x): @pytest.mark.parametrize( "generation_start_pos, expected_compile_time, expected_inference_time", ( - (32, 150, 7.5), - (128, 150, 7.5), - (1024, 150, 7.5), - (2048, 150, 7.5), + (32, 150, 0.025), + (128, 150, 0.025), + (1024, 150, 0.025), + (2048, 150, 0.025), ), ) def test_mixtral_model_perf( diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py index b50abc7a3e98..6557af40fab2 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_rms_norm.py @@ -55,7 +55,7 @@ def test_mistral_rms_norm_inference(t3k_device_mesh, use_program_cache, reset_se layout=ttnn.TILE_LAYOUT, mesh_mapper=ReplicateTensorToMesh(t3k_device_mesh), ) - tt_input = ttnn.to_device(tt_input, t3k_device_mesh) + tt_output = tt_model(tt_input) tt_output_torch = ttnn.to_torch(tt_output, mesh_composer=ConcatMeshToTensor(t3k_device_mesh, dim=0))[0] passing, pcc_message = comp_pcc(reference_output, tt_output_torch) diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py index 332db2bbfb03..4b10f62a6ad6 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py @@ -81,7 +81,6 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): cache_file_name=cache_name(f"wqkv_multidevice_4d"), ) - self.wqkv = ttnn.to_device(self.wqkv, self.device_mesh) self.wo = ttnn.as_tensor( torch.transpose( self.state_dict[wo_str], @@ -91,15 +90,13 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): .unsqueeze(0) .unsqueeze(0), device=self.device_mesh, - mesh_mapper=ShardTensorToMesh(self.device_mesh, dim=-2), + mesh_mapper=ReplicateTensorToMesh(self.device_mesh), dtype=self.dtype, memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"], layout=self.model_config["ATTN_W_LAYOUT_TILE"], - cache_file_name=cache_name(f"wo_multidevice4d"), + cache_file_name=cache_name(f"wo_multidevice4d_H"), ) - self.wo = ttnn.to_device(self.wo, self.device_mesh) - cache_k = torch.zeros( ( self.n_kv_heads, @@ -130,22 +127,8 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype): for lp in layer_past ] - self.layer_past = [ttnn.to_device(lp, self.device_mesh) for lp in self.layer_past] - self.scale = self.head_dim**-0.5 - reduce_mask_torch = torch.zeros(1, 1, self.max_batch_size, self.max_batch_size * 8) - for i in range(self.max_batch_size): - reduce_mask_torch[:, :, i, range(i, self.max_batch_size * 8, self.max_batch_size)] = 1 - self.reduce_mask = ttnn.from_torch( - reduce_mask_torch, - device=self.device_mesh, - mesh_mapper=ReplicateTensorToMesh(self.device_mesh), - dtype=ttnn.bfloat8_b, - layout=ttnn.TILE_LAYOUT, - ) - - self.reduce_mask = ttnn.to_device(self.reduce_mask, self.device_mesh) self.compute_kernel = self.model_args.get_compute_kernel_config() self.compute_kernel_attn = self.model_args.get_compute_kernel_attn_config() @@ -306,16 +289,19 @@ def forward( ) attn_output_1B4D.deallocate(True) - # attn_output_11BH = ttnn.experimental.tensor.sharded_to_interleaved( - # attn_output_11BH, output_mem_config=ttnn.L1_MEMORY_CONFIG - # ) + attn_output_11BH = ttnn.experimental.tensor.sharded_to_interleaved( + attn_output_11BH, output_mem_config=ttnn.L1_MEMORY_CONFIG + ) ### # Output matmul ### + # All gather + dense_outputs_11BH_gathered = ttnn.all_gather(attn_output_11BH, dim=3, num_links=1) - dense_out_11BH = ttnn.experimental.operations.primary.matmul( - attn_output_11BH, + # return the sum of the outputs + dense_outputs_11BH = ttnn.experimental.operations.primary.matmul( + dense_outputs_11BH_gathered, wo, output_mem_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"], # compute_with_storage_grid_size=(8, 8), @@ -323,10 +309,6 @@ def forward( compute_kernel_config=self.compute_kernel, output_dtype=ttnn.bfloat8_b, ) - attn_output_11BH.deallocate(True) - # All gather - dense_outputs_11BH = ttnn.all_gather(dense_out_11BH, dim=2, num_links=1) - # return the sum of the outputs - dense_outputs_11BH = ttnn.experimental.operations.primary.matmul(self.reduce_mask, dense_outputs_11BH) + dense_outputs_11BH_gathered.deallocate(True) return dense_outputs_11BH diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py index 83e35f0a0aaa..d3cb5e9f677c 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_common.py @@ -81,7 +81,6 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, current_pos, sliding_window, device_ memory_config=ttnn.L1_MEMORY_CONFIG, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - xs_1SBH = ttnn.to_device(xs_1SBH, device_mesh) # Attention mask padded_layer_past_len = min(nearest_32(current_pos + 1), sliding_window) @@ -108,7 +107,7 @@ def prepare_inputs_ttnn(x_bsh, hidden_size, current_pos, sliding_window, device_ memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - attn_mask = ttnn.to_device(attn_mask, device_mesh) + ATTN_MASK_MEMCFG = ttnn.create_sharded_memory_config( shape=(32, padded_layer_past_len), core_grid=ttnn.CoreGrid(y=4, x=8), @@ -137,7 +136,6 @@ def prepare_rotation_mat_ttnn(head_dim, max_seq_len, device_mesh): ) for rot_mat_i in rot_mat ] - rot_mats = [ttnn.to_device(rot_mat, device_mesh) for rot_mat in rot_mats] return rot_mats @@ -178,7 +176,6 @@ def cache_attention(device_mesh, state_dict, model_args, rot_emb_matrix_list, se memory_config=ttnn.L1_MEMORY_CONFIG, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - attention_inputs = ttnn.to_device(attention_inputs, device_mesh) tt_attn = TtMixtralAttention( device_mesh, @@ -201,7 +198,7 @@ def cache_attention(device_mesh, state_dict, model_args, rot_emb_matrix_list, se memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - attn_mask = ttnn.to_device(attn_mask, device_mesh) + ATTN_MASK_MEMCFG = ttnn.create_sharded_memory_config( shape=(32, padded_layer_past_len), core_grid=ttnn.CoreGrid(y=4, x=8), diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py index 665ef5d9fd30..f3c2002d4d83 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_mlp.py @@ -43,11 +43,8 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtypes): ) self.w1 = as_tensor("w1") - self.w1 = ttnn.to_device(self.w1, device_mesh) self.w2 = as_tensor("w2") - self.w2 = ttnn.to_device(self.w2, device_mesh) self.w3 = as_tensor("w3") - self.w3 = ttnn.to_device(self.w3, device_mesh) def forward(self, x: ttnn.Tensor) -> ttnn.Tensor: """ diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py index 598f9663bc08..6664ad227e2e 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_moe.py @@ -48,7 +48,6 @@ def __init__(self, device_mesh, state_dict, experts, args, layer_num, dtype): device=self.device_mesh, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - self.reduce_mask = ttnn.to_device(self.reduce_mask, device_mesh) self.expert_mask_11BB = ttnn.from_torch( torch.cat([torch.full((1, 1, 32, 32), fill_value=i + 1) for i in range(8)], dim=3), dtype=ttnn.uint16, diff --git a/models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py b/models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py index 4c29ee50ae0d..4957d4d6d1ef 100644 --- a/models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py +++ b/models/demos/t3000/mixtral8x7b/tt/mixtral_rms_norm.py @@ -88,7 +88,6 @@ def __init__( cache_file_name=cache_name, mesh_mapper=ReplicateTensorToMesh(device_mesh), ) - self.weight = ttnn.to_device(self.weight, device_mesh) def forward(self, x: ttnn.Tensor, out_sharded=False) -> ttnn.Tensor: x = ttnn.experimental.tensor.interleaved_to_sharded( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py index 1eb0382ce263..a3ebb92457d8 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_attention.py @@ -154,6 +154,7 @@ def test_falcon_attention( configuration.max_position_embeddings, model_config, parameters=parameters, + core_grid=device_mesh.get_devices()[0].core_grid, ) tt_out, tt_layer_present = tt_FalconAttention_model( diff --git a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py index 6301284023c5..192babe1f3e9 100644 --- a/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py +++ b/models/demos/ttnn_falcon7b/tests/multi_chip/test_falcon_mlp.py @@ -52,6 +52,7 @@ def torch_model(): @pytest.mark.parametrize( "device_mesh", [ + 1, 2, ], indirect=True, diff --git a/models/demos/ttnn_falcon7b/tt/falcon_attention.py b/models/demos/ttnn_falcon7b/tt/falcon_attention.py index 63fb859b7599..51921c0c45c8 100644 --- a/models/demos/ttnn_falcon7b/tt/falcon_attention.py +++ b/models/demos/ttnn_falcon7b/tt/falcon_attention.py @@ -24,6 +24,7 @@ def __init__( max_position_embeddings: int = 2048, model_config=None, parameters=None, + core_grid=None, ): super().__init__() self.hidden_size = hidden_size @@ -49,11 +50,7 @@ def __init__( ) self.scalar = 1 / math.sqrt(self.head_dim) - - if is_wormhole_b0(): - self.core_grid = ttnn.CoreGrid(y=7, x=8) - else: - self.core_grid = ttnn.CoreGrid(y=9, x=12) + self.core_grid = core_grid def __call__( self, @@ -165,7 +162,9 @@ def __call__( attn_weights = ttnn.experimental.operations.primary.transformers.attn_matmul( query_layer, key_layer_transposed, - compute_with_storage_grid_size=ttnn.experimental.tensor.CoreCoord(8, 7), + compute_with_storage_grid_size=ttnn.experimental.tensor.CoreCoord( + self.core_grid.x, self.core_grid.y + ), output_mem_config=self.model_config["PRE_SOFTMAX_MM_OUTPUT_MEMCFG"], output_dtype=self.model_config["PRE_SOFTMAX_MM_OUTPUT_DTYPE"], # Must be BFLOAT16 ) @@ -228,7 +227,9 @@ def __call__( attn_output = ttnn.experimental.operations.primary.transformers.attn_matmul( attn_weights, value_layer, - compute_with_storage_grid_size=ttnn.experimental.tensor.CoreCoord(8, 7), + compute_with_storage_grid_size=ttnn.experimental.tensor.CoreCoord( + self.core_grid.x, self.core_grid.y + ), output_mem_config=self.model_config["POST_SOFTMAX_MM_OUTPUT_MEMCFG"], output_dtype=self.model_config["POST_SOFTMAX_MM_OUTPUT_DTYPE"], # Must be BFLOAT16 ) diff --git a/models/demos/ttnn_falcon7b/tt/falcon_decoder.py b/models/demos/ttnn_falcon7b/tt/falcon_decoder.py index fed5b893129e..045011db439f 100644 --- a/models/demos/ttnn_falcon7b/tt/falcon_decoder.py +++ b/models/demos/ttnn_falcon7b/tt/falcon_decoder.py @@ -31,6 +31,7 @@ def __init__( max_position_embeddings=config.max_position_embeddings, model_config=model_config, parameters=parameters.self_attention, + core_grid=device.get_devices()[0].core_grid, ) self.mlp = TtFalconMLP(model_config, parameters=parameters.mlp) diff --git a/scripts/docker/build_docker_image.sh b/scripts/docker/build_docker_image.sh index 82df50664e43..39c01283fbf5 100755 --- a/scripts/docker/build_docker_image.sh +++ b/scripts/docker/build_docker_image.sh @@ -5,5 +5,5 @@ TT_METAL_DOCKER_IMAGE_TAG=${1:-ubuntu-20.04-amd64:latest} TT_METAL_HOME=$(git rev-parse --show-toplevel) ( cd ${TT_METAL_HOME} || exit - docker build -f dockerfile/ubuntu-20.04-x86.Dockerfile -t ${TT_METAL_DOCKER_IMAGE_TAG} . + docker build -f dockerfile/ubuntu-20.04-amd64.Dockerfile -t ${TT_METAL_DOCKER_IMAGE_TAG} . ) \ No newline at end of file diff --git a/tests/scripts/run_cpp_fd2_tests.sh b/tests/scripts/run_cpp_fd2_tests.sh index 9d9d8b61445b..84134ef5e6ca 100755 --- a/tests/scripts/run_cpp_fd2_tests.sh +++ b/tests/scripts/run_cpp_fd2_tests.sh @@ -59,11 +59,15 @@ run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 1 -i 5 -x -spre" # Smoke Test run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 1 -i 5 -x -spre -sdis" # Smoke Test +run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 2 -i 5 -x -spre -sdis" # Random Test +run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 6 -i 5 -x -spre -sdis" # Host Test if [[ $ARCH_NAME == "wormhole_b0" ]]; then # packetized path used only on multi-chip WH run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 0 -i 5 -spre -sdis -packetized_en" # TrueSmoke Test with packetized path run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 1 -i 5 -spre -sdis -packetized_en" # Smoke Test with packetized path + run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 2 -i 5 -spre -sdis -packetized_en" # Random Test with packetized path + run_test "./build/test/tt_metal/perf_microbenchmark/dispatch/test_prefetcher -t 6 -i 5 -spre -sdis -packetized_en" # Host Test with packetized path fi diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index 23cc2d0d0ba3..e535e635d451 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -17,7 +17,9 @@ run_perf_models_other() { env pytest models/demos/ttnn_falcon7b/tests -m $test_marker - env pytest models/demos/resnet/tests -m $test_marker + # Separate calls since we can't mix switching between number of cqs + env pytest models/demos/resnet/tests/test_perf_resnet.py -m $test_marker + env pytest models/demos/resnet/tests/test_perf_resnet_2cqs.py -m $test_marker env pytest tests/ttnn/integration_tests/whisper/test_performance.py -m $test_marker diff --git a/tests/scripts/single_card/nightly/run_gs_only.sh b/tests/scripts/single_card/nightly/run_gs_only.sh index 9973f35b7bda..36ed969d4a04 100755 --- a/tests/scripts/single_card/nightly/run_gs_only.sh +++ b/tests/scripts/single_card/nightly/run_gs_only.sh @@ -11,6 +11,6 @@ echo "Running model nightly tests for GS only" env pytest models/demos/metal_BERT_large_11/tests/test_demo.py -env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_inference[LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0] +env pytest models/demos/resnet/tests/test_metal_resnet50_performant.py -env pytest models/demos/resnet/tests/test_metal_resnet50.py::test_run_resnet50_trace_inference -k "LoFi-activations_BFLOAT8_B-weights_BFLOAT8_B-batch_20-device_params0" +env pytest models/demos/resnet/tests/test_metal_resnet50_2cqs_performant.py diff --git a/tests/scripts/single_card/nightly/run_wh_b0_only.sh b/tests/scripts/single_card/nightly/run_wh_b0_only.sh index 163ed499c4a3..5af44887070a 100755 --- a/tests/scripts/single_card/nightly/run_wh_b0_only.sh +++ b/tests/scripts/single_card/nightly/run_wh_b0_only.sh @@ -14,13 +14,12 @@ SLOW_MATMULS=1 WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml env pytest tes env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b/tests/ci/test_falcon_end_to_end_prefill.py +env pytest models/demos/mamba/tests/test_benchmarks.py +env pytest models/demos/mamba/tests/test_reference_model.py env pytest models/demos/mamba/tests/test_mamba_ssm.py env pytest models/demos/mamba/tests/test_mamba_block.py env pytest models/demos/mamba/tests/test_residual_block.py -env pytest models/demos/mamba/tests/test_full_model_loop.py -env pytest models/demos/mamba/tests/test_benchmarks.py -env pytest models/demos/mamba/tests/test_reference_model.py -env pytest models/demos/mamba/tests/test_transforms.py +env pytest models/demos/mamba/tests/test_full_model.py env pytest models/demos/mamba/tests/test_mamba_demo.py env pytest models/demos/wormhole/mistral7b/tests/test_mistral_embedding.py diff --git a/tests/scripts/t3000/run_t3000_model_perf_tests.sh b/tests/scripts/t3000/run_t3000_model_perf_tests.sh index abff688f6487..c8fc186f9bca 100755 --- a/tests/scripts/t3000/run_t3000_model_perf_tests.sh +++ b/tests/scripts/t3000/run_t3000_model_perf_tests.sh @@ -22,7 +22,7 @@ run_t3000_mixtral_tests() { echo "LOG_METAL: Running run_t3000_mixtral_tests" - env pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py::test_mixtral_model_perf[wormhole_b0-True-2048-150-7.5] -m "model_perf_t3000" + env pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py::test_mixtral_model_perf[wormhole_b0-True-2048-150-0.025] -m "model_perf_t3000" # Record the end time end_time=$(date +%s) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py index 4d70e6b70d6f..d6bf0b1ab5f4 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/op_map.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/op_map.py @@ -803,6 +803,14 @@ "tt_op": tt_lib_ops.where, "pytorch_op": pytorch_ops.where, }, + "eltwise-where-optional": { + "tt_op": tt_lib_ops.where_optional, + "pytorch_op": pytorch_ops.where, + }, + "eltwise-where-scalar-optional": { + "tt_op": tt_lib_ops.where_scalar_optional, + "pytorch_op": pytorch_ops.where_scalar, + }, "where-bw": { "tt_op": tt_lib_ops.where_bw, "pytorch_op": pytorch_ops.where_bw, diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py index 4ddb18dde5ce..fee9e99be3c3 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_ternary.py @@ -4,6 +4,7 @@ import pytest import torch +import random from functools import partial from math import pi @@ -36,3 +37,48 @@ def test_run_eltwise_where_test(input_shapes, device, function_level_defaults): comparison_func, device, ) + + +@pytest.mark.parametrize("input_shapes", shapes) +def test_run_eltwise_where_test_optional(input_shapes, device, function_level_defaults): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-5, high=+5), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-10, high=+10), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32), + ] + comparison_func = partial(comparison_funcs.comp_pcc) + run_single_pytorch_test( + "eltwise-where-optional", + [input_shapes[0], input_shapes[0], input_shapes[0], input_shapes[0]], + datagen_func, + comparison_func, + device, + ) + + +shapes_scalar = ( + [[1, 1, 32, 32], [1, 1, 32, 32]], # Single core + [[1, 1, 320, 384], [1, 1, 320, 384]], # Multi core + [[1, 3, 320, 384], [1, 3, 320, 384]], # Multi core +) + + +@pytest.mark.parametrize("input_shapes", shapes_scalar) +def test_run_eltwise_where_scalar_optional(input_shapes, device, function_level_defaults): + datagen_func = [ + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_randint, low=-100, high=+100), torch.float32), + generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=-1, high=+1), torch.float32), + ] + test_args = list(generation_funcs.gen_default_dtype_layout_device(input_shapes))[0] + test_args.update({"scalar_true": random.uniform(0.5, 75.5), "scalar_false": random.uniform(0.5, 95.5)}) + + comparison_func = partial(comparison_funcs.comp_pcc) + run_single_pytorch_test( + "eltwise-where-scalar-optional", + input_shapes, + datagen_func, + comparison_func, + device, + test_args, + ) diff --git a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py index 8a588493e48a..1b0f4c27a1a9 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/pytorch_ops.py @@ -96,6 +96,12 @@ def where(x, y, z, *args, **kwargs): return torch.where(x > 0, y, z) +def where_scalar(x, *args, **kwargs): + y = kwargs.pop("scalar_true") + z = kwargs.pop("scalar_false") + return torch.where(x > 0, y, z) + + def where_bw(x, y, z, w, *args, **kwargs): grad_data = x in_data = y @@ -1331,8 +1337,13 @@ def eltwise_identity(x, *args, **kwargs): return x -def eltwise_typecast(x, *args, **kwargs): - return torch.relu(x.to(torch.int32)) # due to no uint32 support +def eltwise_typecast(x, *args, tt_output_dtype, **kwargs): + if tt_output_dtype[0] == ttl.tensor.DataType.UINT16: + return torch.clamp(x.to(torch.int32), min=0, max=65535) # due to no uint16 support + elif tt_output_dtype[0] == ttl.tensor.DataType.UINT32: + return torch.relu(x.to(torch.int32)) # due to no uint32 support + else: + return x def eltwise_rdiv(x, *args, **kwargs): diff --git a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py index d7c116b794b9..b9dac18fd1b5 100644 --- a/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py +++ b/tests/tt_eager/python_api_testing/sweep_tests/tt_lib_ops.py @@ -1518,6 +1518,28 @@ def where(x, y, z, device, dtype, layout, input_mem_config, output_mem_config, * return tt2torch_tensor(t3) +@setup_host_and_device +def where_optional(x, y, z, out, device, dtype, layout, input_mem_config, output_mem_config, **kwargs): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t1 = setup_tt_tensor(y, device, layout[1], input_mem_config[1], dtype[1]) + t2 = setup_tt_tensor(z, device, layout[2], input_mem_config[2], dtype[2]) + t3 = setup_tt_tensor(out, device, layout[3], input_mem_config[3], dtype[3]) + ttl.tensor.where(t0, t1, t2, output_mem_config=output_mem_config, output_tensor=t3) + + return tt2torch_tensor(t3) + + +@setup_host_and_device +def where_scalar_optional( + x, out, device, dtype, layout, input_mem_config, output_mem_config, scalar_true, scalar_false, **kwargs +): + t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) + t3 = setup_tt_tensor(out, device, layout[1], input_mem_config[1], dtype[1]) + ttl.tensor.where(t0, scalar_true, scalar_false, output_mem_config=output_mem_config, output_tensor=t3) + + return tt2torch_tensor(t3) + + @setup_host_and_device def eltwise_div_unary( x, @@ -2192,13 +2214,14 @@ def eltwise_typecast( *args, device, dtype, + tt_output_dtype, layout, input_mem_config, output_mem_config, **kwargs, ): t0 = setup_tt_tensor(x, device, layout[0], input_mem_config[0], dtype[0]) - t1 = ttl.tensor.eltwise_typecast(t0, output_mem_config=output_mem_config) + t1 = ttl.tensor.eltwise_typecast(t0, tt_output_dtype[0], output_mem_config=output_mem_config) return tt2torch_tensor(t1) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py index 769adb144b10..5d6a12971ef2 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_all_gather.py @@ -224,7 +224,6 @@ def test_all_gather_on_t3000_post_commit_looping( [ (4, 2, [4, 1, 33, 256], 0, ttl.tensor.Layout.ROW_MAJOR), (8, 1, [8, 1, 33, 256], 0, ttl.tensor.Layout.ROW_MAJOR), - # (8, 1, [8, 1, 256, 32], 0, ttl.tensor.Layout.TILE), (8, 1, [8, 8, 256, 384], 1, ttl.tensor.Layout.ROW_MAJOR), (4, 2, [8, 8, 256, 384], 1, ttl.tensor.Layout.ROW_MAJOR), (4, 2, [8, 8, 256, 384], 1, ttl.tensor.Layout.TILE), @@ -259,6 +258,8 @@ def test_all_gather_on_t3000_post_commit_looping( (8, 1, [1, 1, 1024, 256], 3, ttl.tensor.Layout.TILE), (8, 1, [1, 1, 256, 2048], 2, ttl.tensor.Layout.TILE), (8, 1, [1, 1, 256, 8192], 2, ttl.tensor.Layout.TILE), # double on reduction dim for 8 chip + (8, 1, [8, 1, 256, 32], 0, ttl.tensor.Layout.TILE), + (8, 1, [8, 8, 128, 4096], 1, ttl.tensor.Layout.TILE), ], ) @pytest.mark.parametrize( @@ -424,6 +425,11 @@ def test_line_all_gather_on_t3000_post_commit( ([8, 8, 256, 384], 3, ttl.tensor.Layout.TILE), ([8, 8, 256, 768], 3, ttl.tensor.Layout.ROW_MAJOR), ([8, 8, 256, 768], 3, ttl.tensor.Layout.TILE), + ([8, 8, 1024, 4096], 1, ttl.tensor.Layout.TILE), + ([8, 8, 2048, 4096], 1, ttl.tensor.Layout.TILE), + ([8, 8, 128, 4096], 1, ttl.tensor.Layout.ROW_MAJOR), + ([8, 8, 1024, 4096], 1, ttl.tensor.Layout.ROW_MAJOR), + ([8, 8, 2048, 4096], 1, ttl.tensor.Layout.ROW_MAJOR), # Only for BFP8B # ([1, 1, 640, 32768], 3, ttl.tensor.Layout.TILE), # MLP AllGather. Llama 2 decode attn, mlp. Llama2, Falcon 40B decode mlp attn diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py index 0f5e1bb50e3f..ed50144d26a5 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_matmul_dram_sharded.py @@ -279,3 +279,187 @@ def test_matmul_in1_dram_sharded_with_program_cache( ttl.tensor.Tensor(py_dummy_tensor, in0_dtype).to(ttl.tensor.Layout.TILE).to(device, mem_config) ) assert device.num_program_cache_entries() == 3 + + +def run_test_matmul_in1_dram_sharded_mm_chain( + device, + in0_sharded, + out_sharded, + in1_in_dram, + M, + K, + N, + fidelity, + has_bias, + activation, + grid_size, + in0_dtype, + in1_dtype, + out_dtype, + function_level_defaults, + use_program_cache, +): + if is_grayskull() and (N == 4096 or K == 32768): + pytest.skip("Skipping too large tensor test on Grayskull") + + if is_grayskull(): + N_padded = N + num_banks = 8 + else: + N_padded = pad_to_dram_banks(N) + num_banks = 12 + + in0_shape = [1, 1, M, K] + in1_shape = [1, 1, K, N] + in1_shard_shape = [K, N_padded // num_banks] + num_cores = grid_size[0] * grid_size[1] + + in0_block_h = M // 32 + in0_block_w = K // num_cores // 32 + out_block_h = M // 32 + out_block_w = N // num_cores // 32 + + out_subblock_h, out_subblock_w, _ = find_max_subblock(out_block_h, out_block_w) + + logger.debug("N_padded " + str(N_padded)) + logger.debug("in0 block h w " + str(in0_block_h * 32) + " " + str(in0_block_w * 32)) + logger.debug("in1 block h w " + str(in0_block_w * 32) + " " + str(out_block_w * 32)) + logger.debug("out block h w " + str(out_block_h * 32) + " " + str(out_block_w * 32)) + logger.debug("out subblock h w " + str(out_subblock_h * 32) + " " + str(out_subblock_w * 32)) + + sharded_mem_config = ttl.tensor.MemoryConfig( + memory_layout=ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, + buffer_type=ttl.tensor.BufferType.L1, + ) + + in0 = torch.randn(in0_shape).bfloat16().float() + in1 = torch.randn(in1_shape).bfloat16().float() + + in0_shard_grid = (grid_size[0] - 1, grid_size[1] - 1) + in0_shard_shape = [M, int(in0_block_w * 32)] + in0_shard_grid = ttl.tensor.CoreRangeSet({ttl.tensor.CoreRange(ttl.tensor.CoreCoord(0, 0), in0_shard_grid)}) + in0_shard_spec = ttl.tensor.ShardSpec(in0_shard_grid, in0_shard_shape, ttl.tensor.ShardOrientation.ROW_MAJOR, False) + in0_mem_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, ttl.tensor.BufferType.L1, in0_shard_spec + ) + in0_t = torch2tt_tensor(in0, device, tt_memory_config=in0_mem_config, tt_dtype=in0_dtype) + + in1_shard_grid = ttl.tensor.CoreCoord(device.dram_grid_size().x - 1, device.dram_grid_size().y - 1) + in1_shard_grid = ttl.tensor.CoreRangeSet({ttl.tensor.CoreRange(ttl.tensor.CoreCoord(0, 0), in1_shard_grid)}) + in1_shard_spec = ttl.tensor.ShardSpec(in1_shard_grid, in1_shard_shape, ttl.tensor.ShardOrientation.ROW_MAJOR, False) + in1_mem_config = ttl.tensor.MemoryConfig( + ttl.tensor.TensorMemoryLayout.WIDTH_SHARDED, ttl.tensor.BufferType.DRAM, in1_shard_spec + ) + in1_t = torch2tt_tensor(in1, device, tt_memory_config=in1_mem_config, tt_dtype=in1_dtype) + + program_config = ttl.operations.primary.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig( + in0_block_w=in0_block_w // 4, + out_subblock_h=out_subblock_h, + out_subblock_w=out_subblock_w, + per_core_M=out_block_h, + per_core_N=out_block_w, + fuse_batch=True, + fused_activation=None, + ) + + if is_grayskull(): + compute_kernel_config = ttl.tensor.GrayskullComputeKernelConfig( + math_fidelity=fidelity, + math_approx_mode=True, + ) + else: + compute_kernel_config = ttl.tensor.WormholeComputeKernelConfig( + math_fidelity=fidelity, + math_approx_mode=True, + fp32_dest_acc_en=True, + packer_l1_acc=True, + ) + + # 1st mm + output_t = ttl.operations.primary.matmul( + in0_t, + in1_t, + program_config=program_config, + output_mem_config=sharded_mem_config, + output_dtype=out_dtype, + compute_kernel_config=compute_kernel_config, + ) + + for _ in range(200): + output_t = ttl.operations.primary.matmul( + in0_t, + in1_t, + program_config=program_config, + output_mem_config=sharded_mem_config, + output_dtype=out_dtype, + compute_kernel_config=compute_kernel_config, + ) + + output_t = output_t.cpu().to(ttl.tensor.Layout.ROW_MAJOR) + + pt_out = in0 @ in1 + + tt_out = tt2torch_tensor(output_t) + + print(tt_out) + print(pt_out) + + passing, output = comp_pcc(pt_out, tt_out) + logger.info(output) + assert True + + +@pytest.mark.parametrize( + "fidelity", + [ + ttl.tensor.MathFidelity.HiFi2, + ], + ids=[ + "HiFi2", + ], +) +@pytest.mark.parametrize( + "has_bias", + [ + False, + ], + ids=["no_bias"], +) +@pytest.mark.parametrize( + "in0_dtype, in1_dtype, out_dtype", + [ + (ttl.tensor.DataType.BFLOAT16, ttl.tensor.DataType.BFLOAT8_B, ttl.tensor.DataType.BFLOAT16), + ], +) +def test_matmul_in1_dram_sharded_with_mm_chain( + device, + fidelity, + has_bias, + in0_dtype, + in1_dtype, + out_dtype, + function_level_defaults, + use_program_cache, +): + M = 32 + K = 4096 + N = 4096 + grid_size = (8, 2) + run_test_matmul_in1_dram_sharded_mm_chain( + device, + True, + True, + True, + M, + K, + N, + fidelity, + has_bias, + None, + grid_size, + in0_dtype, + in1_dtype, + out_dtype, + function_level_defaults, + use_program_cache, + ) diff --git a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py index 08f033f23288..f7f615b66a7a 100644 --- a/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py +++ b/tests/tt_eager/python_api_testing/unit_testing/misc/test_moreh_adamw.py @@ -17,22 +17,7 @@ from loguru import logger -@pytest.mark.parametrize( - "shape", - ( - (1, 1, 32, 32), # single - (12, 6, 64, 64), # multi tile - ), -) -@pytest.mark.parametrize("lr", [0.0, 1e-2]) -@pytest.mark.parametrize("betas", ((0.9, 0.999), (0.5, 0.555))) -@pytest.mark.parametrize("eps", [1e-06, 1e-08]) -@pytest.mark.parametrize("weight_decay", [0.0, 0.3]) -@pytest.mark.parametrize("amsgrad", [True, False]) -@pytest.mark.parametrize("step", [1, 2, 8]) -def test_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device): - torch.manual_seed(0) - +def run_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device): N = shape[0] C = shape[1] H = shape[2] @@ -205,3 +190,38 @@ def forward(self, x): whole_passing &= passing assert whole_passing + + +@pytest.mark.parametrize( + "shape", + [ + [1, 1, 32, 32], # single + [12, 6, 64, 64], # multi tile + ], +) +@pytest.mark.parametrize("lr", [0.0, 1e-2]) +@pytest.mark.parametrize("betas", ((0.9, 0.999), (0.5, 0.555))) +@pytest.mark.parametrize("eps", [1e-06, 1e-08]) +@pytest.mark.parametrize("weight_decay", [0.0, 0.3]) +@pytest.mark.parametrize("amsgrad", [True, False]) +@pytest.mark.parametrize("step", [1, 2, 8]) +def test_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device): + torch.manual_seed(0) + + run_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device) + + +@pytest.mark.parametrize( + "shape", + [[1, 1, 32, 32]], # single +) +@pytest.mark.parametrize("lr", [1e-2]) +@pytest.mark.parametrize("betas", [[0.9, 0.999], [0.5, 0.555]]) +@pytest.mark.parametrize("eps", [1e-08]) +@pytest.mark.parametrize("weight_decay", [0.3]) +@pytest.mark.parametrize("amsgrad", [True, False]) +@pytest.mark.parametrize("step", [8]) +def test_moreh_adamw_callback(shape, lr, betas, eps, weight_decay, amsgrad, step, device, use_program_cache): + torch.manual_seed(0) + for _ in range(2): + run_moreh_adamw(shape, lr, betas, eps, weight_decay, amsgrad, step, device) diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp index 9ae0f0adffba..9f94b540aafa 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/3_pcie_transfer/kernels/pull_from_pcie.cpp @@ -17,7 +17,7 @@ void kernel_main() { volatile tt_l1_ptr uint32_t* done_address = reinterpret_cast(L1_UNRESERVED_BASE); while (done_address[0] == 0) { - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); + uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX), pcie_read_ptr); noc_async_read(host_src_addr, L1_UNRESERVED_BASE, read_sizeB); pcie_read_ptr += read_sizeB; if (pcie_read_ptr > pcie_base + pcie_sizeB) { diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h index 249f6bc0974c..d6a0344f9ddf 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/common.h @@ -372,7 +372,8 @@ inline bool DeviceData::validate_one_core(Device *device, bool DeviceData::validate_host(std::unordered_set &validated_cores, const one_core_data_t& host_data) { - log_info(tt::LogTest, "Validating data from hugepage"); + uint32_t size_bytes = host_data.data.size() * sizeof(uint32_t); + log_info(tt::LogTest, "Validating {} bytes from hugepage", size_bytes); bool failed = false; @@ -383,7 +384,7 @@ bool DeviceData::validate_host(std::unordered_set &validated_cores, bool done = false; for (int data_index = 0; data_index < host_data.data.size(); data_index++) { validated_cores.insert(this->host_core); - if (host_data.data[data_index] != results[host_data_index] && fail_count < 5000) { + if (host_data.data[data_index] != results[host_data_index] && fail_count < 20) { if (!failed) { log_fatal(tt::LogTest, "Data mismatch - First 20 host data failures: [idx] expected->read"); } @@ -754,6 +755,8 @@ inline void gen_bare_dispatcher_unicast_write_cmd(Device *device, cmd.write_linear.length = length; cmd.write_linear.num_mcast_dests = 0; + TT_FATAL((cmd.write_linear.addr & (16 - 1)) == 0); // XXXXX L1_ALIGNMENT16 + add_bare_dispatcher_cmd(cmds, cmd); } diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp index 76571678b520..bc5e958996ff 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/dispatch/test_prefetcher.cpp @@ -470,6 +470,26 @@ void gen_dram_write_cmd(Device *device, add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); } +void gen_wait_and_stall_cmd(Device *device, + vector& prefetch_cmds, + vector& cmd_sizes) { + + vector dispatch_cmds; + + CQDispatchCmd wait; + wait.base.cmd_id = CQ_DISPATCH_CMD_WAIT; + wait.wait.barrier = true; + wait.wait.notify_prefetch = true; + wait.wait.wait = true; + wait.wait.addr = dispatch_wait_addr_g; + wait.wait.count = 0; + add_bare_dispatcher_cmd(dispatch_cmds, wait); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + + vector empty_payload; // don't give me grief, it is just a test + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_STALL, empty_payload); +} + // This is pretty much a blit: copies from worker core's start of data back to the end of data void gen_linear_read_cmd(Device *device, vector& prefetch_cmds, @@ -482,6 +502,9 @@ void gen_linear_read_cmd(Device *device, vector dispatch_cmds; const uint32_t bank_id = 0; // No interleaved pages here. + // Stall because we are reading data that was previously written + gen_wait_and_stall_cmd(device, prefetch_cmds, cmd_sizes); + gen_bare_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, length); add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE_NOFLUSH, dispatch_cmds); @@ -498,25 +521,7 @@ void gen_linear_read_cmd(Device *device, for (uint32_t i = 0; i < length_words; i++) { device_data.push_one(worker_core, device_data.at(worker_core, bank_id, offset + i)); } -} - -void gen_wait_and_stall_cmd(Device *device, - vector& prefetch_cmds, - vector& cmd_sizes) { - - vector dispatch_cmds; - - CQDispatchCmd wait; - wait.base.cmd_id = CQ_DISPATCH_CMD_WAIT; - wait.wait.barrier = true; - wait.wait.notify_prefetch = true; - wait.wait.addr = dispatch_wait_addr_g; - wait.wait.count = 0; - add_bare_dispatcher_cmd(dispatch_cmds, wait); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - - vector empty_payload; // don't give me grief, it is just a test - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_STALL, empty_payload); + device_data.pad(worker_core, bank_id, 16); // XXXX L1_ALIGNMENT } void gen_dispatcher_delay_cmd(Device *device, @@ -633,37 +638,64 @@ void gen_host_test(Device *device, vector& cmd_sizes, DeviceData& device_data) { - constexpr uint32_t data_size = 614400; + constexpr uint32_t max_data_size = DEVICE_DATA_SIZE; // Read data from a worker so we can get reasonable BW measurements // TODO: extend the DRAM mechanism for pre-fill to workers vectordata; - for (uint32_t i = 0; i < data_size / sizeof(uint32_t); i++) { + for (uint32_t i = 0; i < max_data_size / sizeof(uint32_t); i++) { data.push_back(i); } CoreCoord phys_worker_core = device->worker_core_from_logical_core(first_worker_g); llrt::write_hex_vec_to_core(device->id(), phys_worker_core, data, l1_buf_base_g); tt::Cluster::instance().l1_barrier(device->id()); - for (int count = 0; count < 50; count++) { + for (int count = 1; count < 100; count++) { + uint32_t data_size_words = std::rand() % ((max_data_size / 100 / sizeof(uint32_t)) * count) + 1; + uint32_t data_size_bytes = data_size_words * sizeof(uint32_t); + std::vector dispatch_cmds; - gen_bare_dispatcher_host_write_cmd(dispatch_cmds, data_size); + gen_bare_dispatcher_host_write_cmd(dispatch_cmds, data_size_bytes); add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE_NOFLUSH, dispatch_cmds); auto prior_end = prefetch_cmds.size(); - add_prefetcher_linear_read_cmd(device, prefetch_cmds, cmd_sizes, first_worker_g, l1_buf_base_g, data_size); + add_prefetcher_linear_read_cmd(device, prefetch_cmds, cmd_sizes, first_worker_g, l1_buf_base_g, data_size_bytes); uint32_t new_size = (prefetch_cmds.size() - prior_end) * sizeof(uint32_t); cmd_sizes.push_back(new_size >> dispatch_constants::PREFETCH_Q_LOG_MINSIZE); + // write host writes the command back to the host for (auto datum : dispatch_cmds) { device_data.push_one(device_data.get_host_core(), 0, datum); } - for (auto datum : data) { + for (int i = 0; i < data_size_words; i++) { + uint32_t datum = data[i]; device_data.push_one(device_data.get_host_core(), 0, datum); } pad_host_data(device_data); } } +void gen_rnd_linear_cmd(Device *device, + vector& prefetch_cmds, + vector& cmd_sizes, + DeviceData& device_data, + CoreCoord worker_core) { + + vector dispatch_cmds; + + // Hmm, how big a size to test? + int max_linear_cmd_read_size = 20 * dispatch_buffer_page_size_g; // XXXXX 10 * + uint32_t size = std::rand() % max_linear_cmd_read_size; + size &= ~(sizeof(uint32_t) - 1); + uint32_t offset = std::rand() % dispatch_buffer_page_size_g; + offset = (offset >> 2) << 2; + device_data.relevel(CoreType::WORKER); // XXXXX shouldn't be needed + if (device_data.size_at(worker_core, 0) * sizeof(uint32_t) < max_linear_cmd_read_size + offset) { + // Not enough data yet, just bail on this cmd + return; + } + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, size, offset); +} + void gen_rnd_dram_paged_cmd(Device *device, vector& prefetch_cmds, vector& cmd_sizes, @@ -762,6 +794,11 @@ void gen_rnd_test(Device *device, CoreCoord worker_core(first_worker_g.x + x, first_worker_g.y + y); switch (cmd) { + case CQ_PREFETCH_CMD_RELAY_LINEAR: + // TODO: disabled for now + // test issue w/ handling re-leveling of results data after paged commands + //gen_rnd_linear_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core); + break; case CQ_PREFETCH_CMD_RELAY_PAGED: gen_rnd_dram_paged_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core); break; @@ -896,6 +933,23 @@ void gen_smoke_test(Device *device, gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, 8448); add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + // Check some hard page alignment sizes + dispatch_cmds.resize(0); + gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, dispatch_buffer_page_size_g); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + + dispatch_cmds.resize(0); + gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + + dispatch_cmds.resize(0); + gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, 2 * dispatch_buffer_page_size_g); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + + dispatch_cmds.resize(0); + gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, 2 * dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + // Merge 4 commands in the FetchQ dispatch_cmds.resize(0); gen_dispatcher_unicast_write_cmd(device, dispatch_cmds, worker_core, device_data, 112); @@ -991,6 +1045,10 @@ void gen_smoke_test(Device *device, // These tests copy data from earlier tests so can't run first gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, 32); gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, 65 * 1024); + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, dispatch_buffer_page_size_g); + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, 2 * dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + gen_linear_read_cmd(device, prefetch_cmds, cmd_sizes, device_data, worker_core, 2 * dispatch_buffer_page_size_g); // Test wait/stall gen_dispatcher_delay_cmd(device, prefetch_cmds, cmd_sizes, 1024 * 1024); @@ -1002,30 +1060,47 @@ void gen_smoke_test(Device *device, // Test host if (!use_dram_exec_buf_g) { - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, 32); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); - - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, 36); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); - - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, 1024); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); - - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); - - dispatch_cmds.resize(0); - gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, 16384); - add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); - pad_host_data(device_data); + for (int multiplier = 1; multiplier <= 3; multiplier++) { + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * 32); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * 36); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * 1024); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g - 2 * sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g - sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g + sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + + dispatch_cmds.resize(0); + gen_dispatcher_host_write_cmd(dispatch_cmds, device_data, multiplier * dispatch_buffer_page_size_g + sizeof(CQDispatchCmd)); + add_prefetcher_cmd(prefetch_cmds, cmd_sizes, CQ_PREFETCH_CMD_RELAY_INLINE, dispatch_cmds); + pad_host_data(device_data); + } } // Test Paged DRAM Write and Read. FIXME - Needs work - hits asserts. diff --git a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen_tx.cpp b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen_tx.cpp index d43c6ba8ca21..a698fba95cd9 100644 --- a/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen_tx.cpp +++ b/tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen_tx.cpp @@ -2,10 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 +// clang-format off #include "dataflow_api.h" #include "debug/dprint.h" #include "tt_metal/impl/dispatch/kernels/packet_queue.hpp" #include "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen.hpp" +// clang-format on constexpr uint32_t src_endpoint_id = get_compile_time_arg_val(0); constexpr uint32_t num_dest_endpoints = get_compile_time_arg_val(1); @@ -14,7 +16,7 @@ static_assert(is_power_of_2(num_dest_endpoints), "num_dest_endpoints must be a p constexpr uint32_t queue_start_addr_words = get_compile_time_arg_val(2); constexpr uint32_t queue_size_words = get_compile_time_arg_val(3); -constexpr uint32_t queue_size_bytes = queue_size_words*PACKET_WORD_SIZE_BYTES; +constexpr uint32_t queue_size_bytes = queue_size_words * PACKET_WORD_SIZE_BYTES; static_assert(is_power_of_2(queue_size_words), "queue_size_words must be a power of 2"); @@ -27,15 +29,13 @@ constexpr uint32_t remote_rx_x = get_compile_time_arg_val(6); constexpr uint32_t remote_rx_y = get_compile_time_arg_val(7); constexpr uint32_t remote_rx_queue_id = get_compile_time_arg_val(8); -constexpr DispatchRemoteNetworkType - tx_network_type = - static_cast(get_compile_time_arg_val(9)); +constexpr DispatchRemoteNetworkType tx_network_type = + static_cast(get_compile_time_arg_val(9)); constexpr uint32_t test_results_addr_arg = get_compile_time_arg_val(10); constexpr uint32_t test_results_size_bytes = get_compile_time_arg_val(11); -tt_l1_ptr uint32_t* const test_results = - reinterpret_cast(test_results_addr_arg); +tt_l1_ptr uint32_t* const test_results = reinterpret_cast(test_results_addr_arg); constexpr uint32_t prng_seed = get_compile_time_arg_val(12); @@ -64,10 +64,8 @@ constexpr packet_output_queue_state_t* output_queue_ptr = &output_queue; input_queue_rnd_state_t input_queue_rnd_state; - // generates packets with ranom size and payload on the input side inline bool input_queue_handler() { - if (input_queue_rnd_state.all_packets_done()) { return true; } @@ -80,19 +78,15 @@ inline bool input_queue_handler() { // Each call to input_queue_handler initializes only up to the end // of the queue buffer, so we don't need to handle wrapping. uint32_t byte_wr_addr = input_queue_ptr->get_queue_wptr_addr_bytes(); - uint32_t words_to_init = std::min(free_words, - input_queue_ptr->get_queue_words_before_wptr_wrap()); + uint32_t words_to_init = std::min(free_words, input_queue_ptr->get_queue_words_before_wptr_wrap()); uint32_t words_initialized = 0; while (words_initialized < words_to_init) { if (input_queue_rnd_state.all_packets_done()) { break; - } - else if (!input_queue_rnd_state.packet_active()) { - input_queue_rnd_state.next_packet_rnd(num_dest_endpoints, - dest_endpoint_start_id, - max_packet_size_words, - total_data_words); + } else if (!input_queue_rnd_state.packet_active()) { + input_queue_rnd_state.next_packet_rnd( + num_dest_endpoints, dest_endpoint_start_id, max_packet_size_words, total_data_words); tt_l1_ptr dispatch_packet_header_t* header_ptr = reinterpret_cast(byte_wr_addr); @@ -105,46 +99,54 @@ inline bool input_queue_handler() { words_initialized++; input_queue_rnd_state.curr_packet_words_remaining--; byte_wr_addr += PACKET_WORD_SIZE_BYTES; - } - else { + } else { uint32_t words_remaining = words_to_init - words_initialized; uint32_t num_words = std::min(words_remaining, input_queue_rnd_state.curr_packet_words_remaining); uint32_t start_val = (input_queue_rnd_state.packet_rnd_seed & 0xFFFF0000) + (input_queue_rnd_state.curr_packet_size_words - input_queue_rnd_state.curr_packet_words_remaining); - fill_packet_data(reinterpret_cast(byte_wr_addr), - num_words, - start_val); + fill_packet_data(reinterpret_cast(byte_wr_addr), num_words, start_val); words_initialized += num_words; input_queue_rnd_state.curr_packet_words_remaining -= num_words; - byte_wr_addr += num_words*PACKET_WORD_SIZE_BYTES; + byte_wr_addr += num_words * PACKET_WORD_SIZE_BYTES; } } input_queue_ptr->advance_queue_local_wptr(words_initialized); return false; } - void kernel_main() { - zero_l1_buf(test_results, test_results_size_bytes); test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_STARTED; test_results[PQ_TEST_MISC_INDEX] = 0xff000000; - test_results[PQ_TEST_MISC_INDEX+1] = 0xcc000000 | src_endpoint_id; + test_results[PQ_TEST_MISC_INDEX + 1] = 0xcc000000 | src_endpoint_id; noc_init(); - zero_l1_buf(reinterpret_cast(queue_start_addr_words*PACKET_WORD_SIZE_BYTES), - queue_size_words); + zero_l1_buf( + reinterpret_cast(queue_start_addr_words * PACKET_WORD_SIZE_BYTES), queue_size_words); input_queue_rnd_state.init(prng_seed, src_endpoint_id); - input_queue_ptr->init(input_queue_id, queue_start_addr_words, queue_size_words, - // remote_x, remote_y, remote_queue_id, remote_update_network_type: - 0, 0, 0, DispatchRemoteNetworkType::NONE); - - output_queue_ptr->init(output_queue_id, remote_rx_queue_start_addr_words, remote_rx_queue_size_words, - remote_rx_x, remote_rx_y, remote_rx_queue_id, tx_network_type, - input_queue_ptr, 1); + input_queue_ptr->init( + input_queue_id, + queue_start_addr_words, + queue_size_words, + // remote_x, remote_y, remote_queue_id, remote_update_network_type: + 0, + 0, + 0, + DispatchRemoteNetworkType::NONE); + + output_queue_ptr->init( + output_queue_id, + remote_rx_queue_start_addr_words, + remote_rx_queue_size_words, + remote_rx_x, + remote_rx_y, + remote_rx_queue_id, + tx_network_type, + input_queue_ptr, + 1); if (!wait_all_src_dest_ready(NULL, 0, output_queue_ptr, 1, timeout_cycles)) { test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_TIMEOUT; @@ -172,7 +174,8 @@ void kernel_main() { bool all_packets_initialized = input_queue_handler(); if (input_queue_ptr->get_curr_packet_valid()) { bool full_packet_sent; - uint32_t curr_data_words_sent = output_queue_ptr->forward_data_from_input(input_queue_id, full_packet_sent); + uint32_t curr_data_words_sent = output_queue_ptr->forward_data_from_input( + input_queue_id, full_packet_sent, input_queue.get_end_of_cmd()); data_words_sent += curr_data_words_sent; progress_timestamp = (curr_data_words_sent > 0) ? get_timestamp_32b() : progress_timestamp; } else if (all_packets_initialized) { @@ -208,18 +211,17 @@ void kernel_main() { set_64b_result(test_results, data_words_sent, PQ_TEST_WORD_CNT_INDEX); set_64b_result(test_results, cycles_elapsed, PQ_TEST_CYCLES_INDEX); set_64b_result(test_results, iter, PQ_TEST_ITER_INDEX); - set_64b_result(test_results, total_data_words, PQ_TEST_MISC_INDEX+4); - set_64b_result(test_results, num_packets, PQ_TEST_MISC_INDEX+6); + set_64b_result(test_results, total_data_words, PQ_TEST_MISC_INDEX + 4); + set_64b_result(test_results, num_packets, PQ_TEST_MISC_INDEX + 6); if (!timeout) { test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_PASS; test_results[PQ_TEST_MISC_INDEX] = 0xff00004; } else { test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_TIMEOUT; - set_64b_result(test_results, words_flushed, PQ_TEST_MISC_INDEX+10); + set_64b_result(test_results, words_flushed, PQ_TEST_MISC_INDEX + 10); // these calls lead to code size issues? // input_queue_ptr->dprint_object(); // output_queue_ptr->dprint_object(); } - } diff --git a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp index 05c4a338ff59..ac8945a4d6d7 100644 --- a/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp +++ b/tests/tt_metal/tt_metal/test_kernels/dataflow/unit_tests/command_queue/pcie_write_16b.cpp @@ -11,7 +11,7 @@ void kernel_main() { constexpr uint32_t base_pcie_dst_address = get_compile_time_arg_val(1); constexpr uint32_t num_16b_writes = get_compile_time_arg_val(2); - uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)) << 32; + uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(PCIE_NOC_X, PCIE_NOC_Y, NOC_INDEX)) << 32; uint32_t l1_src_address = base_l1_src_address; uint32_t pcie_dst_address = base_pcie_dst_address; diff --git a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp index 2c963a0796d0..1281b2414ef8 100644 --- a/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp +++ b/tests/tt_metal/tt_metal/unit_tests_fast_dispatch/streams/test_autonomous_relay_streams.cpp @@ -656,7 +656,7 @@ TEST_F(CommandQueueFixture, TestAutonomousRelayStreams) { } std::srand(0); - uint32_t num_loop_iterations = 10; + uint32_t num_loop_iterations = 2; uint32_t num_messages_to_send = 1'000'000; uint32_t tx_rx_stream_buffer_size_bytes = 16 * 1024; uint32_t relay_stream_buffer_size_bytes = 16 * 1024; @@ -733,7 +733,7 @@ TEST_F(CommandQueueFixture, TestAutonomousRelayStreamsSmallPackets) { return; } -TEST_F(CommandQueueFixture, TestAutonomousRelayStreamsLoopingShort) { +TEST_F(CommandQueueFixture, DISABLED_TestAutonomousRelayStreamsLoopingShort) { auto arch = tt::get_arch_from_string(tt::test_utils::get_env_arch_name()); auto num_devices = tt::tt_metal::GetNumAvailableDevices(); if (arch == tt::ARCH::GRAYSKULL) { diff --git a/tests/ttnn/integration_tests/bert/test_performance.py b/tests/ttnn/integration_tests/bert/test_performance.py index 034df32b53d3..e29b0a44329e 100644 --- a/tests/ttnn/integration_tests/bert/test_performance.py +++ b/tests/ttnn/integration_tests/bert/test_performance.py @@ -59,7 +59,7 @@ def get_expected_times(bert): return { ttnn_bert: (0.1, 0.1), ttnn_optimized_bert: (5.5, 0.07), - ttnn_optimized_sharded_bert: (5.2, 0.07), + ttnn_optimized_sharded_bert: (5.5, 0.07), }[bert] diff --git a/tests/ttnn/integration_tests/mistral/test_mistral_attention.py b/tests/ttnn/integration_tests/mistral/test_mistral_attention.py index efc2dc36a8bb..c3a516d12df9 100644 --- a/tests/ttnn/integration_tests/mistral/test_mistral_attention.py +++ b/tests/ttnn/integration_tests/mistral/test_mistral_attention.py @@ -2,6 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 +import pytest + import torch import ttnn import tt_lib @@ -19,6 +21,7 @@ from tests.ttnn.utils_for_testing import assert_with_pcc +@pytest.mark.skip(reason="https://github.com/tenstorrent/tt-metal/issues/9076") @skip_for_wormhole_b0() def test_mistral_attention_inference(model_location_generator, device, reset_seeds): model_path = model_location_generator("mistral-7B-v0.1", model_subdir="Mistral") diff --git a/tests/ttnn/integration_tests/whisper/test_performance.py b/tests/ttnn/integration_tests/whisper/test_performance.py index b88669f43d9d..41c559c5ef04 100644 --- a/tests/ttnn/integration_tests/whisper/test_performance.py +++ b/tests/ttnn/integration_tests/whisper/test_performance.py @@ -17,7 +17,7 @@ def get_expected_times(functional_whisper): return { - ttnn_functional_whisper: (10.5, 4.16), + ttnn_functional_whisper: (11, 4.16), ttnn_optimized_functional_whisper: (1.2, 1.35), }[functional_whisper] diff --git a/tests/ttnn/unit_tests/operations/test_math.py b/tests/ttnn/unit_tests/operations/test_math.py index c1cf8198b436..1fc6f66619e6 100644 --- a/tests/ttnn/unit_tests/operations/test_math.py +++ b/tests/ttnn/unit_tests/operations/test_math.py @@ -7,6 +7,8 @@ import torch import ttnn +import tt_lib +from models.utility_functions import is_grayskull from tests.ttnn.utils_for_testing import assert_with_pcc from models.utility_functions import torch_random @@ -69,6 +71,60 @@ def test_lgamma(device, h, w): run_math_unary_test(device, h, w, ttnn.lgamma, torch.lgamma, pcc=0.999) +@pytest.mark.parametrize("h", [32]) +@pytest.mark.parametrize("w", [32]) +@pytest.mark.parametrize("output_dtype", [ttnn.DataType.BFLOAT16, ttnn.DataType.UINT16, ttnn.DataType.UINT32]) +def test_eq(device, h, w, output_dtype): + if is_grayskull() and output_dtype in (ttnn.DataType.UINT32, ttnn.DataType.UINT16): + pytest.skip("GS does not support fp32/uint32/uint16 data types") + + torch.manual_seed(0) + + same = 50 + torch_input_tensor_a = torch.rand((h, w), dtype=torch.bfloat16) + torch_input_tensor_a[0, 0] = same + torch_input_tensor_a[0, 1] = same + torch_input_tensor_a[0, 2] = same + + torch_input_tensor_b = torch.rand((h, w), dtype=torch.bfloat16) + torch_input_tensor_b[0, 0] = same + torch_input_tensor_b[0, 1] = same + torch_input_tensor_b[0, 2] = same + + torch_output_tensor = torch.eq(torch_input_tensor_a, torch_input_tensor_b) + + input_tensor_a = ttnn.from_torch( + torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + input_tensor_b = ttnn.from_torch( + torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.L1_MEMORY_CONFIG + ) + + pages_before = ttnn._ttnn.reports.get_buffer_pages() + output_tensor = ttnn.eq(input_tensor_a, input_tensor_b, dtype=output_dtype) + assert output_tensor.get_dtype() == output_dtype + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) - 1 + output_tensor = ttnn.to_torch(output_tensor) + assert_with_pcc(torch_output_tensor, output_tensor, 0.999) + + # EQ with a preallocated output tensor + output_tensor_preallocated_bfloat16 = ttnn.ones( + [h, w], ttnn.DataType.BFLOAT16, ttnn.TILE_LAYOUT, device, ttnn.L1_MEMORY_CONFIG + ) + output_tensor_preallocated = output_tensor_preallocated_bfloat16 + # There is no good way to create uint16 tensor in ttnn/torch, so we create bfloat16 and typecast to target + if output_dtype != ttnn.DataType.BFLOAT16: + output_tensor_preallocated = tt_lib.tensor.typecast( + output_tensor_preallocated_bfloat16, output_dtype, ttnn.L1_MEMORY_CONFIG + ) + + pages_before = ttnn._ttnn.reports.get_buffer_pages() + ttnn.eq(input_tensor_a, input_tensor_b, dtype=output_dtype, output_tensor=output_tensor_preallocated) + assert len(pages_before) == len(ttnn._ttnn.reports.get_buffer_pages()) + torch_output_tensor_preallocated = ttnn.to_torch(output_tensor_preallocated) + assert_with_pcc(torch_output_tensor, torch_output_tensor_preallocated, 0.999) + + @pytest.mark.parametrize("h", [64]) @pytest.mark.parametrize("w", [128]) def test_log10(device, h, w): diff --git a/tests/ttnn/unit_tests/test_multi_device.py b/tests/ttnn/unit_tests/test_multi_device.py index c8b7386279dc..501840cfe5ce 100644 --- a/tests/ttnn/unit_tests/test_multi_device.py +++ b/tests/ttnn/unit_tests/test_multi_device.py @@ -159,6 +159,8 @@ def test_multi_device_replicate(device_mesh, shape, layout, memory_config): def test_ttnn_multi_device_all_gather(pcie_device_mesh): """Multidevice API test for ttnn.all_gather CCL operation""" + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("Requires multiple devices to run") full_tensor = torch.rand((1, 1, 32, 32 * pcie_device_mesh.get_num_devices()), dtype=torch.bfloat16) ttnn_tensor = ttnn.from_torch(full_tensor, mesh_mapper=ShardTensorToMesh(pcie_device_mesh, dim=3)) diff --git a/tests/ttnn/unit_tests/test_multi_device_async.py b/tests/ttnn/unit_tests/test_multi_device_async.py index 2f5cc0e82529..35a3bf71a5bd 100644 --- a/tests/ttnn/unit_tests/test_multi_device_async.py +++ b/tests/ttnn/unit_tests/test_multi_device_async.py @@ -278,8 +278,8 @@ def test_multi_device_explicit_dealloc(pcie_device_mesh): """Multidevice API: Ensure that deallocating multi-device tensors works as expected""" from ttnn import ShardTensorToMesh, ConcatMeshToTensor, ReplicateTensorToMesh - for device in pcie_device_mesh.get_device_ids(): - pcie_device_mesh.get_device(device).enable_async(True) + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("Requires multiple devices to run") # Create input tensors that cause OOM during op execution # Explictly deallocate buffers after each op to ensure we don't run OOM. @@ -311,9 +311,6 @@ def test_multi_device_explicit_dealloc(pcie_device_mesh): ttnn_output_tensor, mesh_composer=ConcatMeshToTensor(pcie_device_mesh, dim=0) ) - for device in pcie_device_mesh.get_device_ids(): - pcie_device_mesh.get_device(device).enable_async(False) - @pytest.mark.parametrize("scalar", [3]) @pytest.mark.parametrize("size", [64]) diff --git a/tests/ttnn/unit_tests/test_multi_device_trace.py b/tests/ttnn/unit_tests/test_multi_device_trace.py index e75279713485..aa350b6d1e76 100644 --- a/tests/ttnn/unit_tests/test_multi_device_trace.py +++ b/tests/ttnn/unit_tests/test_multi_device_trace.py @@ -16,6 +16,9 @@ @pytest.mark.parametrize("use_all_gather", [True, False]) @pytest.mark.parametrize("enable_async", [True, False]) def test_multi_device_single_trace(pcie_device_mesh, shape, use_all_gather, enable_async): + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("This test requires multiple devices") + # Trace requires program cache to be enabled for device_id in pcie_device_mesh.get_device_ids(): pcie_device_mesh.get_device(device_id).enable_async(enable_async) @@ -103,6 +106,9 @@ def test_multi_device_multi_trace(pcie_device_mesh, shape, use_all_gather, enabl if shape == (1, 1, 32, 32) or shape == (1, 3, 512, 512) or shape == (1, 3, 32, 32): pytest.skip("This configuration is not working with all-gather") + if pcie_device_mesh.get_num_devices() <= 1: + pytest.skip("This test requires multiple devices") + # Trace requires program cache to be enabled for device_id in pcie_device_mesh.get_device_ids(): pcie_device_mesh.get_device(device_id).enable_async(enable_async) diff --git a/tt_eager/tensor/tensor.cpp b/tt_eager/tensor/tensor.cpp index 9f28d1035671..cdcb1b2e93e0 100644 --- a/tt_eager/tensor/tensor.cpp +++ b/tt_eager/tensor/tensor.cpp @@ -563,6 +563,8 @@ Tensor Tensor::to(Layout target_layout, DeviceMesh* device_mesh) const { auto& worker = workers[worker_index]; worker->push_work([*this, tensor_modified_layout, target_layout, worker, worker_index]() mutable { TT_ASSERT( + this->storage_type() == StorageType::OWNED || + this->storage_type() == StorageType::BORROWED|| this->storage_type() == StorageType::MULTI_DEVICE_HOST && "to(layout) must be called on host tensors with MULTI_DEVICE_HOST_STORAGE when multiple workers " "are specified"); diff --git a/tt_eager/tensor/tensor.hpp b/tt_eager/tensor/tensor.hpp index e60d7a77ef4c..fedbf54cb42e 100644 --- a/tt_eager/tensor/tensor.hpp +++ b/tt_eager/tensor/tensor.hpp @@ -334,7 +334,7 @@ struct Tensor { return buffer->device(); } else if (this->storage_type() == tt::tt_metal::StorageType::MULTI_DEVICE) { auto &storage = std::get(this->get_storage()); - return storage.get_buffer_for_device_id(0)->device(); + return this->get_workers().at(0); } else { TT_THROW("Cannot get the device from a tensor with host storage"); } diff --git a/tt_eager/tensor/tensor_utils.cpp b/tt_eager/tensor/tensor_utils.cpp index d85efa6c9f88..a5bf1dba55f8 100644 --- a/tt_eager/tensor/tensor_utils.cpp +++ b/tt_eager/tensor/tensor_utils.cpp @@ -383,15 +383,20 @@ bool is_cpu_tensor(const Tensor& tensor) { bool is_device_tensor(const Tensor& tensor) { return tensor.storage_type() == StorageType::DEVICE; } Tensor get_device_tensor(const Tensor& multi_device_tensor, const int device_id) { - const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); - if (tensor_storage.has_buffer_for_device_id(device_id)) { - return Tensor{ - DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, - multi_device_tensor.get_legacy_shape(), - multi_device_tensor.get_dtype(), - multi_device_tensor.get_layout()}; + if (std::holds_alternative(multi_device_tensor.get_storage())) { + const auto& tensor_storage = std::get(multi_device_tensor.get_storage()); + if (tensor_storage.has_buffer_for_device_id(device_id)) { + return Tensor{ + DeviceStorage{tensor_storage.get_buffer_for_device_id(device_id)}, + multi_device_tensor.get_legacy_shape(), + multi_device_tensor.get_dtype(), + multi_device_tensor.get_layout()}; + } + } else if (std::holds_alternative(multi_device_tensor.get_storage())) { + return multi_device_tensor; } - TT_THROW("Device not found in multi-device tensor"); + + TT_THROW("User is trying to access a device tensor that is not on device."); } Tensor get_device_tensor(const Tensor& multi_device_tensor, const Device* device) { diff --git a/tt_eager/tensor/types.hpp b/tt_eager/tensor/types.hpp index c60ca89118c5..dc0a421c6f1a 100644 --- a/tt_eager/tensor/types.hpp +++ b/tt_eager/tensor/types.hpp @@ -504,16 +504,17 @@ struct MultiDeviceHostStorage { inline const MemoryConfig memory_config() const { std::lock_guard lock(buffer_mtx); - if (this->buffers.at(0).get() == nullptr) { + auto first_device_id = this->ordered_device_ids.at(0); + if (this->buffers.at(first_device_id).get() == nullptr) { TT_THROW("MemoryConfig can only be obtained if the buffer is not null"); } std::optional shard_spec = std::nullopt; - if (is_sharded(this->buffers.at(0)->buffer_layout())) { - shard_spec = this->buffers.at(0)->shard_spec().tensor_shard_spec; + if (is_sharded(this->buffers.at(first_device_id)->buffer_layout())) { + shard_spec = this->buffers.at(first_device_id)->shard_spec().tensor_shard_spec; } return MemoryConfig{ - .memory_layout = this->buffers.at(0)->buffer_layout(), - .buffer_type = this->buffers.at(0)->buffer_type(), + .memory_layout = this->buffers.at(first_device_id)->buffer_layout(), + .buffer_type = this->buffers.at(first_device_id)->buffer_type(), .shard_spec = shard_spec}; } diff --git a/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp b/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp index 32debd44f72b..964e67305b16 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp +++ b/tt_eager/tt_dnn/op_library/all_gather/all_gather_op.hpp @@ -47,7 +47,7 @@ class AllGatherConfig { erisc_handshake_address(round_up(eth_l1_mem::address_map::ERISC_L1_UNRESERVED_BASE, 16)), topology(topology), - enable_bidirectional(/*false*/topology == all_gather_op::Topology::Ring && dim != 0 && dim != 1), + enable_bidirectional(topology == all_gather_op::Topology::Ring), input_is_dram(input_tensor.buffer()->buffer_type() == BufferType::DRAM), output_is_dram(output_tensor.buffer()->buffer_type() == BufferType::DRAM), diff --git a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp index 2c7f486cd81d..9ffcba874adc 100644 --- a/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/all_gather/multi_core/all_gather_op_multi_core.cpp @@ -318,6 +318,8 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& TT_ASSERT(rem_pages < pages_per_chunk || num_full_chunks == 0); TT_ASSERT(rem_pages <= max_pages_per_chunk); std::vector num_full_chunks_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), num_full_chunks / all_gather_config.get_num_eth_buffers_per_edm()); + std::vector is_channel_shrinkable(all_gather_config.get_num_eth_buffers_per_edm(), false); + std::vector largest_packets_per_channel(all_gather_config.get_num_eth_buffers_per_edm(), 0); std::vector rem_pages_per_worker(all_gather_config.get_num_eth_buffers_per_edm(), 0); { uint32_t worker_idx = 0; @@ -355,10 +357,22 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& ); uint32_t max_shards_per_eth_buffer = std::min(all_gather_config.get_eth_buffer_size() / input_tensor_shard_arg_generator.args_struct.shard_size_in_bytes, input_tensor_shard_arg_generator.args_struct.num_dest_cores); TT_ASSERT(max_shards_per_eth_buffer > 0, "Codepath needs further generalization to support computing multiple sends per shard. Shard size: {}", input_tensor_shard_arg_generator.args_struct.shard_size_in_bytes); + log_info(tt::LogOp, "max_shards_per_eth_buffer: {}", max_shards_per_eth_buffer); num_full_chunks_per_worker.at(b) = input_tensor_shard_arg_generator.args_struct.num_dest_cores < max_shards_per_eth_buffer ? 1 : input_tensor_shard_arg_generator.args_struct.num_dest_cores / max_shards_per_eth_buffer; rem_pages_per_worker.at(b) = max_shards_per_eth_buffer > input_tensor_shard_arg_generator.args_struct.num_dest_cores ? 0 : input_tensor_shard_arg_generator.args_struct.num_dest_cores - (num_full_chunks_per_worker.at(b) * max_shards_per_eth_buffer); TT_ASSERT(rem_pages_per_worker.at(b) == 0 || input_tensor_shard_arg_generator.args_struct.num_dest_cores >= num_full_chunks_per_worker.at(b) * max_shards_per_eth_buffer); TT_ASSERT(input_tensor_shard_arg_generator.args_struct.num_dest_cores == rem_pages_per_worker.at(b) + num_full_chunks_per_worker.at(b) * max_shards_per_eth_buffer); + + uint32_t full_chunk_size_bytes = max_shards_per_eth_buffer * input_tensor_shard_arg_generator.args_struct.shard_size_in_bytes; + bool shrinkable = num_full_chunks_per_worker.at(b) == 1 && all_gather_config.get_eth_buffer_size() > full_chunk_size_bytes; + is_channel_shrinkable.at(b) = shrinkable; + largest_packets_per_channel.at(b) = shrinkable ? full_chunk_size_bytes : all_gather_config.get_eth_buffer_size(); + } + } else { + for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { + bool shrinkable = num_full_chunks_per_worker.at(b) == 0; + is_channel_shrinkable.at(b) = shrinkable; + largest_packets_per_channel.at(b) = shrinkable ? rem_pages_per_worker.at(b) * input_page_size : all_gather_config.get_eth_buffer_size(); } } for(uint32_t b = 0; b < all_gather_config.get_num_eth_buffers_per_edm(); ++b) { @@ -412,6 +426,11 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "Adding sender EDM channel"); EriscDatamoverBuilder::ChannelBufferInterface const& sender_channel_buffer_info = sender_edm_builder.add_sender_channel(sender_worker_writer_semaphore_addr, clockwise_link_buffer_num_messages_to_send.at(b), sender_worker_coords); + if (is_channel_shrinkable.at(b)) { + TT_ASSERT(largest_packets_per_channel.at(b) > 0); + log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b); + sender_edm_builder.set_max_message_size_bytes(sender_channel_buffer_info.channel, largest_packets_per_channel.at(b)); + } sender_eth_sem_addrs.push_back(sender_channel_buffer_info.eth_semaphore_l1_address); sender_eth_buffer_addrs.push_back(sender_channel_buffer_info.eth_buffer_l1_address); } @@ -422,6 +441,11 @@ operation::ProgramWithCallbacks all_gather_multi_core_with_workers(const Tensor& log_trace(tt::LogOp, "Adding receiver EDM channel"); EriscDatamoverBuilder::ChannelBufferInterface const& receiver_channel_buffer_info = receiver_edm_builder.add_receiver_channel(receiver_worker_semaphore_addr, counter_clockwise_link_buffer_num_messages_to_send.at(b), receiver_worker_coords); + if (is_channel_shrinkable.at(b)) { + TT_ASSERT(largest_packets_per_channel.at(b) > 0); + log_trace(tt::LogOp, "\tsetting channel_max_size to {} for channel {}", largest_packets_per_channel.at(b), b); + receiver_edm_builder.set_max_message_size_bytes(receiver_channel_buffer_info.channel, largest_packets_per_channel.at(b)); + } receiver_eth_sem_addrs.push_back(receiver_channel_buffer_info.eth_semaphore_l1_address); receiver_eth_buffer_addrs.push_back(receiver_channel_buffer_info.eth_buffer_l1_address); } diff --git a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp index 3100d466520d..29cbae919476 100644 --- a/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/bmm_op.cpp @@ -1041,7 +1041,7 @@ void Matmul::validate( // subbblock constraint TT_FATAL(program_config.out_subblock_w == per_core_N || program_config.out_subblock_h == 1); // tensor in1 - TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED); + TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::WIDTH_SHARDED); } else if constexpr (std::is_same_v) { if (input_tensor_a.memory_config().is_sharded()) { auto tensor_a_memory_layout = input_tensor_a.memory_config().memory_layout; diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp b/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp index ca37e340a703..ede0790e5ea6 100644 --- a/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp @@ -68,6 +68,14 @@ inline void reblock_and_untilize( } void MAIN { + // RUNTIME ARGS + #ifdef MATMUL_DRAM_SHARDED + const bool is_worker_core = get_arg_val(0) == 1; + // if not worker core, skip + if (not is_worker_core) { + return; + } + #endif constexpr uint32_t in0_block_w = get_compile_time_arg_val(0); // inner block size in tiles constexpr uint32_t in0_num_subblocks = get_compile_time_arg_val(1); // outer row block size (in inner row blocks) diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp index 109def5e9bcd..bbe72e1f48ab 100644 --- a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp @@ -31,7 +31,11 @@ void kernel_main() { constexpr uint32_t num_storage_cores = num_blocks / num_blocks_per_shard; // RUNTIME ARGS - const bool is_worker_core = get_arg_val(0) == 1; + const uint32_t worker_core_type = get_arg_val(0); + // if not worker core, skip + if (worker_core_type == 0) { + return; + } const uint32_t sender_id = get_arg_val(1); volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(2)); volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(2 + num_storage_cores)); @@ -71,7 +75,7 @@ void kernel_main() { uint32_t local_read_addr = get_read_ptr(cb_id_in2); - if (not is_worker_core) { + if (worker_core_type == 1) { // mcast sender + no compute for (uint32_t i = 0; i < num_blocks_per_shard; ++i) { const uint32_t block_id = sender_block_id + i; @@ -101,7 +105,8 @@ void kernel_main() { local_read_addr += in0_block_size_bytes; } - } else { + } else if (worker_core_type == 2) { // mcast sender + compute + for(uint32_t block = 0; block < num_blocks; ++block) { const uint32_t block_id = block / num_blocks_per_shard; @@ -138,5 +143,27 @@ void kernel_main() { cb_push_back(cb_id_in0, in0_block_num_tiles); } + } else { // mcast receiver + compute + + for(uint32_t block = 0; block < num_blocks; ++block) { + const uint32_t block_id = block / num_blocks_per_shard; + + // get the mcast sender noc + uint64_t in0_mcast_sender_semaphore_noc_addr = get_noc_addr(in0_mcast_sender_noc_x[block_id], in0_mcast_sender_noc_y[block_id], in0_mcast_sender_semaphore_addr); + + // Operand 0 + cb_reserve_back(cb_id_in0, in0_block_num_tiles); + + // Set in0 semaphore value to INVALID + noc_semaphore_set(in0_mcast_receiver_semaphore_addr_ptr, INVALID); + + // Atomic increment source core counter + noc_semaphore_inc(in0_mcast_sender_semaphore_noc_addr, 1); + + // wait on in0 semaphore value to become VALID (set by mcast sender after it multicasts data) + noc_semaphore_wait(in0_mcast_receiver_semaphore_addr_ptr, VALID); + + cb_push_back(cb_id_in0, in0_block_num_tiles); + } } } diff --git a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp index 5bde1c06534c..0546a8db1c0e 100644 --- a/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp @@ -9,17 +9,23 @@ void kernel_main() { // RUNTIME ARGS - const uint32_t in1_tensor_addr = get_arg_val(0); + const bool is_worker_core = get_arg_val(0) == 1; + // if not worker core, skip + if (not is_worker_core) { + return; + } + + const uint32_t in1_tensor_addr = get_arg_val(1); #ifdef FUSE_BIAS - const uint32_t in3_tensor_addr = get_arg_val(1); + const uint32_t in3_tensor_addr = get_arg_val(2); #endif - const uint32_t dram_bank_id = get_arg_val(2); - const uint32_t vc = get_arg_val(3); - const uint32_t num_shard_to_write_back = get_arg_val(4); - const uint32_t reshard_tensor_start_offset = get_arg_val(5); - volatile tt_l1_ptr uint32_t * per_core_N_reshard_bytes = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(6)); - volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(7)); - volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(8)); + const uint32_t dram_bank_id = get_arg_val(3); + const uint32_t vc = get_arg_val(4); + const uint32_t num_shard_to_write_back = get_arg_val(5); + const uint32_t reshard_tensor_start_offset = get_arg_val(6); + volatile tt_l1_ptr uint32_t * per_core_N_reshard_bytes = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(7)); + volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_x = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(8)); + volatile tt_l1_ptr uint32_t * in0_mcast_sender_noc_y = (volatile tt_l1_ptr uint32_t*)(get_arg_addr(9)); // COMPILE TIME ARGS diff --git a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp index 38efb0589eff..9b8b3200eaa3 100644 --- a/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp +++ b/tt_eager/tt_dnn/op_library/bmm/multi_core_reuse_mcast_dram_sharded_optimized/bmm_op_multi_core_reuse_dram_sharded_optimized.cpp @@ -503,6 +503,13 @@ operation::ProgramWithCallbacks create_program_dram_sharded( log_debug("all_cores: {}", core); } + // grid bounding box + CoreRange bounding_box = all_cores.bounding_box(); + std::set bounding_box_set; bounding_box_set.insert(bounding_box); + CoreRangeSet all_cores_in_rect_grid(bounding_box_set); + std::vector all_cores_in_rect_grid_vec = corerange_to_cores(all_cores_in_rect_grid); + log_debug("bounding_box: {}", bounding_box); + // Mcast args auto in0_mcast_sender_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); auto in0_mcast_receiver_semaphore = tt_metal::CreateSemaphore(program, all_cores, INVALID); @@ -581,16 +588,6 @@ operation::ProgramWithCallbacks create_program_dram_sharded( in1_sender_writer_compile_time_args.push_back(bias_buffer_num_pages); in1_sender_writer_compile_time_args.push_back((std::uint32_t)1); } - std::vector in0_receiver_compile_time_args = { - // in0 block args - (std::uint32_t)in0_block_w * per_core_M, // in0_block_num_tiles - // in0/in1 common args - (std::uint32_t)num_blocks, // num_blocks - // in0 mcast args - (std::uint32_t)in0_mcast_sender_semaphore, - (std::uint32_t)in0_mcast_receiver_semaphore, - // - (std::uint32_t)num_blocks_per_shard}; std::map mm_kernel_defines; std::map mm_kernel_in0_sender_define; @@ -625,11 +622,12 @@ operation::ProgramWithCallbacks create_program_dram_sharded( if (skip_write_back) { mm_kernel_in1_sender_writer_defines["SKIP_WRITE_BACK"] = "1"; } + mm_kernel_defines["MATMUL_DRAM_SHARDED"] = "1"; auto mm_kernel_in0_sender_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_sender_dram_sharded.cpp", - mcast_senders, + all_cores_in_rect_grid, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_1, .noc = in0_noc, @@ -639,22 +637,13 @@ operation::ProgramWithCallbacks create_program_dram_sharded( auto mm_kernel_in1_sender_writer_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in1_sender_dram_sharded.cpp", - all_worker_cores, + all_cores_in_rect_grid, tt_metal::DataMovementConfig{ .processor = tt_metal::DataMovementProcessor::RISCV_0, .noc = in1_noc, .compile_args = in1_sender_writer_compile_time_args, .defines = mm_kernel_in1_sender_writer_defines}); - KernelHandle mm_kernel_in0_receiver_id = tt_metal::CreateKernel( - program, - "tt_eager/tt_dnn/op_library/bmm/kernels/dataflow/reader_bmm_tile_layout_in0_receiver_dram_sharded.cpp", - mcast_receivers, - tt_metal::DataMovementConfig{ - .processor = tt_metal::DataMovementProcessor::RISCV_1, - .noc = in0_noc, - .compile_args = in0_receiver_compile_time_args}); - // Compute kernel compile time args uint32_t in0_subblock_num_tiles = out_subblock_h * in0_block_w; @@ -687,7 +676,8 @@ operation::ProgramWithCallbacks create_program_dram_sharded( auto mm_kernel = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/bmm/kernels/compute/bmm_large_block_zm_fused_bias_activation.cpp", - all_worker_cores, + // all_worker_cores, + all_cores_in_rect_grid, tt_metal::ComputeConfig{ .math_fidelity = math_fidelity, .fp32_dest_acc_en = fp32_dest_acc_en, @@ -850,14 +840,15 @@ operation::ProgramWithCallbacks create_program_dram_sharded( for (auto core : mcast_senders_coords) { std::vector mm_in0_sender_args; - bool is_worker_core; + // mcast sender - 1, mcast sender + compute core - 2 + uint32_t worker_core_type; if (find(storage_worker_common.begin(), storage_worker_common.end(), core) != storage_worker_common.end()) { - is_worker_core = true; + worker_core_type = 2; } else { - is_worker_core = false; + worker_core_type = 1; } - mm_in0_sender_args.push_back((std::uint32_t)is_worker_core); + mm_in0_sender_args.push_back((std::uint32_t)worker_core_type); mm_in0_sender_args.push_back((std::uint32_t)sender_id); mm_in0_sender_args.insert( mm_in0_sender_args.end(), in0_mcast_sender_noc_x.begin(), in0_mcast_sender_noc_x.end()); @@ -876,12 +867,30 @@ operation::ProgramWithCallbacks create_program_dram_sharded( // in0 receivers rt args std::vector mm_in0_receiver_args; + // mcast receiver - 3 + uint32_t worker_core_type = 3; + mm_in0_receiver_args.push_back((std::uint32_t)worker_core_type); + mm_in0_receiver_args.push_back((std::uint32_t) 0); mm_in0_receiver_args.insert( mm_in0_receiver_args.end(), in0_mcast_sender_noc_x.begin(), in0_mcast_sender_noc_x.end()); mm_in0_receiver_args.insert( mm_in0_receiver_args.end(), in0_mcast_sender_noc_y.begin(), in0_mcast_sender_noc_y.end()); - tt_metal::SetRuntimeArgs(program, mm_kernel_in0_receiver_id, core, mm_in0_receiver_args); - reader_kernel_ids.push_back(mm_kernel_in0_receiver_id); + + tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_receiver_args); + reader_kernel_ids.push_back(mm_kernel_in0_sender_id); + } + + for (auto core : all_cores_in_rect_grid_vec) { + if (std::find(mcast_senders_coords.begin(), mcast_senders_coords.end(), core) == mcast_senders_coords.end() and + std::find(mcast_receiver_coords.begin(), mcast_receiver_coords.end(), core) == mcast_receiver_coords.end()) { + // in0 receivers rt args + std::vector mm_in0_idle_args; + // idle core - 0 + uint32_t worker_core_type = 0; + mm_in0_idle_args.push_back((std::uint32_t)worker_core_type); + + tt_metal::SetRuntimeArgs(program, mm_kernel_in0_sender_id, core, mm_in0_idle_args); + } } uint32_t bank_id = 0; @@ -894,11 +903,40 @@ operation::ProgramWithCallbacks create_program_dram_sharded( uint32_t curr_worker_core = 0; uint32_t curr_storage_core = 0; + // for all the cores in the rect grid, we send one rt arg to determine if they are worker core + for (uint32_t i = 0; i < all_cores_in_rect_grid_vec.size(); ++i) { + auto core = all_cores_in_rect_grid_vec[i]; + + if (all_worker_cores.ranges().find(core) == all_worker_cores.ranges().end()) { // not worker + // in1 reader rt args + bool is_worker_core = false; + std::vector mm_in1_sender_writer_args; + mm_in1_sender_writer_args.push_back((std::uint32_t) is_worker_core); + + tt_metal::SetRuntimeArgs(program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); + + // compute rt args + std::vector mm_compute_args; + mm_compute_args.push_back((std::uint32_t) is_worker_core); + + tt_metal::SetRuntimeArgs(program, mm_kernel, core, mm_compute_args); + } else { + // compute rt args + bool is_worker_core = true; + std::vector mm_compute_args; + mm_compute_args.push_back((std::uint32_t) is_worker_core); + + tt_metal::SetRuntimeArgs(program, mm_kernel, core, mm_compute_args); + } + } + for (uint32_t i = 0; i < all_worker_cores_ordered.size(); ++i) { auto core = all_worker_cores_ordered[i]; // in1 reader rt args + bool is_worker_core = true; std::vector mm_in1_sender_writer_args; + mm_in1_sender_writer_args.push_back((std::uint32_t) is_worker_core); mm_in1_sender_writer_args.push_back(in1_buffer->address()); if (bias_buffer != nullptr) { mm_in1_sender_writer_args.push_back(bias_buffer->address()); @@ -1014,7 +1052,7 @@ operation::ProgramWithCallbacks create_program_dram_sharded( } } - mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 4, num_iter); + mm_in1_sender_writer_args.insert(mm_in1_sender_writer_args.begin() + 5, num_iter); } tt_metal::SetRuntimeArgs(program, mm_kernel_in1_sender_writer_id, core, mm_in1_sender_writer_args); @@ -1044,11 +1082,11 @@ operation::ProgramWithCallbacks create_program_dram_sharded( auto core = all_worker_cores_ordered[i]; auto writer_kernel_id = writer_kernel_ids[i]; auto& writer_runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - writer_runtime_args[0] = src_buffer_b->address(); + writer_runtime_args[1] = src_buffer_b->address(); if (bias_tensor.has_value()) { - writer_runtime_args[1] = bias_tensor.value().buffer()->address(); + writer_runtime_args[2] = bias_tensor.value().buffer()->address(); } else { - writer_runtime_args[1] = 0; + writer_runtime_args[2] = 0; } } }; diff --git a/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp b/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp index 193046b8c549..89c237a6cb60 100644 --- a/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp +++ b/tt_eager/tt_dnn/op_library/ccl/ccl_host_datastructures.hpp @@ -7,6 +7,7 @@ #include "eth_l1_address_map.h" #include "tensor/tensor_impl.hpp" #include "tt_eager/tt_dnn/op_library/ccl/shared_with_host/hetergeneous_data_structs.hpp" +#include namespace tt { namespace tt_metal { @@ -130,19 +131,25 @@ class EriscDatamoverBuilder { worker_semaphore_address(worker_semaphore_address), num_eth_messages_to_forward(num_eth_messages_to_forward), channel(channel), + largest_message_size_bytes(0), is_sender(is_sender) {} std::vector const worker_coords; uint32_t worker_semaphore_address; uint32_t num_eth_messages_to_forward; uint32_t channel; + uint32_t largest_message_size_bytes; bool is_sender; }; void push_back_channel_args(std::vector& args, ChannelBufferSpec const& channel) const { args.push_back(this->local_buffer_addresses.at(channel.channel)); args.push_back(channel.num_eth_messages_to_forward); - args.push_back(this->eth_buffer_size_bytes); + if (channel.largest_message_size_bytes > 0) { + args.push_back(std::min(channel.largest_message_size_bytes, this->eth_buffer_size_bytes)); + } else { + args.push_back(this->eth_buffer_size_bytes); + } args.push_back(this->local_semaphore_addresses.at(channel.channel)); args.push_back(channel.worker_semaphore_address); args.push_back(channel.worker_coords.size()); @@ -167,6 +174,7 @@ class EriscDatamoverBuilder { public: struct ChannelBufferInterface { + std::size_t channel; uint32_t eth_buffer_l1_address; uint32_t eth_semaphore_l1_address; }; @@ -224,8 +232,16 @@ class EriscDatamoverBuilder { log_trace(tt::LogOp, "\tbuffer_address: {}", local_buffer_addresses.at(channel)); log_trace(tt::LogOp, "\tsemaphore_address: {}", local_semaphore_addresses.at(channel)); - return ChannelBufferInterface{local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; + return ChannelBufferInterface{channel, local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; + } + + // This function is used to set the maximum message size for a given channel. If the maximum + // message size is < EDM channel buffer size, then the buffer size passed to the EDM for this channel + // will be trimmed be no larger than the largest message to save on unnecessary eth bandwidth. + void set_max_message_size_bytes(std::size_t channel, std::size_t max_message_size_bytes) { + active_channels.at(channel).largest_message_size_bytes = std::max(active_channels.at(channel).largest_message_size_bytes, max_message_size_bytes); } + [[nodiscard]] ChannelBufferInterface add_receiver_channel( uint32_t worker_semaphore_address, @@ -241,7 +257,7 @@ class EriscDatamoverBuilder { log_trace(tt::LogOp, "\tnum_eth_messages_to_forward: {}", active_channels.back().num_eth_messages_to_forward); log_trace(tt::LogOp, "\tchannel: {}", active_channels.back().channel); log_trace(tt::LogOp, "\tis_sender: {}", active_channels.back().is_sender ? 1 : 0); - return ChannelBufferInterface{local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; + return ChannelBufferInterface{channel, local_buffer_addresses.at(channel), local_semaphore_addresses.at(channel)}; } [[nodiscard]] diff --git a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp index 73fde5957023..4982a1e2110d 100644 --- a/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp +++ b/tt_eager/tt_dnn/op_library/ccl/reduce_scatter/host/reduce_scatter_full_worker_grid.cpp @@ -472,7 +472,7 @@ static std::tuple build_reduce_scatter_worker( vector compute_kernel_args = {}; constexpr bool fp32_dest_acc_en = false; constexpr bool math_approx_mode = false; - std::map eltwise_defines = eltwise_binary_op_utils::get_defines(binary_math_op, std::nullopt); + std::map eltwise_defines = eltwise_binary_op_utils::get_defines(binary_math_op); KernelHandle worker_reduce_kernel_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/eltwise_binary/kernels/compute/eltwise_binary.cpp", diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp index 7db4638049f9..97bd34762386 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.cpp @@ -1228,48 +1228,84 @@ Tensor _where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config) { + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); - Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + if(output_tensor.has_value()) + { + mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v1( - const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) { + const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config); - Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + + if(output_tensor.has_value()){ + mul(lez(predicate, output_mem_config), value_false, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(t2, output_tensor.value(), std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t1 = mul(lez(predicate, output_mem_config), value_false, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v2( - const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config) { - Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); + const Tensor& predicate, const Tensor& value_true, float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + + if(output_tensor.has_value()){ + mul(gtz(predicate, output_mem_config), value_true, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + add(output_tensor.value(), t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } + else + { + Tensor t2 = mul(gtz(predicate, output_mem_config), value_true, std::nullopt, output_mem_config); + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } Tensor _where_v3( - const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) { + const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { Tensor t2 = mul_unary(gtz(predicate, output_mem_config), value_true, output_mem_config); Tensor t1 = mul_unary(lez(predicate, output_mem_config), value_false, output_mem_config); - return add(t2, t1, std::nullopt, output_mem_config); + if(output_tensor.has_value()){ + add(t2, t1, std::nullopt, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, std::nullopt, output_tensor.value()); + } else { + output_tensor = add(t2, t1, std::nullopt, output_mem_config); + } + return output_tensor.value(); } - Tensor where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config); + const MemoryConfig& output_mem_config, + std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const float value_true, const Tensor& value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v1)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const Tensor& value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v2)(predicate, value_true, value_false, output_mem_config, output_tensor); } Tensor where( - const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config) { - return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config); + const Tensor& predicate, const float value_true, const float value_false, const MemoryConfig& output_mem_config, std::optional output_tensor) { + return operation::decorate_as_composite(__func__, _where_v3)(predicate, value_true, value_false, output_mem_config, output_tensor); } // on-device tensor creation 0s like @reference_tensor diff --git a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp index 0d79d22a44ea..45edd04a6aca 100644 --- a/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp +++ b/tt_eager/tt_dnn/op_library/composite/composite_ops.hpp @@ -316,22 +316,26 @@ Tensor where( const Tensor& predicate, const Tensor& value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const float value_true, const Tensor& value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const Tensor& value_true, const float value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); Tensor where( const Tensor& predicate, const float value_true, const float value_false, - const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG); + const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, + std::optional output_tensor = std::nullopt); // on-device tensor creation 0s like @reference_tensor Tensor zeros_like( diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp index ea091ce92695..bdd11b215dae 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.cpp @@ -13,11 +13,13 @@ using namespace tt::constants; +namespace tt { +namespace tt_metal { namespace eltwise_binary_op_utils { using namespace tt::tt_metal; std::map get_defines( - BinaryOpType op_type, const std::optional> fused_activations) { + BinaryOpType op_type, const std::optional output_dtype, const std::optional> fused_activations) { std::map defines; string op_name = "sub_tiles"; string op_binary_type = "EltwiseBinaryType::ELWSUB"; @@ -104,6 +106,15 @@ std::map get_defines( default: TT_ASSERT(false && "Undefined op type"); } + if(output_dtype.has_value() && output_dtype.value() == DataType::UINT32){ + TT_ASSERT(defines.count("SFPU_OP_CHAIN_0") == 0 && "SFPU_OP_CHAIN_0 already defined"); + + auto dataformat = std::to_string((uint32_t)datatype_to_dataformat_converter(output_dtype.value())); + defines.insert({"SFPU_OP_CHAIN_0", + fmt::format("typecast_tile_init(); typecast_tile<{0}u>(i);", dataformat)}); + defines.insert({"SFPU_OP_TYPECAST_INCLUDE", "1"}); + } + defines["ELTWISE_OP"] = op_name.c_str(); defines["ELTWISE_OP_TYPE"] = op_binary_type.c_str(); if (fused_activations.has_value()) { @@ -120,11 +131,6 @@ std::map get_defines( } // namespace eltwise_binary_op_utils -namespace tt { - -namespace tt_metal { - - void EltwiseBinary::validate_with_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp index a774520904f8..d69e84c3265c 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/eltwise_binary_op.hpp @@ -12,6 +12,7 @@ #include "tt_dnn/op_library/repeat/repeat_op.hpp" #include "tt_dnn/op_library/run_operation.hpp" #include "tt_metal/host_api.hpp" +#include "tt_metal/common/logger.hpp" namespace tt { @@ -38,6 +39,14 @@ enum class BinaryOpType { DIV_FAST }; +namespace eltwise_binary_op_utils { + +std::map get_defines(BinaryOpType op_type, const std::optional out_dtype = std::nullopt, + const std::optional> fused_activations = std::nullopt); + +} // namespace eltwise_binary_op_utils + + enum class BinaryOpParallelizationStrategy { MULTI_CORE }; operation::ProgramWithCallbacks eltwise_binary_multi_core( @@ -132,14 +141,16 @@ struct make_eltwise_binary { (in_a.get_legacy_shape() == in_b.get_legacy_shape()) or (in_a.get_legacy_shape().without_padding() == in_b.get_legacy_shape().without_padding()), "Input shapes must be the same!"); - return operation::run( + + auto output_tensors = operation::run( EltwiseBinary{ binary_op_type, fused_activations, output_mem_config, output_dtype.value_or(in_a.get_dtype()), - false}, + false /*in place*/}, {in_a, in_b}, {}, {output_tensor}); + return output_tensors; }, {input_tensor_a, input_tensor_b}, output_tensors, {}, {output_tensor}); return output_tensors.at(0); @@ -231,11 +242,3 @@ inline Tensor add( } // namespace operations } // namespace tt - -namespace eltwise_binary_op_utils { -using namespace tt::tt_metal; - -std::map get_defines( - BinaryOpType op_typee, const std::optional> fused_activations); - -} // namespace eltwise_binary_op_utils diff --git a/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp b/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp index 37a772afb208..f9bf11ef33e3 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_binary/multi_core/eltwise_binary_op_multi_core.cpp @@ -312,7 +312,7 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const Tensor &a, const } auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config); - std::map eltwise_defines = eltwise_binary_op_utils::get_defines(op_type, fused_activations); + std::map eltwise_defines = eltwise_binary_op_utils::get_defines(op_type, output.get_dtype(), fused_activations); if (eltwise_defines.find("SFPU_OP_INIT_PRE_IN0_0") != eltwise_defines.end()) { tt_metal::CircularBufferConfig cb_interm_config = tt_metal::CircularBufferConfig(1 * src0_single_tile_size, {{CB::c_intermed0, src0_cb_data_format}}) @@ -371,12 +371,12 @@ operation::ProgramWithCallbacks eltwise_binary_multi_core(const Tensor &a, const all_device_cores, tt_metal::WriterDataMovementConfig(writer_compile_time_args, writer_defines)); + bool fp32_dest_acc_en = dst_cb_data_format == tt::DataFormat::UInt32 || dst_cb_data_format == tt::DataFormat::Int32 || dst_cb_data_format == tt::DataFormat::Float32; auto eltwise_binary_kernel_id = tt_metal::CreateKernel( program, "tt_eager/tt_dnn/op_library/eltwise_binary/kernels/compute/eltwise_binary.cpp", all_device_cores, - tt_metal::ComputeConfig{.defines = eltwise_defines} - ); + tt_metal::ComputeConfig{.fp32_dest_acc_en=fp32_dest_acc_en, .defines = eltwise_defines}); set_eltwise_binary_runtime_args( diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp index d958fc0c1f0b..73b14e2b1127 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp @@ -177,6 +177,11 @@ std::pair get_op_init_and_func_parameterized( Converter::to_hex(param1))}; break; } + case UnaryOpType::TYPECAST: + op_init_and_name = { + "typecast_tile_init();", + fmt::format("typecast_tile<{1}u>({0});", idst, std::to_string((uint32_t)datatype_to_dataformat_converter((DataType)param0)))}; + break; default: TT_ASSERT(false && "unexpected parameterized type"); }; return op_init_and_name; @@ -258,9 +263,6 @@ std::pair get_op_init_and_func_default(UnaryOpType op_type, stri case UnaryOpType::NEG: op_init_and_name = {"negative_tile_init();", fmt::format("negative_tile({});", idst)}; break; - case UnaryOpType::TYPECAST: - op_init_and_name = {"typecast_tile_init();", fmt::format("typecast_tile({});", idst)}; - break; default: TT_ASSERT(false && "Undefined non-parametrized op type"); } return op_init_and_name; diff --git a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp index f9f8a2521c07..6dece1630529 100644 --- a/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp +++ b/tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.hpp @@ -104,7 +104,8 @@ bool is_parametrized_type(T val) { case UnaryOpType::DIV_UNARY_SFPU: case UnaryOpType::UNARY_NE: case UnaryOpType::UNARY_GT: - case UnaryOpType::UNARY_LT: return true; + case UnaryOpType::UNARY_LT: + case UnaryOpType::TYPECAST: return true; default: return false; } return false; @@ -195,7 +196,7 @@ inline Tensor run_eltwise_unary( const std::vector& ops_chain, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified"); - DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? DataType::UINT32 : input_tensor.get_dtype(); + DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast(ops_chain[0].params[0]) : input_tensor.get_dtype(); bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_tensor.get_dtype() == DataType::UINT32 or @@ -241,7 +242,7 @@ inline Tensor run_eltwise_unary( const std::vector& ops_chain, const MemoryConfig& output_mem_config = operation::DEFAULT_OUTPUT_MEMORY_CONFIG) { TT_FATAL(ops_chain.size() > 0, "At least 1 unary op must be specified"); - DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? DataType::UINT32 : input_tensor.get_dtype(); + DataType output_dtype = (ops_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast(ops_chain[0].params[0]) : input_tensor.get_dtype(); bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or input_tensor.get_dtype() == DataType::UINT32 or @@ -369,7 +370,7 @@ constexpr auto rsub = make_eltwise_unary_with_param{}; constexpr auto silu = make_eltwise_unary{}; constexpr auto identity = make_eltwise_unary{}; constexpr auto identity_uint32 = make_eltwise_unary{}; -constexpr auto eltwise_typecast = make_eltwise_unary{}; +constexpr auto eltwise_typecast = make_eltwise_unary_with_param{}; constexpr auto add_unary_sfpu = make_eltwise_symmetric_binop_unary_with_param{}; constexpr auto mul_unary_sfpu = make_eltwise_symmetric_binop_unary_with_param{}; constexpr auto unary_gt = make_eltwise_unary_with_param{}; diff --git a/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp b/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp index 3f0d618224c7..8823e9aa5c68 100644 --- a/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw.cpp @@ -10,8 +10,8 @@ #include "tt_dnn/op_library/run_operation.hpp" #include "tt_eager/tensor/tensor.hpp" #include "tt_eager/tensor/tensor_impl.hpp" -#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" #include "tt_eager/tt_dnn/op_library/moreh_adamw/moreh_adamw_op.hpp" +#include "tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp" #include "tt_eager/tt_dnn/op_library/work_split.hpp" #include "tt_metal/common/math.hpp" #include "tt_metal/detail/util.hpp" @@ -26,9 +26,14 @@ operation::ProgramWithCallbacks moreh_adamw_( const Tensor& grad, const Tensor& exp_avg, const Tensor& exp_avg_sq, - float lr, float beta1, float beta2, float eps, float weight_decay, uint32_t step, bool amsgrad, + float lr, + float beta1, + float beta2, + float eps, + float weight_decay, + uint32_t step, + bool amsgrad, const std::optional> max_exp_avg_sq) { - uint32_t num_tiles = param.volume() / TILE_HW; Program program{}; @@ -36,14 +41,15 @@ operation::ProgramWithCallbacks moreh_adamw_( //////////////////////////////////////////////////////////////////////////// // Device Setup //////////////////////////////////////////////////////////////////////////// - tt_metal::Device *device = param.device(); + tt_metal::Device* device = param.device(); auto grid = device->compute_with_storage_grid_size(); const auto num_cores_y = grid.y; // auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); // uint32_t num_cores_x = compute_with_storage_grid_size.x; // uint32_t num_cores_y = compute_with_storage_grid_size.y; - auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = tt_metal::split_work_to_cores(grid, num_tiles); + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = + tt_metal::split_work_to_cores(grid, num_tiles); //////////////////////////////////////////////////////////////////////////// // CircularBuffer Setup @@ -54,27 +60,27 @@ operation::ProgramWithCallbacks moreh_adamw_( all_cores, data_format, { - {CB::c_in0, 1}, // param - {CB::c_in1, 1}, // grad - {CB::c_in2, 1}, // exp_avg - {CB::c_in3, 1}, // exp_avg_sq - {CB::c_in4, 1}, // max_exp_avg_sq (optional) - {CB::c_in5, 5}, // lr, beta1, beta2, eps, weight_decay - {CB::c_in6, 1}, // 1.0f - - {CB::c_intermed0, 1}, // tmp_grad - {CB::c_intermed1, 1}, // tmp_exp_avg - {CB::c_intermed2, 1}, // tmp_exp_avg_sq - {CB::c_intermed3, 1}, // tmp_max_exp_avg_sq - {CB::c_intermed4, 1}, // - {CB::c_intermed5, 1}, // - {CB::c_intermed6, 1}, // tmp1 - {CB::c_intermed7, 1}, // tmp2 - - {CB::c_out0, 1}, // param - {CB::c_out1, 1}, // exp_avg - {CB::c_out2, 1}, // exp_avg_sq - {CB::c_out3, 1}, // max_exp_avg_sq (optional) + {CB::c_in0, 1}, // param + {CB::c_in1, 1}, // grad + {CB::c_in2, 1}, // exp_avg + {CB::c_in3, 1}, // exp_avg_sq + {CB::c_in4, 1}, // max_exp_avg_sq (optional) + {CB::c_in5, 5}, // lr, beta1, beta2, eps, weight_decay + {CB::c_in6, 1}, // 1.0f + + {CB::c_intermed0, 1}, // tmp_grad + {CB::c_intermed1, 1}, // tmp_exp_avg + {CB::c_intermed2, 1}, // tmp_exp_avg_sq + {CB::c_intermed3, 1}, // tmp_max_exp_avg_sq + {CB::c_intermed4, 1}, // + {CB::c_intermed5, 1}, // + {CB::c_intermed6, 1}, // tmp1 + {CB::c_intermed7, 1}, // tmp2 + + {CB::c_out0, 1}, // param + {CB::c_out1, 1}, // exp_avg + {CB::c_out2, 1}, // exp_avg_sq + {CB::c_out3, 1}, // max_exp_avg_sq (optional) }); //////////////////////////////////////////////////////////////////////////// @@ -117,19 +123,20 @@ operation::ProgramWithCallbacks moreh_adamw_( compute_defines["AMSGRAD"] = "1"; } - const std::vector compute_args_group_1{ - num_tiles_per_core_group_1}; + const std::vector compute_args_group_1{num_tiles_per_core_group_1}; const auto compute_kernel_file = "tt_eager/tt_dnn/op_library/moreh_adamw/kernels/" "moreh_adamw.cpp"; auto compute_kernel_1_id = CreateComputeKernel( - program, compute_kernel_file, {core_group_1, num_tiles_per_core_group_1, compute_args_group_1}, compute_defines); + program, + compute_kernel_file, + {core_group_1, num_tiles_per_core_group_1, compute_args_group_1}, + compute_defines); KernelHandle compute_kernel_2_id = -1; if (!core_group_2.ranges().empty()) { - const std::vector compute_args_group_2{ - num_tiles_per_core_group_2}; + const std::vector compute_args_group_2{num_tiles_per_core_group_2}; compute_kernel_2_id = CreateComputeKernel( program, @@ -170,14 +177,24 @@ operation::ProgramWithCallbacks moreh_adamw_( } const std::vector reader_runtime_args{ - param_addr, grad_addr, exp_avg_addr, exp_avg_sq_addr, max_exp_avg_sq_addr, - f2u_lr.u, f2u_beta1.u, f2u_beta2.u, f2u_eps.u, f2u_weight_decay.u, step, static_cast(amsgrad), - num_tiles_per_core, tile_offset}; + param_addr, + grad_addr, + exp_avg_addr, + exp_avg_sq_addr, + max_exp_avg_sq_addr, + f2u_lr.u, + f2u_beta1.u, + f2u_beta2.u, + f2u_eps.u, + f2u_weight_decay.u, + step, + static_cast(amsgrad), + num_tiles_per_core, + tile_offset}; tt_metal::SetRuntimeArgs(program, reader_kernel_id, core, reader_runtime_args); const std::vector writer_runtime_args{ - param_addr, exp_avg_addr, exp_avg_sq_addr, max_exp_avg_sq_addr, - num_tiles_per_core, tile_offset}; + param_addr, exp_avg_addr, exp_avg_sq_addr, max_exp_avg_sq_addr, num_tiles_per_core, tile_offset}; tt_metal::SetRuntimeArgs(program, writer_kernel_id, core, writer_runtime_args); if (core_group_1.core_coord_in_core_ranges(core)) { @@ -191,50 +208,9 @@ operation::ProgramWithCallbacks moreh_adamw_( tile_offset += num_tiles_per_core; } - //////////////////////////////////////////////////////////////////////////// - // Callback SetUp - //////////////////////////////////////////////////////////////////////////// - auto override_runtime_args_callback = [reader_kernel_id = reader_kernel_id, - writer_kernel_id = writer_kernel_id, - num_cores = num_cores, - num_cores_y = num_cores_y]( - const Program& program, - const std::vector& input_buffers, - const std::vector& output_buffers) { - auto param_buffer = input_buffers.at(0); - auto grad_buffer = input_buffers.at(1); - auto exp_avg_buffer = input_buffers.at(2); - auto exp_avg_sq_buffer = input_buffers.at(3); - auto max_exp_avg_sq_buffer = input_buffers.at(4); - - for (uint32_t i = 0; i < num_cores; ++i) { - CoreCoord core = {i / num_cores_y, i % num_cores_y}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = param_buffer->address(); - runtime_args[1] = grad_buffer->address(); - runtime_args[2] = exp_avg_buffer->address(); - runtime_args[3] = exp_avg_sq_buffer->address(); - if (max_exp_avg_sq_buffer != nullptr) { - runtime_args[4] = max_exp_avg_sq_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = param_buffer->address(); - runtime_args[1] = grad_buffer->address(); - runtime_args[2] = exp_avg_buffer->address(); - runtime_args[3] = exp_avg_sq_buffer->address(); - if (max_exp_avg_sq_buffer != nullptr) { - runtime_args[4] = max_exp_avg_sq_buffer->address(); - } - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + std::move(program), + create_override_addresses_callback(reader_kernel_id, writer_kernel_id, num_cores, num_cores_y)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp index 4aedc2407ce9..aa203229ec66 100644 --- a/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_getitem/moreh_getitem_op.cpp @@ -211,7 +211,9 @@ Tensor moreh_getitem( optional_output_tensors); }, new_input_tensors, - output_tensors); + output_tensors, + {}, + {output_tensor}); return output_tensors.at(0); } diff --git a/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp b/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp index bcbe4abf5d5c..cc5f31427f9a 100644 --- a/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp +++ b/tt_eager/tt_dnn/op_library/moreh_helper_functions.hpp @@ -126,6 +126,138 @@ struct CircularBufferArg { tt::DataFormat data_format, CircularBufferArg arg); + +struct CallbackArgMap { + std::map input; + std::map optional_input; + std::map output; +}; + +using Tensors = std::vector; +using OptionalConstTensors = std::vector>; + +// To use this function, the arguments in the reader kernel must always be sorted in the order of input followed by +// optional_input. Furthermore, input and output tensors must always start from the 0th argument. +template +const std::function +create_override_runtime_arguments_callback( + KernelHandle reader_kernel_id, KernelHandle writer_kernel_id, uint32_t num_cores, uint32_t core_h) { + return [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( + const void *operation, + Program &program, + const Tensors &input_tensors, + const OptionalConstTensors &optional_input_tensors, + const OutputTensors &output_tensors) -> void { + for (uint32_t icore = 0; icore < num_cores; icore++) { + CoreCoord core = {icore / core_h, icore % core_h}; + + // readers + { + uint32_t rt_idx = 0; + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + for (uint32_t idx = 0; idx < input_tensors.size(); idx++) { + runtime_args[rt_idx++] = input_tensors.at(idx).buffer()->address(); + } + for (uint32_t idx = 0; idx < optional_input_tensors.size(); idx++) { + auto optional_input_tensor = optional_input_tensors.at(idx); + runtime_args[rt_idx++] = + optional_input_tensor.has_value() ? optional_input_tensor.value().buffer()->address() : 0; + } + } + + // writer + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + for (uint32_t idx = 0; idx < output_tensors.size(); idx++) { + runtime_args[idx] = output_tensors.at(idx).buffer()->address(); + } + } + } + }; +} + +// Using this structure is not recommended because directly setting the callback argument map doesn't significantly +// reduce the amount of code. +template +const std::function +create_override_runtime_arguments_callback( + KernelHandle reader_kernel_id, + KernelHandle writer_kernel_id, + uint32_t num_cores, + uint32_t core_h, + CallbackArgMap arg_map) { + return [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, arg_map, num_cores, core_h]( + const void *operation, + Program &program, + const Tensors &input_tensors, + const OptionalConstTensors &optional_input_tensors, + const OutputTensors &output_tensors) -> void { + for (uint32_t icore = 0; icore < num_cores; icore++) { + CoreCoord core = {icore / core_h, icore % core_h}; + + // readers + { + auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + for (const auto &pair : arg_map.input) { + runtime_args[pair.first] = input_tensors.at(pair.second).buffer()->address(); + } + for (const auto &pair : arg_map.optional_input) { + auto optional_input_tensor = optional_input_tensors.at(pair.second); + runtime_args[pair.first] = + optional_input_tensor.has_value() ? optional_input_tensor.value().buffer()->address() : 0; + } + } + + // writer + { + auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + for (const auto &pair : arg_map.output) { + runtime_args[pair.first] = output_tensors.at(pair.second).buffer()->address(); + } + } + } + }; +} + +// To use this function, the arguments in the reader kernel must always be sorted in the order of input followed by +// optional_input. Furthermore, input and output tensors must always start from the 0th argument. +template +const std::function&, const std::vector&)> +create_override_addresses_callback( + KernelHandle reader_kernel_id, KernelHandle writer_kernel_id, uint32_t num_cores, uint32_t core_h) { + return [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( + const Program& program, + const std::vector& input_buffers, + const std::vector& output_buffers) -> void { + for (uint32_t icore = 0; icore < num_cores; icore++) { + CoreCoord core = {icore / core_h, icore % core_h}; + + // readers + { + auto& runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); + for (uint32_t idx = 0; idx < input_buffers.size(); idx++) { + auto buffer = input_buffers.at(idx); + if (buffer != nullptr) { + runtime_args[idx] = buffer->address(); + } + } + } + + // writer + { + auto& runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); + for (uint32_t idx = 0; idx < output_buffers.size(); idx++) { + auto buffer = output_buffers.at(idx); + if (buffer != nullptr) { + runtime_args[idx] = buffer->address(); + } + } + } + } + }; +} + + } // namespace primary } // namespace operations } // namespace tt diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp index e2a5b0752c39..f2646aa0b862 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step1/moreh_nll_loss_step1.cpp @@ -146,37 +146,10 @@ operation::ProgramWithCallbacks moreh_nll_loss_step1_impl( tile_offset += num_units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto target_dram_buffer = input_buffers.at(0); - auto weight_dram_buffer = input_buffers.at(1); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = target_dram_buffer->address(); - if (weight_dram_buffer != nullptr) { - runtime_args[1] = weight_dram_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp index fc2957d69c01..085416713c01 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss/moreh_nll_loss_step2/moreh_nll_loss_step2.cpp @@ -186,43 +186,10 @@ operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_2d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(input_buffers.size() == 4); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto target_dram_buffer = input_buffers.at(1); - auto weight_dram_buffer = input_buffers.at(2); - auto divisor_dram_buffer = input_buffers.at(3); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - runtime_args[1] = target_dram_buffer->address(); - if (weight_dram_buffer != nullptr) { - runtime_args[2] = weight_dram_buffer->address(); - } - if (divisor_dram_buffer != nullptr) { - runtime_args[3] = divisor_dram_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_3d( @@ -397,43 +364,10 @@ operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_3d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(input_buffers.size() == 4); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto target_dram_buffer = input_buffers.at(1); - auto weight_dram_buffer = input_buffers.at(2); - auto divisor_dram_buffer = input_buffers.at(3); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - runtime_args[1] = target_dram_buffer->address(); - if (weight_dram_buffer != nullptr) { - runtime_args[2] = weight_dram_buffer->address(); - } - if (divisor_dram_buffer != nullptr) { - runtime_args[3] = divisor_dram_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_4d( @@ -616,43 +550,10 @@ operation::ProgramWithCallbacks moreh_nll_loss_step2_impl_4d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const Program &program, - const std::vector &input_buffers, - const std::vector &output_buffers) { - TT_ASSERT(input_buffers.size() == 4); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto target_dram_buffer = input_buffers.at(1); - auto weight_dram_buffer = input_buffers.at(2); - auto divisor_dram_buffer = input_buffers.at(3); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - runtime_args[1] = target_dram_buffer->address(); - if (weight_dram_buffer != nullptr) { - runtime_args[2] = weight_dram_buffer->address(); - } - if (divisor_dram_buffer != nullptr) { - runtime_args[3] = divisor_dram_buffer->address(); - } - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } operation::ProgramWithCallbacks moreh_nll_loss_step2_impl( diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp index 0fa899feb4a4..f83def0d88a5 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_2d.cpp @@ -3,13 +3,14 @@ // SPDX-License-Identifier: Apache-2.0 #include "tt_eager/tt_dnn/kernels/dataflow/moreh_common.hpp" +#include "dprint.h" void kernel_main() { uint32_t i = 0; auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); auto weight_addr = get_arg_val(i++); auto divisor_addr = get_arg_val(i++); - auto output_grad_addr = get_arg_val(i++); auto ignore_index = static_cast(get_arg_val(i++)); auto num_tiles_per_core = get_arg_val(i++); auto start_id = get_arg_val(i++); diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_3d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_3d.cpp index 6c8697bc352d..e48c188d9c00 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_3d.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_3d.cpp @@ -7,9 +7,9 @@ void kernel_main() { uint32_t i = 0; auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); auto weight_addr = get_arg_val(i++); auto divisor_addr = get_arg_val(i++); - auto output_grad_addr = get_arg_val(i++); auto ignore_index = static_cast(get_arg_val(i++)); auto num_tiles_per_core = get_arg_val(i++); auto start_id = get_arg_val(i++); diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_4d.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_4d.cpp index 073298d147ab..3ca374cf0e8d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_4d.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/kernels/reader_moreh_nll_loss_backward_4d.cpp @@ -8,9 +8,9 @@ void kernel_main() { uint32_t i = 0; auto target_addr = get_arg_val(i++); + auto output_grad_addr = get_arg_val(i++); auto weight_addr = get_arg_val(i++); auto divisor_addr = get_arg_val(i++); - auto output_grad_addr = get_arg_val(i++); auto ignore_index = static_cast(get_arg_val(i++)); auto num_tiles_per_core = get_arg_val(i++); auto start_id = get_arg_val(i++); diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp index 78570bb4f0d9..fa389814e759 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward/moreh_nll_loss_backward.cpp @@ -156,9 +156,9 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_4d( std::vector reader_args = { target_addr, + output_grad_addr, weight_addr, divisor_addr, - output_grad_addr, static_cast(ignore_index), units_per_core, tile_offset, @@ -187,47 +187,12 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_4d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const void *operation, - Program &program, - const std::vector &input_tensors, - const std::vector> &optional_input_tensors, - const std::vector &output_tensors) { - TT_ASSERT(input_tensors.size() == 2); - TT_ASSERT(optional_input_tensors.size() == 2); - TT_ASSERT(output_tensors.size() == 1); - - auto target_addr = input_tensors.at(0).buffer()->address(); - auto output_grad_addr = input_tensors.at(1).buffer()->address(); - auto weight_addr = - optional_input_tensors.at(0).has_value() ? optional_input_tensors.at(0).value().buffer()->address() : 0; - auto divisor_addr = - optional_input_tensors.at(1).has_value() ? optional_input_tensors.at(1).value().buffer()->address() : 0; - auto input_grad_addr = output_tensors.at(0).buffer()->address(); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = target_addr; - runtime_args[1] = weight_addr; - runtime_args[2] = divisor_addr; - runtime_args[3] = output_grad_addr; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_addr; - } - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } - operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_3d( const Tensor &target, const std::optional weight, @@ -238,7 +203,6 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_3d( const bool reduction_mean, const CoreRange core_range, const DeviceComputeKernelConfig compute_kernel_config) { - // split work // input_grad: (N, C, W) @@ -370,9 +334,9 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_3d( std::vector reader_args = { target_addr, + output_grad_addr, weight_addr, divisor_addr, - output_grad_addr, static_cast(ignore_index), units_per_core, tile_offset, @@ -401,47 +365,12 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_3d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const void *operation, - Program &program, - const std::vector &input_tensors, - const std::vector> &optional_input_tensors, - const std::vector &output_tensors) { - TT_ASSERT(input_tensors.size() == 2); - TT_ASSERT(optional_input_tensors.size() == 2); - TT_ASSERT(output_tensors.size() == 1); - - auto target_addr = input_tensors.at(0).buffer()->address(); - auto output_grad_addr = input_tensors.at(1).buffer()->address(); - auto weight_addr = - optional_input_tensors.at(0).has_value() ? optional_input_tensors.at(0).value().buffer()->address() : 0; - auto divisor_addr = - optional_input_tensors.at(1).has_value() ? optional_input_tensors.at(1).value().buffer()->address() : 0; - auto input_grad_addr = output_tensors.at(0).buffer()->address(); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = target_addr; - runtime_args[1] = weight_addr; - runtime_args[2] = divisor_addr; - runtime_args[3] = output_grad_addr; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_addr; - } - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } - operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_2d( const Tensor &target, const std::optional weight, @@ -579,9 +508,9 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_2d( std::vector reader_args = { target_addr, + output_grad_addr, weight_addr, divisor_addr, - output_grad_addr, static_cast(ignore_index), units_per_core, tile_offset, @@ -609,48 +538,12 @@ operation::ProgramWithCallbacks moreh_nll_loss_backward_impl_2d( tile_offset += units_per_core; } - auto override_runtime_args_callback = - [reader_kernel_id = reader_kernel_id, writer_kernel_id = writer_kernel_id, num_cores, core_h]( - const void *operation, - Program &program, - const std::vector &input_tensors, - const std::vector> &optional_input_tensors, - const std::vector &output_tensors) { - TT_ASSERT(input_tensors.size() == 2); - TT_ASSERT(optional_input_tensors.size() == 2); - TT_ASSERT(output_tensors.size() == 1); - - auto target_addr = input_tensors.at(0).buffer()->address(); - auto output_grad_addr = input_tensors.at(1).buffer()->address(); - auto weight_addr = - optional_input_tensors.at(0).has_value() ? optional_input_tensors.at(0).value().buffer()->address() : 0; - auto divisor_addr = - optional_input_tensors.at(1).has_value() ? optional_input_tensors.at(1).value().buffer()->address() : 0; - auto input_grad_addr = output_tensors.at(0).buffer()->address(); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = target_addr; - runtime_args[1] = weight_addr; - runtime_args[2] = divisor_addr; - runtime_args[3] = output_grad_addr; - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_addr; - } - } - }; - - return {.program = std::move(program), .override_runtime_arguments_callback = override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } - - } // namespace operation::ProgramWithCallbacks moreh_nll_loss_backward_impl( diff --git a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp index 9fffeb7de044..32c79199bac8 100644 --- a/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_nll_loss_backward/moreh_nll_loss_backward_op.cpp @@ -24,25 +24,19 @@ void MorehNllLossBackward::validate_with_output_tensors( const std::vector& input_tensors, const std::vector>& optional_input_tensors, const std::vector>& output_tensors) const { - TT_ASSERT(input_tensors.size() == 3, "Must have 3 input tensors"); + TT_ASSERT(input_tensors.size() == 2, "Must have 2 input tensors"); TT_ASSERT(optional_input_tensors.size() == 2, "Must have 2 optional input tensors"); - auto& input_tensor = input_tensors.at(0); - auto& target_tensor = input_tensors.at(1); - auto& output_grad_tensor = input_tensors.at(2); + auto& target_tensor = input_tensors.at(0); + auto& output_grad_tensor = input_tensors.at(1); auto& weight_tensor = optional_input_tensors.at(0); auto& divisor_tensor = optional_input_tensors.at(1); auto& input_grad_tensor = output_tensors.at(0); - TT_ASSERT(input_tensor.storage_type() == StorageType::DEVICE, "Operands to nll_loss need to be on device!"); - TT_ASSERT(input_tensor.buffer() != nullptr, "Operands to nll_loss need to be allocated in buffers on device!"); - TT_ASSERT((input_tensor.get_layout() == Layout::TILE), "intput_tensor to nll_loss must be tilized"); - TT_ASSERT(input_tensor.get_dtype() == DataType::BFLOAT16); - TT_ASSERT(target_tensor.storage_type() == StorageType::DEVICE, "Operands to nll_loss need to be on device!"); TT_ASSERT(target_tensor.buffer() != nullptr, "Operands to nll_loss need to be allocated in buffers on device!"); TT_ASSERT((target_tensor.get_layout() == Layout::TILE), "target_tensor to nll_loss must be tilized"); - TT_ASSERT(target_tensor.get_dtype() == DataType::UINT32); + TT_ASSERT(target_tensor.get_dtype() == DataType::INT32); TT_ASSERT(output_grad_tensor.storage_type() == StorageType::DEVICE, "Operands to nll_loss need to be on device!"); TT_ASSERT( diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp index 79e0b62a5e08..4ee07c10ee58 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_c_large/softmax_c_large.cpp @@ -129,39 +129,10 @@ operation::ProgramWithCallbacks moreh_softmax_c_large(const Tensor &input, const tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp index abcb6b194099..ea2ab9945448 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_large/softmax_h_large.cpp @@ -123,39 +123,10 @@ operation::ProgramWithCallbacks moreh_softmax_h_large(const Tensor &input, const tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp index 77523098d4c2..cbf825a1b5e3 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_h_small/softmax_h_small.cpp @@ -145,39 +145,10 @@ operation::ProgramWithCallbacks moreh_softmax_h_small(const Tensor &input, const tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp index f1ae31c7dd69..7018590c32ac 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_large/softmax_w_large.cpp @@ -124,39 +124,10 @@ operation::ProgramWithCallbacks moreh_softmax_w_large(const Tensor &input, const tile_offset += num_tiles_per_core * Wt; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp index bf90b8d47b0e..1dcf9f818dcb 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax/softmax_w_small/softmax_w_small.cpp @@ -145,39 +145,10 @@ operation::ProgramWithCallbacks moreh_softmax_w_small(const Tensor &input, const tile_offset += num_tiles_per_core * Wt; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 1); - TT_ASSERT(output_buffers.size() == 1); - - auto src_dram_buffer = input_buffers.at(0); - auto dst_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = src_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = dst_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp index 5752781a8934..2447581d0f4a 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_c_large/softmax_backward_c_large.cpp @@ -135,41 +135,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_c_large(const Tensor &out tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp index 859867d17f00..638ed5dc7e7c 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_large/softmax_backward_h_large.cpp @@ -130,41 +130,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_large(const Tensor &out tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp index 44df27586980..b17cff78ce4b 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_h_small/softmax_backward_h_small.cpp @@ -152,41 +152,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_h_small(const Tensor &out tile_offset += num_tiles_per_core; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp index 78ce4ceecfae..a46b647d51f4 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_large/softmax_backward_w_large.cpp @@ -130,41 +130,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_large(const Tensor &out tile_offset += num_tiles_per_core * Wt; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp index a834f5e4acd3..8488ca725468 100644 --- a/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_softmax_backward/softmax_backward_w_small/softmax_backward_w_small.cpp @@ -153,41 +153,10 @@ operation::ProgramWithCallbacks moreh_softmax_backward_w_small(const Tensor &out tile_offset += num_tiles_per_core * Wt; } - auto override_runtime_args_callback = [ - reader_kernel_id=reader_kernel_id, - writer_kernel_id=writer_kernel_id, - num_cores, - core_h - ] - ( - const Program &program, - const std::vector& input_buffers, - const std::vector& output_buffers - ) { - TT_ASSERT(input_buffers.size() == 2); - TT_ASSERT(output_buffers.size() == 1); - - auto output_dram_buffer = input_buffers.at(0); - auto output_grad_dram_buffer = input_buffers.at(1); - auto input_grad_dram_buffer = output_buffers.at(0); - - for (uint32_t icore = 0; icore < num_cores; icore++) { - CoreCoord core = {icore / core_h, icore % core_h}; - - { - auto &runtime_args = GetRuntimeArgs(program, reader_kernel_id, core); - runtime_args[0] = output_dram_buffer->address(); - runtime_args[1] = output_grad_dram_buffer->address(); - } - - { - auto &runtime_args = GetRuntimeArgs(program, writer_kernel_id, core); - runtime_args[0] = input_grad_dram_buffer->address(); - } - } - }; - - return {std::move(program), override_runtime_args_callback}; + return { + .program = std::move(program), + .override_runtime_arguments_callback = + create_override_runtime_arguments_callback(reader_kernel_id, writer_kernel_id, num_cores, core_h)}; } } // namespace primary diff --git a/tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp b/tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp index f8c787a970e3..aa350191a30d 100644 --- a/tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp +++ b/tt_eager/tt_dnn/op_library/moreh_sum/moreh_sum_op.cpp @@ -53,7 +53,7 @@ Tensor _moreh_sum( std::optional compute_kernel_config) { std::vector output_tensors = {Tensor(operation::get_workers_for_op_output({input}))}; - TT_FATAL(input.storage_type() == StorageType::DEVICE); + TT_FATAL(input.storage_type() == StorageType::DEVICE || input.storage_type() == StorageType::MULTI_DEVICE); auto kernel_config_val = init_device_compute_kernel_config(input.device()->arch(), compute_kernel_config, MathFidelity::HiFi4); operation::launch_op( diff --git a/tt_eager/tt_dnn/op_library/operation.hpp b/tt_eager/tt_dnn/op_library/operation.hpp index 6ef4b8fc33dc..26285d0b5e80 100644 --- a/tt_eager/tt_dnn/op_library/operation.hpp +++ b/tt_eager/tt_dnn/op_library/operation.hpp @@ -528,9 +528,6 @@ struct DeviceOperation final { const Tensors& input_tensors, const OptionalConstTensors& optional_input_tensors, const OptionalTensors& optional_output_tensors) -> void { - if (ttnn::CONFIG.enable_fast_runtime_mode) { - return; - } const auto& operation = *reinterpret_cast*>(&storage); if constexpr ( (detail::implements_validate() or diff --git a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp index 0af4c11bf4b2..a742be885e29 100644 --- a/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp +++ b/tt_eager/tt_dnn/op_library/transformer_tms/transformer_tms.cpp @@ -544,11 +544,8 @@ void SSMEltwiseMul::validate(const std::vector& input_tensors) const { "Unsupported data format for input a!"); TT_FATAL( input_tensor_b.get_dtype() == tt::tt_metal::DataType::BFLOAT16 || - input_tensor_a.get_dtype() == tt::tt_metal::DataType::BFLOAT8_B, + input_tensor_b.get_dtype() == tt::tt_metal::DataType::BFLOAT8_B, "Unsupported data format for input b!"); - TT_FATAL( - input_tensor_a.get_dtype() == input_tensor_b.get_dtype(), - "Input a and input b must have the same data format!"); TT_FATAL( this->output_mem_config.memory_layout == TensorMemoryLayout::INTERLEAVED, diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp index fc3710689212..2d15d354531f 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings.cpp @@ -251,6 +251,30 @@ void DeviceModule(py::module &m_device) { Release captured Trace on Device handle )doc"); + auto pyEvent = py::class_>(m_device, "Event", "Event class"); + m_device.def("CreateEvent", + [] () { + return std::make_shared(); + }, R"doc( + Create new event + )doc"); + m_device.def("RecordEvent", + [] (Device* device, const uint8_t cq_id, std::shared_ptr event) { + device->push_work([device, cq_id, event] { + EnqueueRecordEvent(device->command_queue(cq_id), event); + }); + }, R"doc( + Record an event + )doc"); + m_device.def("WaitForEvent", + [] (Device* device, const uint8_t cq_id, std::shared_ptr event) { + device->push_work([device, cq_id, event] { + EnqueueWaitForEvent(device->command_queue(cq_id), event); + }); + }, R"doc( + Wait for an event + )doc"); + m_device.attr("DEFAULT_L1_SMALL_SIZE") = py::int_(DEFAULT_L1_SMALL_SIZE); } diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp index b3750d8cdd80..5ea5a87f8ecf 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_composite_ops.cpp @@ -72,8 +72,8 @@ namespace tt::tt_metal::detail{ "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -89,9 +89,10 @@ namespace tt::tt_metal::detail{ "true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "False Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -107,9 +108,10 @@ namespace tt::tt_metal::detail{ "true_value", "float", "float", "float scalar", "Yes" "false_value", "Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -125,9 +127,10 @@ namespace tt::tt_metal::detail{ "true_value", "True Tensor", "Tensor", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "float", "float", "float scalar", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); - m_tensor.def("where", py::overload_cast(&where), - py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, R"doc( + m_tensor.def("where", py::overload_cast >(&where), + py::arg("predicate"), py::arg("true_value"), py::arg("false_value"), py::arg("output_mem_config").noconvert() = operation::DEFAULT_OUTPUT_MEMORY_CONFIG, py::arg("output_tensor").noconvert() = std::nullopt, R"doc( Perform an ternary where operation on two tensors based on third @predicate. where(predicate, true_value, false_value) implements (predicate) ? true_value : false_value. @@ -143,6 +146,7 @@ namespace tt::tt_metal::detail{ "true_value", "float", "float", "Tensor of shape [W, Z, Y, X]", "Yes" "false_value", "float", "float", "float scalar", "Yes" "output_mem_config", "Layout of tensor in TT Accelerator device memory banks", "MemoryConfig", "Default is interleaved in DRAM", "No" + "output_tensor", "optional output tensor", "Tensor", "default is None", "No" )doc"); // *** composite unary ops *** detail::bind_unary_op(m_tensor, "normalize_hw", tt::tt_metal::normalize_hw, R"doc(Returns a new tensor with the Gaussian normalize of the elements of the input tensor ``{0}`` on H,W axes.)doc"); diff --git a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp index 6b9b80896470..c4693d83a559 100644 --- a/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp +++ b/tt_eager/tt_lib/csrc/tt_lib_bindings_tensor_xary_ops.cpp @@ -85,7 +85,13 @@ namespace tt::tt_metal::detail { detail::bind_unary_op(m_tensor, "i0", i0, R"doc(Computes the zeroth order modified Bessel function of the first kind applied on the elements of the input tensor ``{0}``, for the input range -10 to 10.)doc"); detail::bind_unary_op(m_tensor, "silu", silu, R"doc(Returns tensor with the silu all of elements of the input tensor ``{0}``.)doc"); detail::bind_unary_op(m_tensor, "neg", neg, R"doc(Returns tensor with the negate all of elements of the input tensor ``{0}``.)doc"); - detail::bind_unary_op(m_tensor, "eltwise_typecast", eltwise_typecast, R"doc(Returns tensor with all of the elements of the input tensor ``{0}`` typecasted from fp32 to uint32.)doc"); + + detail::bind_unary_op_with_param( + m_tensor, "eltwise_typecast", eltwise_typecast, + py::arg("tt_output_dtype"), + R"doc(Returns tensor with all of the elements of the input tensor ``{0}`` typecasted from fp32 to uint32 or uint16.)doc", + R"doc("Indicates output dtype of typecast", "ttl.tensor.DataType", "")doc" + ); detail::bind_unary_op_with_param( m_tensor, "exp", py::overload_cast(&exp), diff --git a/tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h b/tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h index 7558f53219ae..a9c8bf6258f6 100644 --- a/tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h +++ b/tt_metal/hw/ckernels/grayskull/metal/llk_io/llk_outputs.h @@ -18,12 +18,12 @@ inline const uint32_t get_output_base_id() return (OUTPUT_BASE_ID); } -inline const uint32_t get_output_src_format(const std::uint32_t output_id) +inline const unsigned char get_output_src_format(const std::uint32_t output_id) { return pack_src_format[output_id]; } -inline const uint32_t get_output_dst_format(const std::uint32_t output_id) +inline const unsigned char get_output_dst_format(const std::uint32_t output_id) { return pack_dst_format[output_id]; } diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h index b3fdd91a568b..0d2a43ff7ace 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_typecast.h @@ -51,5 +51,21 @@ inline void calculate_typecast_fp16b_to_uint32() } } +template +inline void calculate_typecast_fp16b_to_uint16() +{ + #pragma GCC unroll 0 + for (int d = 0; d < ITERATIONS; d++) { + TTI_SFPENCC(0,0,0,0); + TTI_SFPLOAD(0,0,3,0); + TTI_SFPSETCC(0,0,0,0); + TTI_SFPLOADI(0,0,0); + TTI_SFPENCC(0,0,0,0); + TTI_SFP_STOCH_RND(0,0,2,0,1,14); + TTI_SFPSTORE(1,6,3,0); + dst_reg++; + } +} + } // namespace sfpu } // namespace ckernel diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h index b4ac44225b69..8a7f9d95a531 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_typecast.h @@ -5,19 +5,27 @@ #pragma once #include "llk_math_eltwise_unary_sfpu_init.h" -#include "llk_math_eltwise_unary_sfpu_0_param.h" +#include "llk_math_eltwise_unary_sfpu_params.h" #include "ckernel_sfpu_typecast.h" namespace ckernel { // New LLK SFPU APIs -template +template inline void llk_math_eltwise_unary_sfpu_typecast(uint dst_index, int vector_mode = (int)VectorMode::RC) { - llk_math_eltwise_unary_sfpu_0_param - (ckernel::sfpu::calculate_typecast_fp16b_to_uint32, - ckernel::sfpu::calculate_typecast_fp16b_to_uint32, - dst_index, vector_mode); + if constexpr (OUT_DTYPE == (uint32_t)DataFormat::UInt16) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_typecast_fp16b_to_uint16, + dst_index, + vector_mode); + } + else if constexpr (OUT_DTYPE == (uint32_t)DataFormat::UInt32) { + llk_math_eltwise_unary_sfpu_params( + ckernel::sfpu::calculate_typecast_fp16b_to_uint32, + dst_index, + vector_mode); + } } template diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h index b92af5b8ddc3..74c71eb97519 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_io/llk_outputs.h @@ -18,12 +18,12 @@ inline const uint32_t get_output_base_id() return (OUTPUT_BASE_ID); } -inline const uint32_t get_output_src_format(const std::uint32_t output_id) +inline const unsigned char get_output_src_format(const std::uint32_t output_id) { return pack_src_format[output_id]; } -inline const uint32_t get_output_dst_format(const std::uint32_t output_id) +inline const unsigned char get_output_dst_format(const std::uint32_t output_id) { return pack_dst_format[output_id]; } diff --git a/tt_metal/hw/inc/blackhole/noc/noc_parameters.h b/tt_metal/hw/inc/blackhole/noc/noc_parameters.h index 8b8e9ad14150..7f6529f9915e 100644 --- a/tt_metal/hw/inc/blackhole/noc/noc_parameters.h +++ b/tt_metal/hw/inc/blackhole/noc/noc_parameters.h @@ -14,6 +14,9 @@ #define NOC_XY_ENCODING(x, y) \ ((((uint64_t)(y)) << (NOC_ADDR_LOCAL_BITS + NOC_ADDR_NODE_ID_BITS)) | (((uint64_t)(x)) << NOC_ADDR_LOCAL_BITS)) +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ + NOC_XY_ENCODING(x, y) + #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ ((((uint64_t)(x_start)) << (NOC_ADDR_LOCAL_BITS + 2 * NOC_ADDR_NODE_ID_BITS)) | \ (((uint64_t)(y_start)) << (NOC_ADDR_LOCAL_BITS + 3 * NOC_ADDR_NODE_ID_BITS)) | \ diff --git a/tt_metal/hw/inc/dataflow_api.h b/tt_metal/hw/inc/dataflow_api.h index 91b1a26f8f39..12df89b03dec 100644 --- a/tt_metal/hw/inc/dataflow_api.h +++ b/tt_metal/hw/inc/dataflow_api.h @@ -476,7 +476,7 @@ uint64_t get_l1_noc_addr(const uint32_t id, const uint32_t page_size, const uint } uint64_t get_system_memory_noc_addr(const uint32_t id, const uint32_t page_size, const uint32_t base_addr, const uint32_t offset = 0) { - constexpr static uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)) << 32; + uint64_t pcie_core_noc_encoding = uint64_t(NOC_XY_PCIE_ENCODING(NOC_X(PCIE_NOC_X), NOC_Y(PCIE_NOC_Y), noc_index)) << 32; uint32_t addr = base_addr + page_size * id + offset; uint64_t noc_addr = pcie_core_noc_encoding | addr; return noc_addr; diff --git a/tt_metal/hw/inc/grayskull/noc/noc_parameters.h b/tt_metal/hw/inc/grayskull/noc/noc_parameters.h index 3fa07c452942..ed13f98ea8fd 100644 --- a/tt_metal/hw/inc/grayskull/noc/noc_parameters.h +++ b/tt_metal/hw/inc/grayskull/noc/noc_parameters.h @@ -12,6 +12,8 @@ // Address formats #define NOC_XY_ENCODING(x, y) ((((uint32_t)(y)) << (NOC_ADDR_NODE_ID_BITS)) | (((uint32_t)(x)))) +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) NOC_XY_ENCODING(x, y) + #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ ((x_start) << (2 * NOC_ADDR_NODE_ID_BITS)) | ((y_start) << (3 * NOC_ADDR_NODE_ID_BITS)) | (x_end) | \ ((y_end) << (NOC_ADDR_NODE_ID_BITS)) diff --git a/tt_metal/hw/inc/wormhole/noc/noc_parameters.h b/tt_metal/hw/inc/wormhole/noc/noc_parameters.h index 0a2256ffeebe..f6b361d3ff3f 100644 --- a/tt_metal/hw/inc/wormhole/noc/noc_parameters.h +++ b/tt_metal/hw/inc/wormhole/noc/noc_parameters.h @@ -9,13 +9,21 @@ #define PCIE_NOC_X 0 #define PCIE_NOC_Y 3 +#define PCIE_NOC1_X 9 +#define PCIE_NOC1_Y 8 + // Address formats #define NOC_XY_ENCODING(x, y) \ (((uint32_t)(y)) << ((NOC_ADDR_LOCAL_BITS % 32)+NOC_ADDR_NODE_ID_BITS)) | \ - (((uint32_t)(x)) << (NOC_ADDR_LOCAL_BITS % 32)) | ((x == PCIE_NOC_X and y == PCIE_NOC_Y) * 0x8) \ + (((uint32_t)(x)) << (NOC_ADDR_LOCAL_BITS % 32)) \ + +// Address formats +#define NOC_XY_PCIE_ENCODING(x, y, noc_index) \ + NOC_XY_ENCODING(x, y) | \ + ((noc_index ? (x == PCIE_NOC1_X and y == PCIE_NOC1_Y) : (x == PCIE_NOC_X and y == PCIE_NOC_Y)) * 0x8) \ #define NOC_MULTICAST_ENCODING(x_start, y_start, x_end, y_end) \ - (((uint32_t)(x_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+2*NOC_ADDR_NODE_ID_BITS)) | \ + (((uint32_t)(x_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+2*NOC_ADDR_NODE_ID_BITS)) | \ (((uint32_t)(y_start)) << ((NOC_ADDR_LOCAL_BITS % 32)+3*NOC_ADDR_NODE_ID_BITS)) | \ (((uint32_t)(x_end)) << (NOC_ADDR_LOCAL_BITS % 32)) | \ (((uint32_t)(y_end)) << ((NOC_ADDR_LOCAL_BITS % 32)+NOC_ADDR_NODE_ID_BITS)) \ diff --git a/tt_metal/impl/debug/watcher_server.cpp b/tt_metal/impl/debug/watcher_server.cpp index 82fd5f377da8..9b353e305947 100644 --- a/tt_metal/impl/debug/watcher_server.cpp +++ b/tt_metal/impl/debug/watcher_server.cpp @@ -785,17 +785,17 @@ static void watcher_loop(int sleep_usecs) { } log_info(LogLLRuntime, "Watcher server initialized, disabled features: {}", disabled_features); - double last_elapsed_time = watcher::get_elapsed_secs(); while (true) { - // Delay an amount such that we wait a minimum of the set sleep_usecs between polls. - while ((watcher::get_elapsed_secs() - last_elapsed_time) < ((double)sleep_usecs) / 1000000.) { + // Delay the amount of time specified by the user. Don't include watcher polling time to avoid the case where + // watcher dominates the communication links due to heavy traffic. + double last_elapsed_time = watcher::get_elapsed_secs(); + while ((watcher::get_elapsed_secs() - last_elapsed_time) < ((double) sleep_usecs) / 1000000.) { // Odds are this thread will be killed during the usleep, the kill signal is // watcher::enabled = false from the main thread. if (!watcher::enabled) break; usleep(1); } - last_elapsed_time = watcher::get_elapsed_secs(); { const std::lock_guard lock(watch_mutex); diff --git a/tt_metal/impl/device/device.cpp b/tt_metal/impl/device/device.cpp index e73c9efbdbba..6e9892c130c4 100644 --- a/tt_metal/impl/device/device.cpp +++ b/tt_metal/impl/device/device.cpp @@ -16,6 +16,7 @@ #include "common/utils.hpp" #include "llrt/llrt.hpp" #include "dev_msgs.h" +#include "noc/noc_parameters.h" namespace tt { @@ -344,16 +345,19 @@ void Device::configure_kernel_variant( CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in, + NOC noc_index, bool is_active_eth_core) { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + std::map defines = { {"DISPATCH_KERNEL", "1"}, - {"MY_NOC_X", std::to_string(kernel_physical_core.x)}, - {"MY_NOC_Y", std::to_string(kernel_physical_core.y)}, - {"UPSTREAM_NOC_X", std::to_string(upstream_physical_core.x)}, - {"UPSTREAM_NOC_Y", std::to_string(upstream_physical_core.y)}, - {"DOWNSTREAM_NOC_X", std::to_string(downstream_physical_core.x)}, - {"DOWNSTREAM_NOC_Y", std::to_string(downstream_physical_core.y)}, + {"MY_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, kernel_physical_core.x))}, + {"MY_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, kernel_physical_core.y))}, + {"UPSTREAM_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, upstream_physical_core.x))}, + {"UPSTREAM_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, upstream_physical_core.y))}, + {"DOWNSTREAM_NOC_X", std::to_string(NOC_0_X(noc_index, grid_size.x, downstream_physical_core.x))}, + {"DOWNSTREAM_NOC_Y", std::to_string(NOC_0_Y(noc_index, grid_size.y, downstream_physical_core.y))}, }; defines.insert(defines_in.begin(), defines_in.end()); @@ -364,7 +368,7 @@ void Device::configure_kernel_variant( kernel_core, tt::tt_metal::DataMovementConfig { .processor = tt::tt_metal::DataMovementProcessor::RISCV_1, - .noc = NOC::NOC_0, + .noc = noc_index, .compile_args = compile_args, .defines = defines } @@ -376,7 +380,7 @@ void Device::configure_kernel_variant( kernel_core, tt::tt_metal::EthernetConfig{ .eth_mode = is_active_eth_core ? Eth::SENDER : Eth::IDLE, - .noc = NOC::NOC_0, + .noc = noc_index, .compile_args = compile_args, .defines = defines } @@ -420,6 +424,8 @@ void Device::compile_command_queue_programs() { CoreCoord prefetch_physical_core = get_physical_core_coordinate(prefetch_core, dispatch_core_type); CoreCoord dispatch_physical_core = get_physical_core_coordinate(dispatch_core, dispatch_core_type); + NOC noc_index = this->hw_command_queues_[cq_id]->noc_index; + log_debug(LogDevice, "Dispatching out of {} cores", magic_enum::enum_name(dispatch_core_type)); log_debug(LogDevice, "Prefetch HD logical location: {} physical core: {}", prefetch_core.str(), prefetch_physical_core.str()); log_debug(LogDevice, "Dispatch HD logical location: {} physical core {}", dispatch_core.str(), dispatch_physical_core.str()); @@ -465,7 +471,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, dispatch_physical_core, - std::map {} + std::map {}, + noc_index ); tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, prefetch_core, 0, dispatch_core_type); // prefetch_sync_sem @@ -501,7 +508,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, prefetch_physical_core, CoreCoord{0, 0}, - std::map {} + std::map {}, + noc_index ); tt::tt_metal::CreateSemaphore(*command_queue_program_ptr, dispatch_core, 0, dispatch_core_type); // dispatch_sem @@ -517,7 +525,7 @@ void Device::compile_command_queue_programs() { Device *mmio_device = tt::tt_metal::detail::GetDeviceHandle(mmio_device_id); uint16_t channel = tt::Cluster::instance().get_assigned_channel_for_device(device_id); uint32_t cq_size = mmio_device->sysmem_manager().get_cq_size(); - + NOC noc_index = this->hw_command_queues_[cq_id]->noc_index; CoreType dispatch_core_type = dispatch_core_manager::get(num_hw_cqs).get_dispatch_core_type(mmio_device_id); tt_cxy_pair prefetch_core = dispatch_core_manager::get(num_hw_cqs).prefetcher_core(device_id, channel, cq_id); @@ -610,7 +618,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, mux_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run prefetch_h {}", prefetch_core.str()); @@ -671,7 +680,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); std::vector tunneler_l_compile_args = @@ -715,6 +725,7 @@ void Device::compile_command_queue_programs() { CoreCoord{0, 0}, CoreCoord{0, 0}, std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index, true ); @@ -782,7 +793,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run dispatch demux at {}", demux_core.str()); @@ -816,7 +828,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, demux_physical_core, CoreCoord{0xffffffff, 0xffffffff}, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run dispatch_h at {}", dispatch_core.str()); @@ -895,6 +908,7 @@ void Device::compile_command_queue_programs() { CoreCoord{0, 0}, CoreCoord{0, 0}, std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index, true ); @@ -959,7 +973,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run demux at {}", demux_d_core.str()); @@ -1007,7 +1022,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, demux_d_physical_core, dispatch_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run prefertch_d at {}", prefetch_d_core.str()); @@ -1041,7 +1057,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, prefetch_d_physical_core, mux_d_physical_core, - std::map {} + std::map {}, + noc_index ); log_debug(LogDevice, "run dispatch at {}", dispatch_core.str()); @@ -1100,7 +1117,8 @@ void Device::compile_command_queue_programs() { dispatch_core_type, CoreCoord{0, 0}, CoreCoord{0, 0}, - std::map {{"SKIP_NOC_LOGGING", "1"}} + std::map {{"SKIP_NOC_LOGGING", "1"}}, + noc_index ); log_debug(LogDevice, "run mux at {}", mux_d_core.str()); @@ -1194,7 +1212,7 @@ void Device::initialize_command_queue() { this->sysmem_manager_ = std::make_unique(this->id_, this->num_hw_cqs()); hw_command_queues_.resize(num_hw_cqs()); for (size_t cq_id = 0; cq_id < num_hw_cqs(); cq_id++) { - hw_command_queues_[cq_id] = std::make_unique(this, cq_id); + hw_command_queues_[cq_id] = std::make_unique(this, cq_id, static_cast(cq_id)); // Need to do this since CommandQueue constructor is private sw_command_queues_.push_back(std::unique_ptr(new CommandQueue(this, cq_id))); } @@ -1530,6 +1548,24 @@ std::vector Device::ethernet_cores_from_logical_cores(const std::vect return ethernet_cores; } +uint32_t Device::get_noc_unicast_encoding(uint8_t noc_index, const CoreCoord& physical_core) const { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + return NOC_XY_ENCODING( + NOC_0_X(noc_index, grid_size.x, physical_core.x), + NOC_0_Y(noc_index, grid_size.y, physical_core.y) + ); +} + +uint32_t Device::get_noc_multicast_encoding(uint8_t noc_index, const CoreRange& physical_cores) const { + const auto& grid_size = tt::Cluster::instance().get_soc_desc(this->id()).grid_size; + return NOC_MULTICAST_ENCODING( + NOC_0_X(noc_index, grid_size.x, physical_cores.start.x), + NOC_0_Y(noc_index, grid_size.y, physical_cores.start.y), + NOC_0_X(noc_index, grid_size.x, physical_cores.end.x), + NOC_0_Y(noc_index, grid_size.y, physical_cores.end.y) + ); +} + void Device::check_allocator_is_initialized() const { if (this->allocator_ == nullptr) { TT_THROW("No memory allocator! Device has not been initialized, did you forget to call InitializeDevice?"); diff --git a/tt_metal/impl/device/device.hpp b/tt_metal/impl/device/device.hpp index 7b054f030688..12df80a6bee1 100644 --- a/tt_metal/impl/device/device.hpp +++ b/tt_metal/impl/device/device.hpp @@ -11,6 +11,7 @@ #include "impl/dispatch/work_executor.hpp" #include "tt_metal/impl/allocator/basic_allocator.hpp" #include "tt_metal/impl/allocator/l1_banking_allocator.hpp" +#include "tt_metal/impl/kernels/data_types.hpp" #include "tt_metal/impl/trace/trace_buffer.hpp" #include "tt_metal/jit_build/build.hpp" #include "llrt/tt_cluster.hpp" @@ -192,6 +193,9 @@ class Device { // core.y represents different channels along one const std::set ðernet_cores() const { return this->ethernet_cores_; } + uint32_t get_noc_unicast_encoding(uint8_t noc_index, const CoreCoord& physical_core) const; + uint32_t get_noc_multicast_encoding(uint8_t noc_index, const CoreRange& physical_cores) const; + void deallocate_buffers(); // machine epsilon @@ -229,7 +233,7 @@ class Device { void initialize_command_queue(); void initialize_synchronous_sw_cmd_queue(); void configure_kernel_variant(Program& program, string path, std::vector compile_args, CoreCoord kernel_core, CoreCoord Kernel_physical_core, - CoreType dispatch_core_type, CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in , bool is_active_eth_core = false); + CoreType dispatch_core_type, CoreCoord upstream_physical_core, CoreCoord downstream_physical_core, std::map defines_in, NOC noc_index, bool is_active_eth_core = false); void compile_command_queue_programs(); void configure_command_queue_programs(); void clear_l1_state(); diff --git a/tt_metal/impl/dispatch/command_queue.cpp b/tt_metal/impl/dispatch/command_queue.cpp index 5df863d7b3bc..8c061bd40eeb 100644 --- a/tt_metal/impl/dispatch/command_queue.cpp +++ b/tt_metal/impl/dispatch/command_queue.cpp @@ -43,16 +43,12 @@ namespace tt::tt_metal { thread_local std::unordered_map EnqueueProgramCommand::cached_program_command_sequences = {}; -uint32_t get_noc_unicast_encoding(const CoreCoord& coord) { return NOC_XY_ENCODING(NOC_X(coord.x), NOC_Y(coord.y)); } -uint32_t get_noc_multicast_encoding(const CoreCoord& start, const CoreCoord& end) { - return NOC_MULTICAST_ENCODING(start.x, start.y, end.x, end.y); -} - // EnqueueReadBufferCommandSection EnqueueReadBufferCommand::EnqueueReadBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -60,6 +56,7 @@ EnqueueReadBufferCommand::EnqueueReadBufferCommand( uint32_t src_page_index, std::optional pages_to_read) : command_queue_id(command_queue_id), + noc_index(noc_index), dst(dst), manager(manager), buffer(buffer), @@ -89,7 +86,7 @@ void EnqueueReadShardedBufferCommand::add_prefetch_relay(HugepageDeviceCommand& const CoreCoord physical_core = this->buffer.device()->physical_core_from_logical_core(this->core, this->buffer.core_type()); command.add_prefetch_relay_linear( - get_noc_unicast_encoding(physical_core), padded_page_size * this->pages_to_read, this->bank_base_address); + this->device->get_noc_unicast_encoding(this->noc_index, physical_core), padded_page_size * this->pages_to_read, this->bank_base_address); } void EnqueueReadBufferCommand::process() { @@ -125,6 +122,7 @@ void EnqueueReadBufferCommand::process() { EnqueueWriteBufferCommand::EnqueueWriteBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -135,6 +133,7 @@ EnqueueWriteBufferCommand::EnqueueWriteBufferCommand( uint32_t dst_page_index, std::optional pages_to_write) : command_queue_id(command_queue_id), + noc_index(noc_index), manager(manager), issue_wait(issue_wait), src(src), @@ -211,7 +210,7 @@ void EnqueueWriteShardedBufferCommand::add_dispatch_write(HugepageDeviceCommand& this->buffer.device()->physical_core_from_logical_core(this->core, this->buffer.core_type()); bool flush_prefetch = true; command_sequence.add_dispatch_write_linear( - flush_prefetch, 0, get_noc_unicast_encoding(physical_core), this->bank_base_address, data_size_bytes); + flush_prefetch, 0, this->device->get_noc_unicast_encoding(this->noc_index, physical_core), this->bank_base_address, data_size_bytes); } void EnqueueWriteShardedBufferCommand::add_buffer_data(HugepageDeviceCommand& command_sequence) { @@ -287,10 +286,12 @@ void EnqueueWriteBufferCommand::process() { EnqueueProgramCommand::EnqueueProgramCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Program& program, SystemMemoryManager& manager, uint32_t expected_num_workers_completed) : command_queue_id(command_queue_id), + noc_index(noc_index), manager(manager), expected_num_workers_completed(expected_num_workers_completed), program(program) { @@ -462,13 +463,12 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { // can make a vector of unicast encodings here CoreCoord physical_core = device->physical_core_from_logical_core(core_coord, kernel->get_kernel_core_type()); - uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); const auto& runtime_args_data = kernel->runtime_args(core_coord); unique_rt_args_data[processor_idx].emplace_back(kernel->runtime_args_data(core_coord)); // 2, 17, could be differnet len here unique_sub_cmds[processor_idx].emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_core)}); unique_rt_data_and_sizes[processor_idx].emplace_back( runtime_args_data.data(), runtime_args_data.size() * sizeof(uint32_t)); unique_max_runtime_args_len[processor_idx] = @@ -496,12 +496,11 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { for (auto& core_coord : kernel->logical_cores()) { // can make a vector of unicast encodings here CoreCoord physical_core = device->ethernet_core_from_logical_core(core_coord); - uint32_t unicast_noc_encoding = get_noc_unicast_encoding(physical_core); unicast_sub_cmd.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = unicast_noc_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_core)}); } } else { - vector> dst_noc_multicast_info = + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, kernel->logical_coreranges(), kernel->get_kernel_core_type()); common_sub_cmds[kernel_id].emplace>( @@ -511,7 +510,7 @@ void EnqueueProgramCommand::assemble_runtime_args_commands() { multicast_sub_cmd.reserve(dst_noc_multicast_info.size()); for (const auto& mcast_dests : dst_noc_multicast_info) { multicast_sub_cmd.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = mcast_dests.first, .num_mcast_dests = mcast_dests.second}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, std::get(mcast_dests.first)), .num_mcast_dests = mcast_dests.second}); } } } @@ -634,7 +633,6 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const CoreRange& core_range : circular_buffers_unique_coreranges) { const CoreCoord physical_start = device->worker_core_from_logical_core(core_range.start); const CoreCoord physical_end = device->worker_core_from_logical_core(core_range.end); - const uint32_t dst_noc_multicast_encoding = get_noc_multicast_encoding(physical_start, physical_end); const uint32_t num_receivers = core_range.size(); auto& cb_config_payload = cb_config_payloads[i]; @@ -659,7 +657,7 @@ void EnqueueProgramCommand::assemble_device_commands() { } } multicast_cb_config_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_multicast_encoding, .num_mcast_dests = (uint32_t)core_range.size()}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, CoreRange(physical_start, physical_end)), .num_mcast_dests = (uint32_t)core_range.size()}); multicast_cb_config_data.emplace_back( cb_config_payload.data(), (max_base_index + UINT32_WORDS_PER_CIRCULAR_BUFFER_CONFIG) * sizeof(uint32_t)); @@ -683,7 +681,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (int buffer_idx = 0; buffer_idx < program.program_transfer_info.kernel_bins.size(); buffer_idx++) { const auto& kg_transfer_info = program.program_transfer_info.kernel_bins[buffer_idx]; for (int kernel_idx = 0; kernel_idx < kg_transfer_info.dst_base_addrs.size(); kernel_idx++) { - for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; cmd_sequence_sizeB += CQ_PREFETCH_CMD_BARE_MIN_SIZE; } @@ -709,9 +707,8 @@ void EnqueueProgramCommand::assemble_device_commands() { CoreCoord physical_end = device->physical_core_from_logical_core(core_range.end, kernel_group.get_core_type()); - uint32_t dst_noc_multicast_encoding = get_noc_multicast_encoding(physical_start, physical_end); multicast_go_signal_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_multicast_encoding, .num_mcast_dests = (uint32_t)core_range.size()}); + .noc_xy_addr = this->device->get_noc_multicast_encoding(this->noc_index, CoreRange(physical_start, physical_end)), .num_mcast_dests = (uint32_t)core_range.size()}); multicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); } } @@ -733,9 +730,8 @@ void EnqueueProgramCommand::assemble_device_commands() { for (auto y = core_range.start.y; y <= core_range.end.y; y++) { CoreCoord physical_coord = device->physical_core_from_logical_core(CoreCoord({x, y}), kernel_group.get_core_type()); - uint32_t dst_noc_unicast_encoding = get_noc_unicast_encoding(physical_coord); unicast_go_signal_sub_cmds.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = dst_noc_unicast_encoding}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, physical_coord)}); unicast_go_signal_data.emplace_back(launch_message_data, go_signal_sizeB); } } @@ -768,7 +764,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const auto& dst_noc_info : transfer_info.dst_noc_info) { num_packed_cmds += 1; multicast_sub_cmds.emplace_back(CQDispatchWritePackedMulticastSubCmd{ - .noc_xy_addr = dst_noc_info.first, .num_mcast_dests = dst_noc_info.second}); + .noc_xy_addr =this->device->get_noc_multicast_encoding(this->noc_index, std::get(dst_noc_info.first)), .num_mcast_dests = dst_noc_info.second}); sem_data.emplace_back(transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); } } @@ -796,7 +792,7 @@ void EnqueueProgramCommand::assemble_device_commands() { for (const auto& dst_noc_info : transfer_info.dst_noc_info) { num_packed_cmds += 1; unicast_sub_cmds.emplace_back( - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = dst_noc_info.first}); + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr =this->device->get_noc_unicast_encoding(this->noc_index, std::get(dst_noc_info.first))}); sem_data.emplace_back(transfer_info.data.data(), transfer_info.data.size() * sizeof(uint32_t)); } } @@ -828,11 +824,22 @@ void EnqueueProgramCommand::assemble_device_commands() { for (int buffer_idx = 0; buffer_idx < program.program_transfer_info.kernel_bins.size(); buffer_idx++) { const auto& kg_transfer_info = program.program_transfer_info.kernel_bins[buffer_idx]; for (int kernel_idx = 0; kernel_idx < kg_transfer_info.dst_base_addrs.size(); kernel_idx++) { - for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + for (const pair& dst_noc_info : kg_transfer_info.dst_noc_info) { + uint32_t noc_encoding; + std::visit( + [&](auto&& cores) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + noc_encoding = this->device->get_noc_multicast_encoding(this->noc_index, cores); + } else { + noc_encoding = this->device->get_noc_unicast_encoding(this->noc_index, cores); + } + }, + dst_noc_info.first); program_command_sequence.add_dispatch_write_linear( false, // flush_prefetch dst_noc_info.second, // num_mcast_dests - dst_noc_info.first, // noc_xy_addr + noc_encoding, // noc_xy_addr kg_transfer_info.dst_base_addrs[kernel_idx], align(kg_transfer_info.lengths[kernel_idx], NOC_DRAM_ALIGNMENT_BYTES)); // Difference between prefetch total relayed pages and dispatch write linear @@ -868,11 +875,9 @@ void EnqueueProgramCommand::assemble_device_commands() { } } - // Wait Noc Write Barrier, wait for binaries to be written to worker cores + // Wait Noc Write Barrier, wait for binaries/configs to be written to worker cores if (program.program_transfer_info.num_active_cores > 0) { - // Wait Noc Write Barrier, wait for binaries to be written to worker cores - // TODO: any way to not have dispatcher poll the addr here? - program_command_sequence.add_dispatch_wait(true, DISPATCH_MESSAGE_ADDR, 0); + program_command_sequence.add_dispatch_wait(true, DISPATCH_MESSAGE_ADDR, 0, 0, false, false); } // Go Signals @@ -1026,12 +1031,14 @@ void EnqueueProgramCommand::process() { EnqueueRecordEventCommand::EnqueueRecordEventCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, SystemMemoryManager& manager, uint32_t event_id, uint32_t expected_num_workers_completed, bool clear_count) : command_queue_id(command_queue_id), device(device), + noc_index(noc_index), manager(manager), event_id(event_id), expected_num_workers_completed(expected_num_workers_completed), @@ -1080,7 +1087,7 @@ void EnqueueRecordEventCommand::process() { CoreCoord dispatch_physical_core = get_physical_core_coordinate(dispatch_location, core_type); unicast_sub_cmds[cq_id] = - CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = get_noc_unicast_encoding(dispatch_physical_core)}; + CQDispatchWritePackedUnicastSubCmd{.noc_xy_addr = this->device->get_noc_unicast_encoding(this->noc_index, dispatch_physical_core)}; event_payloads[cq_id] = {event_payload.data(), event_payload.size() * sizeof(uint32_t)}; } @@ -1209,11 +1216,12 @@ void EnqueueTerminateCommand::process() { } // HWCommandQueue section -HWCommandQueue::HWCommandQueue(Device* device, uint32_t id) : +HWCommandQueue::HWCommandQueue(Device* device, uint32_t id, NOC noc_index) : manager(device->sysmem_manager()), completion_queue_thread{} { ZoneScopedN("CommandQueue_constructor"); this->device = device; this->id = id; + this->noc_index = noc_index; this->num_entries_in_completion_q = 0; this->num_completed_completion_q_reads = 0; @@ -1340,6 +1348,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin auto command = EnqueueReadShardedBufferCommand( this->id, this->device, + this->noc_index, buffer, dst, this->manager, @@ -1376,6 +1385,7 @@ void HWCommandQueue::enqueue_read_buffer(Buffer& buffer, void* dst, bool blockin auto command = EnqueueReadInterleavedBufferCommand( this->id, this->device, + this->noc_index, buffer, dst, this->manager, @@ -1514,6 +1524,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, auto command = EnqueueWriteShardedBufferCommand( this->id, this->device, + this->noc_index, buffer, src, this->manager, @@ -1605,6 +1616,7 @@ void HWCommandQueue::enqueue_write_buffer(const Buffer& buffer, const void* src, auto command = EnqueueWriteInterleavedBufferCommand( this->id, this->device, + this->noc_index, buffer, src, this->manager, @@ -1646,7 +1658,7 @@ void HWCommandQueue::enqueue_program(Program& program, bool blocking) { // Snapshot of expected workers from previous programs, used for dispatch_wait cmd generation. uint32_t expected_workers_completed = this->manager.get_bypass_mode() ? this->trace_ctx->num_completion_worker_cores : this->expected_num_workers_completed; - auto command = EnqueueProgramCommand(this->id, this->device, program, this->manager, expected_workers_completed); + auto command = EnqueueProgramCommand(this->id, this->device, this->noc_index, program, this->manager, expected_workers_completed); this->enqueue_command(command, blocking); log_trace( @@ -1677,7 +1689,7 @@ void HWCommandQueue::enqueue_record_event(std::shared_ptr event, bool cle event->ready = true; // what does this mean??? auto command = EnqueueRecordEventCommand( - this->id, this->device, this->manager, event->event_id, this->expected_num_workers_completed, clear_count); + this->id, this->device, this->noc_index, this->manager, event->event_id, this->expected_num_workers_completed, clear_count); this->enqueue_command(command, false); if (clear_count) { @@ -2295,9 +2307,6 @@ void EnqueueProgramImpl( } void EnqueueRecordEvent(CommandQueue& cq, std::shared_ptr event) { - TT_ASSERT(event->device == nullptr, "EnqueueRecordEvent expected to be given an uninitialized event"); - TT_ASSERT(event->event_id == -1, "EnqueueRecordEvent expected to be given an uninitialized event"); - TT_ASSERT(event->cq_id == -1, "EnqueueRecordEvent expected to be given an uninitialized event"); detail::DispatchStateCheck(true); cq.run_command(CommandInterface{ diff --git a/tt_metal/impl/dispatch/command_queue.hpp b/tt_metal/impl/dispatch/command_queue.hpp index 578724880f00..9809824eab57 100644 --- a/tt_metal/impl/dispatch/command_queue.hpp +++ b/tt_metal/impl/dispatch/command_queue.hpp @@ -55,9 +55,6 @@ string EnqueueCommandTypeToString(EnqueueCommandType ctype); #define NOC_X(x) x #define NOC_Y(y) y -uint32_t get_noc_unicast_encoding(const CoreCoord& coord); -uint32_t get_noc_multicast_encoding(const CoreCoord& start, const CoreCoord& end); - class CommandQueue; class CommandInterface; @@ -74,13 +71,14 @@ class EnqueueReadBufferCommand : public Command { private: SystemMemoryManager& manager; void* dst; - uint32_t command_queue_id; CoreType dispatch_core_type; virtual void add_prefetch_relay(HugepageDeviceCommand& command) = 0; protected: Device* device; + uint32_t command_queue_id; + NOC noc_index; uint32_t expected_num_workers_completed; uint32_t src_page_index; uint32_t pages_to_read; @@ -90,6 +88,7 @@ class EnqueueReadBufferCommand : public Command { EnqueueReadBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -112,6 +111,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadInterleavedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -121,6 +121,7 @@ class EnqueueReadInterleavedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadBufferCommand( command_queue_id, device, + noc_index, buffer, dst, manager, @@ -139,6 +140,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadShardedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Buffer& buffer, void* dst, SystemMemoryManager& manager, @@ -150,6 +152,7 @@ class EnqueueReadShardedBufferCommand : public EnqueueReadBufferCommand { EnqueueReadBufferCommand( command_queue_id, device, + noc_index, buffer, dst, manager, @@ -165,7 +168,6 @@ class EnqueueWriteInterleavedBufferCommand; class EnqueueWriteBufferCommand : public Command { private: SystemMemoryManager& manager; - uint32_t command_queue_id; CoreType dispatch_core_type; virtual void add_dispatch_write(HugepageDeviceCommand& command) = 0; @@ -173,6 +175,8 @@ class EnqueueWriteBufferCommand : public Command { protected: Device* device; + uint32_t command_queue_id; + NOC noc_index; const void* src; const Buffer& buffer; uint32_t expected_num_workers_completed; @@ -186,6 +190,7 @@ class EnqueueWriteBufferCommand : public Command { EnqueueWriteBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -212,6 +217,7 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteInterleavedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -224,6 +230,7 @@ class EnqueueWriteInterleavedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteBufferCommand( command_queue_id, device, + noc_index, buffer, src, manager, @@ -249,6 +256,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteShardedBufferCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, const Buffer& buffer, const void* src, SystemMemoryManager& manager, @@ -263,6 +271,7 @@ class EnqueueWriteShardedBufferCommand : public EnqueueWriteBufferCommand { EnqueueWriteBufferCommand( command_queue_id, device, + noc_index, buffer, src, manager, @@ -282,6 +291,7 @@ class EnqueueProgramCommand : public Command { private: uint32_t command_queue_id; Device* device; + NOC noc_index; Program& program; SystemMemoryManager& manager; CoreType dispatch_core_type; @@ -302,6 +312,7 @@ class EnqueueProgramCommand : public Command { EnqueueProgramCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, Program& program, SystemMemoryManager& manager, uint32_t expected_num_workers_completed); @@ -321,6 +332,7 @@ class EnqueueRecordEventCommand : public Command { private: uint32_t command_queue_id; Device* device; + NOC noc_index; SystemMemoryManager& manager; uint32_t event_id; uint32_t expected_num_workers_completed; @@ -330,6 +342,7 @@ class EnqueueRecordEventCommand : public Command { EnqueueRecordEventCommand( uint32_t command_queue_id, Device* device, + NOC noc_index, SystemMemoryManager& manager, uint32_t event_id, uint32_t expected_num_workers_completed, @@ -474,11 +487,12 @@ struct RuntimeArgsMetadata { class HWCommandQueue { public: - HWCommandQueue(Device* device, uint32_t id); + HWCommandQueue(Device* device, uint32_t id, NOC noc_index); ~HWCommandQueue(); CoreCoord completion_queue_writer_core; + NOC noc_index; volatile bool is_dprint_server_hung(); volatile bool is_noc_hung(); diff --git a/tt_metal/impl/dispatch/cq_commands.hpp b/tt_metal/impl/dispatch/cq_commands.hpp index f4a4ddb0a446..db16fa618211 100644 --- a/tt_metal/impl/dispatch/cq_commands.hpp +++ b/tt_metal/impl/dispatch/cq_commands.hpp @@ -162,6 +162,9 @@ struct CQDispatchWaitCmd { uint8_t barrier; // if true, issue write barrier uint8_t notify_prefetch; // if true, inc prefetch sem uint8_t clear_count; // if true, reset count to 0 + uint8_t wait; // if true, wait on count value below + uint8_t pad1; + uint16_t pad2; uint32_t addr; // address to read uint32_t count; // wait while address is < count } __attribute__((packed)); diff --git a/tt_metal/impl/dispatch/device_command.hpp b/tt_metal/impl/dispatch/device_command.hpp index 67977c63797e..e8c1255a8b52 100644 --- a/tt_metal/impl/dispatch/device_command.hpp +++ b/tt_metal/impl/dispatch/device_command.hpp @@ -73,7 +73,7 @@ class DeviceCommand { vector_memcpy_aligned cmd_vector() const { return this->cmd_region_vector; } void add_dispatch_wait( - uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0, bool notify_prefetch = false) { + uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0, bool notify_prefetch = false, bool do_wait = true) { auto initialize_wait_cmds = [&](CQPrefetchCmd *relay_wait, CQDispatchCmd *wait_cmd) { relay_wait->base.cmd_id = CQ_PREFETCH_CMD_RELAY_INLINE; relay_wait->relay_inline.length = sizeof(CQDispatchCmd); @@ -82,6 +82,7 @@ class DeviceCommand { wait_cmd->base.cmd_id = CQ_DISPATCH_CMD_WAIT; wait_cmd->wait.barrier = barrier; wait_cmd->wait.notify_prefetch = notify_prefetch; + wait_cmd->wait.wait = do_wait; wait_cmd->wait.addr = address; wait_cmd->wait.count = count; wait_cmd->wait.clear_count = clear_count; @@ -101,8 +102,8 @@ class DeviceCommand { } void add_dispatch_wait_with_prefetch_stall( - uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0) { - this->add_dispatch_wait(barrier, address, count, clear_count, true); + uint8_t barrier, uint32_t address, uint32_t count, uint8_t clear_count = 0, bool do_wait = true) { + this->add_dispatch_wait(barrier, address, count, clear_count, true, do_wait); uint32_t increment_sizeB = align(sizeof(CQPrefetchCmd), PCIE_ALIGNMENT); auto initialize_stall_cmd = [&](CQPrefetchCmd *stall_cmd) { *stall_cmd = {}; diff --git a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp index a506c16df3ea..07bf38efdb20 100644 --- a/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_dispatch.cpp @@ -10,15 +10,15 @@ // - # blocks must evenly divide the dispatch buffer size // - dispatch buffer base must be page size aligned +#include "debug/assert.h" +#include "debug/dprint.h" #include "tt_metal/impl/dispatch/cq_commands.hpp" #include "tt_metal/impl/dispatch/dispatch_address_map.hpp" #include "tt_metal/impl/dispatch/kernels/cq_common.hpp" #include "tt_metal/impl/dispatch/kernels/packet_queue_ctrl.hpp" -#include "debug/dprint.h" -#include "debug/assert.h" -// The command queue write interface controls writes to the completion region, host owns the completion region read interface -// Data requests from device and event states are written to the completion region +// The command queue write interface controls writes to the completion region, host owns the completion region read +// interface Data requests from device and event states are written to the completion region CQWriteInterface cq_write_interface; @@ -43,7 +43,7 @@ constexpr uint32_t is_h_variant = get_compile_time_arg_val(16); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); -constexpr uint32_t pcie_noc_xy_encoding = uint32_t(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y)); +constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_PCIE_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y), NOC_INDEX)); constexpr uint32_t dispatch_cb_page_size = 1 << dispatch_cb_log_page_size; constexpr uint32_t completion_queue_end_addr = completion_queue_base_addr + completion_queue_size; @@ -57,7 +57,6 @@ constexpr uint32_t dispatch_cb_size = dispatch_cb_page_size * dispatch_cb_pages; constexpr uint32_t dispatch_cb_end = dispatch_cb_base + dispatch_cb_size; constexpr uint32_t downstream_cb_end = downstream_cb_base + downstream_cb_size; - // Break buffer into blocks, 1/n of the total (dividing equally) // Do bookkeeping (release, etc) based on blocks // Note: due to the current method of release pages, up to 1 block of pages @@ -69,14 +68,17 @@ static uint32_t block_noc_writes_to_clear[dispatch_cb_blocks]; static uint32_t rd_block_idx; static uint32_t wr_block_idx; -static uint32_t cb_fence; // walks through cb page by page -static uint32_t cmd_ptr; // walks through pages in cb cmd by cmd +static uint32_t cb_fence; // walks through cb page by page +static uint32_t cmd_ptr; // walks through pages in cb cmd by cmd static uint32_t downstream_cb_data_ptr = downstream_cb_base; constexpr uint32_t l1_to_local_cache_copy_chunk = 6; -constexpr uint32_t max_write_packed_cores = 108; // GS 120 - 1 row TODO: this should be a compile time arg passed in from host -constexpr uint32_t l1_cache_size = ((max_write_packed_cores + l1_to_local_cache_copy_chunk - 1) / l1_to_local_cache_copy_chunk) * l1_to_local_cache_copy_chunk; +constexpr uint32_t max_write_packed_cores = + 108; // GS 120 - 1 row TODO: this should be a compile time arg passed in from host +constexpr uint32_t l1_cache_size = + ((max_write_packed_cores + l1_to_local_cache_copy_chunk - 1) / l1_to_local_cache_copy_chunk) * + l1_to_local_cache_copy_chunk; static uint32_t l1_cache[l1_cache_size]; @@ -105,12 +107,12 @@ void careful_copy_from_l1_to_local_cache(volatile uint32_t tt_l1_ptr *l1_ptr, ui } } -FORCE_INLINE volatile uint32_t* get_cq_completion_read_ptr() { - return reinterpret_cast(CQ_COMPLETION_READ_PTR); +FORCE_INLINE volatile uint32_t *get_cq_completion_read_ptr() { + return reinterpret_cast(CQ_COMPLETION_READ_PTR); } -FORCE_INLINE volatile uint32_t* get_cq_completion_write_ptr() { - return reinterpret_cast(CQ_COMPLETION_WRITE_PTR); +FORCE_INLINE volatile uint32_t *get_cq_completion_write_ptr() { + return reinterpret_cast(CQ_COMPLETION_WRITE_PTR); } FORCE_INLINE @@ -130,9 +132,10 @@ void completion_queue_reserve_back(uint32_t num_pages) { // so available space is distance from write ptr to read ptr // Toggles are equal means write ptr is ahead of read ptr // so available space is total space minus the distance from read to write ptr - available_space = completion_rd_toggle != cq_write_interface.completion_fifo_wr_toggle ? - completion_rd_ptr - cq_write_interface.completion_fifo_wr_ptr : - (completion_queue_size_16B - (cq_write_interface.completion_fifo_wr_ptr - completion_rd_ptr)); + available_space = + completion_rd_toggle != cq_write_interface.completion_fifo_wr_toggle + ? completion_rd_ptr - cq_write_interface.completion_fifo_wr_ptr + : (completion_queue_size_16B - (cq_write_interface.completion_fifo_wr_ptr - completion_rd_ptr)); } while (data_size_16B > available_space); DEBUG_STATUS("QRBD"); @@ -141,7 +144,7 @@ void completion_queue_reserve_back(uint32_t num_pages) { FORCE_INLINE void notify_host_of_completion_queue_write_pointer() { uint64_t completion_queue_write_ptr_addr = command_queue_base_addr + HOST_CQ_COMPLETION_WRITE_PTR; - uint64_t pcie_address = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_ptr_addr); // For now, we are writing to host hugepages at offset + uint64_t pcie_address = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_ptr_addr); // For now, we are writing to host hugepages at offset uint32_t completion_wr_ptr_and_toggle = cq_write_interface.completion_fifo_wr_ptr | (cq_write_interface.completion_fifo_wr_toggle << 31); volatile tt_l1_ptr uint32_t* completion_wr_ptr_addr = get_cq_completion_write_ptr(); completion_wr_ptr_addr[0] = completion_wr_ptr_and_toggle; @@ -156,7 +159,8 @@ void completion_queue_push_back(uint32_t num_pages) { cq_write_interface.completion_fifo_wr_ptr += push_size_16B; if (cq_write_interface.completion_fifo_wr_ptr >= completion_queue_end_addr_16B) { - cq_write_interface.completion_fifo_wr_ptr = cq_write_interface.completion_fifo_wr_ptr - completion_queue_end_addr_16B + completion_queue_base_addr_16B; + cq_write_interface.completion_fifo_wr_ptr = + cq_write_interface.completion_fifo_wr_ptr - completion_queue_end_addr_16B + completion_queue_base_addr_16B; // Flip the toggle cq_write_interface.completion_fifo_wr_toggle = not cq_write_interface.completion_fifo_wr_toggle; } @@ -184,31 +188,28 @@ void process_write_host_h() { cb_fence = dispatch_cb_base; data_ptr = dispatch_cb_base; } - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx);; + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); + ; cb_fence += n_pages * dispatch_cb_page_size; // Release pages for prefetcher // Since we gate how much we acquire to < 1/4 the buffer, this should be called enough - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } uint32_t available_data = cb_fence - data_ptr; uint32_t xfer_size = (length > available_data) ? available_data : length; uint32_t npages = (xfer_size + completion_queue_page_size - 1) / completion_queue_page_size; completion_queue_reserve_back(npages); uint32_t completion_queue_write_addr = cq_write_interface.completion_fifo_wr_ptr << 4; - uint64_t host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_addr); + uint64_t host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_addr); // completion_queue_write_addr will never be equal to completion_queue_end_addr due to completion_queue_push_back // wrap logic so we don't need to handle this case explicitly to avoid 0 sized transactions if (completion_queue_write_addr + xfer_size > completion_queue_end_addr) { @@ -218,7 +219,7 @@ void process_write_host_h() { data_ptr += last_chunk_size; length -= last_chunk_size; xfer_size -= last_chunk_size; - host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy_encoding, completion_queue_write_addr); + host_completion_queue_write_addr = get_noc_addr_helper(pcie_noc_xy, completion_queue_write_addr); block_noc_writes_to_clear[rd_block_idx]+=(last_chunk_size + NOC_MAX_BURST_SIZE - 1) / NOC_MAX_BURST_SIZE; // XXXXX maybe just write the noc internal api counter } noc_async_write(data_ptr, host_completion_queue_write_addr, xfer_size); @@ -226,7 +227,9 @@ void process_write_host_h() { // We flush to ensure the ptr has been read out of l1 before we update it again completion_queue_push_back(npages); noc_async_writes_flushed(); - block_noc_writes_to_clear[rd_block_idx]+=(xfer_size + NOC_MAX_BURST_SIZE - 1) / NOC_MAX_BURST_SIZE; // XXXXX maybe just write the noc internal api counter + block_noc_writes_to_clear[rd_block_idx] += + (xfer_size + NOC_MAX_BURST_SIZE - 1) / + NOC_MAX_BURST_SIZE; // XXXXX maybe just write the noc internal api counter length -= xfer_size; data_ptr += xfer_size; @@ -234,59 +237,58 @@ void process_write_host_h() { cmd_ptr = data_ptr; } -template -void relay_to_next_cb(uint32_t data_ptr, - uint32_t length) { - +// Relay, potentially through the mux/dmux/tunneller path +// Code below sends 1 page worth of data except at the end of a cmd +// This means the downstream buffers are always page aligned, simplifies wrap handling +template +void relay_to_next_cb(uint32_t data_ptr, uint32_t length) { static_assert( preamble_size == 0 || preamble_size == sizeof(dispatch_packet_header_t), "Dispatcher preamble size must be 0 or sizeof(dispatch_packet_header_t)"); DPRINT << "relay_to_next_cb: " << data_ptr << " " << cb_fence << " " << length << ENDL(); - bool page_acquired = false; - // The downstream packetizing stage will initialize the other fields, but it needs info on - // the length of the transfer to be packetized. - if (preamble_size > 0) { - cb_acquire_pages(1); // XXXX optimize, take all availabl - page_acquired = true; - ASSERT(downstream_cb_data_ptr != downstream_cb_end); - - uint64_t downstream_noc_addr = get_noc_addr_helper(downstream_noc_xy, downstream_cb_data_ptr); - noc_inline_dw_write(downstream_noc_addr, length + preamble_size); - block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter - downstream_cb_data_ptr += preamble_size; - } - // First page should be valid since it has the command ASSERT(data_ptr <= dispatch_cb_end - dispatch_cb_page_size); ASSERT(data_ptr <= cb_fence - dispatch_cb_page_size); - uint32_t extra = preamble_size; while (length > 0) { - ASSERT (downstream_cb_end > downstream_cb_data_ptr); + ASSERT(downstream_cb_end > downstream_cb_data_ptr); + + cb_acquire_pages(1); + + uint32_t xfer_size; + bool not_end_of_cmd; + if (length > dispatch_cb_page_size - preamble_size) { + xfer_size = dispatch_cb_page_size - preamble_size; + not_end_of_cmd = true; + } else { + xfer_size = length; + not_end_of_cmd = false; + } - uint32_t xfer_size = (length > dispatch_cb_page_size - extra) ? - dispatch_cb_page_size - extra : - length; uint64_t dst = get_noc_addr_helper(downstream_noc_xy, downstream_cb_data_ptr); + if (preamble_size > 0) { + uint32_t flag; + noc_inline_dw_write(dst, xfer_size + preamble_size + not_end_of_cmd); + block_noc_writes_to_clear[rd_block_idx]++; + downstream_cb_data_ptr += preamble_size; + dst = get_noc_addr_helper(downstream_noc_xy, downstream_cb_data_ptr); + ASSERT(downstream_cb_data_ptr < downstream_cb_end); + } + // Get a page if needed if (data_ptr + xfer_size > cb_fence) { // Check for block completion if (cb_fence == block_next_start_addr[rd_block_idx]) { // Check for dispatch_cb wrap if (rd_block_idx == dispatch_cb_blocks - 1) { - // We can be misalgined when orphan_size is non=zero - // Code could be structured to stay aligned after wrap, - // but instead making this behave like other routines - uint32_t orphan_size = preamble_size; - ASSERT(dispatch_cb_end - data_ptr == preamble_size); + ASSERT(cb_fence == dispatch_cb_end); + uint32_t orphan_size = cb_fence - data_ptr; if (orphan_size != 0) { - cb_acquire_pages(1); // XXXX optimize, take all availabl noc_async_write(data_ptr, dst, orphan_size); block_noc_writes_to_clear[rd_block_idx]++; - page_acquired = true; length -= orphan_size; xfer_size -= orphan_size; downstream_cb_data_ptr += orphan_size; @@ -299,34 +301,26 @@ void relay_to_next_cb(uint32_t data_ptr, data_ptr = dispatch_cb_base; } - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx); + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); cb_fence += n_pages * dispatch_cb_page_size; // Release pages for prefetcher // Since we gate how much we acquire to < 1/4 the buffer, this should be called enough - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } - // Get downstream page - if (page_acquired == false) { - cb_acquire_pages(1); // XXXX optimize, take all available - } noc_async_write(data_ptr, dst, xfer_size); - block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter - cb_release_pages(1); // XXXX optimize, take all available + block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter + cb_release_pages(1); // XXXX optimize, take all available length -= xfer_size; data_ptr += xfer_size; @@ -334,8 +328,6 @@ void relay_to_next_cb(uint32_t data_ptr, if (downstream_cb_data_ptr == downstream_cb_end) { downstream_cb_data_ptr = downstream_cb_base; } - page_acquired = false; - extra = 0; } // Move to next page @@ -348,7 +340,6 @@ void relay_to_next_cb(uint32_t data_ptr, } void process_write_host_d() { - volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr; // Remember: host transfer command includes the command in the payload, don't add it here uint32_t length = cmd->write_linear_host.length; @@ -358,7 +349,6 @@ void process_write_host_d() { } void relay_write_h() { - volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr; uint32_t length = sizeof(CQDispatchCmd) + cmd->write_linear.length; uint32_t data_ptr = cmd_ptr; @@ -368,7 +358,7 @@ void relay_write_h() { // Note that for non-paged writes, the number of writes per page is always 1 // This means each noc_write frees up a page -template +template void process_write_linear(uint32_t num_mcast_dests) { volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr; @@ -376,7 +366,6 @@ void process_write_linear(uint32_t num_mcast_dests) { uint32_t dst_addr = cmd->write_linear.addr; uint32_t length = cmd->write_linear.length; uint32_t data_ptr = cmd_ptr + sizeof(CQDispatchCmd); - DPRINT << "dispatch_write: " << length << " num_mcast_dests: " << num_mcast_dests << ENDL(); while (length != 0) { uint32_t xfer_size = (length > dispatch_cb_page_size) ? dispatch_cb_page_size : length; uint64_t dst = get_noc_addr_helper(dst_noc, dst_addr); @@ -389,8 +378,9 @@ void process_write_linear(uint32_t num_mcast_dests) { if (rd_block_idx == dispatch_cb_blocks - 1) { uint32_t orphan_size = dispatch_cb_end - data_ptr; if (orphan_size != 0) { - if constexpr (multicast){ - noc_async_write_multicast(data_ptr, dst, orphan_size, num_mcast_dests); + if constexpr (multicast) { + noc_async_write_multicast( + data_ptr, dst, orphan_size, num_mcast_dests); } else { noc_async_write(data_ptr, dst, orphan_size); } @@ -404,33 +394,29 @@ void process_write_linear(uint32_t num_mcast_dests) { dst = get_noc_addr_helper(dst_noc, dst_addr); } - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx); + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); cb_fence += n_pages * dispatch_cb_page_size; // Release pages for prefetcher // Since we gate how much we acquire to < 1/4 the buffer, this should be called enough - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } - if constexpr (multicast){ + if constexpr (multicast) { noc_async_write_multicast(data_ptr, dst, xfer_size, num_mcast_dests); } else { noc_async_write(data_ptr, dst, xfer_size); } - block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter + block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter length -= xfer_size; data_ptr += xfer_size; @@ -449,7 +435,7 @@ void process_write() { } } -template +template void process_write_paged() { volatile tt_l1_ptr CQDispatchCmd *cmd = (volatile tt_l1_ptr CQDispatchCmd *)cmd_ptr; @@ -462,15 +448,19 @@ void process_write_paged() { InterleavedAddrGen addr_gen; addr_gen.bank_base_address = base_addr; addr_gen.page_size = page_size; - uint64_t dst_addr_offset = 0; // Offset into page. + uint64_t dst_addr_offset = 0; // Offset into page. - DPRINT << "process_write_paged - pages: " << pages << " page_size: " << page_size << " dispatch_cb_page_size: " << dispatch_cb_page_size; + DPRINT << "process_write_paged - pages: " << pages << " page_size: " << page_size + << " dispatch_cb_page_size: " << dispatch_cb_page_size; DPRINT << " start_page: " << page_id << " base_addr: " << HEX() << base_addr << DEC() << ENDL(); while (write_length != 0) { - // TODO #7360: Have more performant handling when page_size > dispatch_cb_page_size by not doing multiple writes for one buffer page - uint32_t xfer_size = page_size > dispatch_cb_page_size ? min(dispatch_cb_page_size, page_size - dst_addr_offset) : page_size; - uint64_t dst = addr_gen.get_noc_addr(page_id, dst_addr_offset); // XXXX replace this w/ walking the banks to save mul on GS + // TODO #7360: Have more performant handling when page_size > dispatch_cb_page_size by not doing multiple writes + // for one buffer page + uint32_t xfer_size = + page_size > dispatch_cb_page_size ? min(dispatch_cb_page_size, page_size - dst_addr_offset) : page_size; + uint64_t dst = addr_gen.get_noc_addr( + page_id, dst_addr_offset); // XXXX replace this w/ walking the banks to save mul on GS // Get a Dispatch page if needed if (data_ptr + xfer_size > cb_fence) { @@ -490,31 +480,28 @@ void process_write_paged() { data_ptr = dispatch_cb_base; dst = addr_gen.get_noc_addr(page_id, dst_addr_offset); } - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx); + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); cb_fence += n_pages * dispatch_cb_page_size; // Release pages for prefetcher // Since we gate how much we acquire to < 1/4 the buffer, this should be called enough - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } noc_async_write(data_ptr, dst, xfer_size); - block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter + block_noc_writes_to_clear[rd_block_idx]++; // XXXXX maybe just write the noc internal api counter - // If paged write is not completed for a page (dispatch_cb_page_size < page_size) then add offset, otherwise incr page_id. + // If paged write is not completed for a page (dispatch_cb_page_size < page_size) then add offset, otherwise + // incr page_id. if (dst_addr_offset + xfer_size < page_size) { dst_addr_offset += xfer_size; } else { @@ -542,7 +529,7 @@ void process_write_paged() { // // Since all subcmds all appear in the first page and given the size restrictions // this command can't be too many pages. All pages are released at the end -template +template void process_write_packed(uint32_t flags) { volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; @@ -550,8 +537,8 @@ void process_write_packed(uint32_t flags) { ASSERT(count <= (mcast ? max_write_packed_cores / 2 : max_write_packed_cores)); constexpr uint32_t sub_cmd_size = sizeof(WritePackedSubCmd); // Copying in a burst is about a 30% net gain vs reading one value per loop below - careful_copy_from_l1_to_local_cache((volatile uint32_t tt_l1_ptr*)(cmd_ptr + sizeof(CQDispatchCmd)), - count * sub_cmd_size / sizeof(uint32_t)); + careful_copy_from_l1_to_local_cache( + (volatile uint32_t tt_l1_ptr *)(cmd_ptr + sizeof(CQDispatchCmd)), count * sub_cmd_size / sizeof(uint32_t)); uint32_t xfer_size = cmd->write_packed.size; uint32_t dst_addr = cmd->write_packed.addr; @@ -560,7 +547,8 @@ void process_write_packed(uint32_t flags) { uint32_t data_ptr = cmd_ptr + sizeof(CQDispatchCmd) + count * sizeof(WritePackedSubCmd); data_ptr = round_up_pow2(data_ptr, L1_NOC_ALIGNMENT); - uint32_t stride = (flags & CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE) ? 0 : round_up_pow2(xfer_size, L1_NOC_ALIGNMENT); + uint32_t stride = + (flags & CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_NO_STRIDE) ? 0 : round_up_pow2(xfer_size, L1_NOC_ALIGNMENT); DPRINT << data_ptr << " " << cmd_ptr << " " << xfer_size << " " << dispatch_cb_page_size << ENDL(); ASSERT(stride != 0 || data_ptr - cmd_ptr + xfer_size <= dispatch_cb_page_size); @@ -573,9 +561,7 @@ void process_write_packed(uint32_t flags) { WritePackedSubCmd *sub_cmd_ptr = (WritePackedSubCmd *)l1_cache; while (count != 0) { uint32_t dst_noc = sub_cmd_ptr->noc_xy_addr; - uint32_t num_dests = mcast ? - ((CQDispatchWritePackedMulticastSubCmd *)sub_cmd_ptr)->num_mcast_dests : - 1; + uint32_t num_dests = mcast ? ((CQDispatchWritePackedMulticastSubCmd *)sub_cmd_ptr)->num_mcast_dests : 1; sub_cmd_ptr++; uint64_t dst = get_noc_addr_helper(dst_noc, dst_addr); // Get a page if needed @@ -601,16 +587,12 @@ void process_write_packed(uint32_t flags) { noc_nonposted_writes_acked[noc_index] += mcasts; writes = 0; mcasts = 0; - move_rd_to_next_block(block_noc_writes_to_clear, - rd_block_idx); + move_rd_to_next_block(block_noc_writes_to_clear, rd_block_idx); } // Wait for dispatcher to supply a page (this won't go beyond the buffer end) - uint32_t n_pages = cb_acquire_pages(cb_fence, - block_next_start_addr, - rd_block_idx); + uint32_t n_pages = cb_acquire_pages( + cb_fence, block_next_start_addr, rd_block_idx); cb_fence += n_pages * dispatch_cb_page_size; // This is done here so the common case doesn't have to restore the pointers @@ -644,17 +626,16 @@ void process_write_packed(uint32_t flags) { noc_nonposted_writes_acked[noc_index] += mcasts; // Release pages for prefetcher // write_packed releases pages at the end so the first page (w/ the sub_cmds) remains valid - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); cmd_ptr = data_ptr; } static uint32_t process_debug_cmd(uint32_t cmd_ptr) { - volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; uint32_t checksum = 0; uint32_t *data = (uint32_t *)((uint32_t)cmd + (uint32_t)sizeof(CQDispatchCmd)); @@ -682,25 +663,27 @@ static void process_wait() { uint32_t barrier = cmd->wait.barrier; uint32_t notify_prefetch = cmd->wait.notify_prefetch; + uint32_t clear_count = cmd->wait.clear_count; + uint32_t wait = cmd->wait.wait; uint32_t addr = cmd->wait.addr; uint32_t count = cmd->wait.count; - uint32_t clear_count = cmd->wait.clear_count; if (barrier) { noc_async_write_barrier(); } DEBUG_STATUS("PWW"); - volatile tt_l1_ptr uint32_t* sem_addr = - reinterpret_cast(addr); + volatile tt_l1_ptr uint32_t *sem_addr = reinterpret_cast(addr); DPRINT << " DISPATCH WAIT " << HEX() << addr << DEC() << " count " << count << ENDL(); #if defined(COMPILE_FOR_IDLE_ERISC) uint32_t heartbeat = 0; #endif - while (!wrap_ge(*sem_addr, count)) { + if (wait) { + while (!wrap_ge(*sem_addr, count)) { #if defined(COMPILE_FOR_IDLE_ERISC) - RISC_POST_HEARTBEAT(heartbeat); + RISC_POST_HEARTBEAT(heartbeat); #endif + } } DEBUG_STATUS("PWD"); @@ -718,57 +701,54 @@ static void process_wait() { } static void process_delay_cmd() { - volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; uint32_t count = cmd->delay.delay; for (volatile uint32_t i = 0; i < count; i++); cmd_ptr += sizeof(CQDispatchCmd); } -static inline bool process_cmd_d(uint32_t& cmd_ptr) { - +static inline bool process_cmd_d(uint32_t &cmd_ptr) { bool done = false; - re_run_command: +re_run_command: volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; switch (cmd->base.cmd_id) { - case CQ_DISPATCH_CMD_WRITE_LINEAR: - DEBUG_STATUS("DWB"); - DPRINT << "cmd_write\n"; - process_write(); - DEBUG_STATUS("DWD"); - break; - - case CQ_DISPATCH_CMD_WRITE_LINEAR_H: - DPRINT << "cmd_write_linear_h\n"; - if (is_h_variant) { + case CQ_DISPATCH_CMD_WRITE_LINEAR: + DEBUG_STATUS("DWB"); + DPRINT << "cmd_write\n"; process_write(); - } else { - relay_write_h(); - } - break; + DEBUG_STATUS("DWD"); + break; - case CQ_DISPATCH_CMD_WRITE_LINEAR_H_HOST: - DPRINT << "cmd_write_linear_h_host\n"; - if (is_h_variant) { - process_write_host_h(); - } else { - process_write_host_d(); - } - break; + case CQ_DISPATCH_CMD_WRITE_LINEAR_H: + DPRINT << "cmd_write_linear_h\n"; + if (is_h_variant) { + process_write(); + } else { + relay_write_h(); + } + break; - case CQ_DISPATCH_CMD_WRITE_PAGED: - DPRINT << "cmd_write_paged is_dram: " << (uint32_t) cmd->write_paged.is_dram << ENDL(); - if (cmd->write_paged.is_dram) { - process_write_paged(); - } else { - process_write_paged(); - } - break; + case CQ_DISPATCH_CMD_WRITE_LINEAR_H_HOST: + DPRINT << "cmd_write_linear_h_host\n"; + if (is_h_variant) { + process_write_host_h(); + } else { + process_write_host_d(); + } + break; - case CQ_DISPATCH_CMD_WRITE_PACKED: - { + case CQ_DISPATCH_CMD_WRITE_PAGED: + DPRINT << "cmd_write_paged is_dram: " << (uint32_t)cmd->write_paged.is_dram << ENDL(); + if (cmd->write_paged.is_dram) { + process_write_paged(); + } else { + process_write_paged(); + } + break; + + case CQ_DISPATCH_CMD_WRITE_PACKED: { DPRINT << "cmd_write_packed" << ENDL(); uint32_t flags = cmd->write_packed.flags; if (flags & CQ_DISPATCH_CMD_PACKED_WRITE_FLAG_MCAST) { @@ -776,93 +756,90 @@ static inline bool process_cmd_d(uint32_t& cmd_ptr) { } else { process_write_packed(flags); } - } - break; - - case CQ_DISPATCH_CMD_WAIT: - DPRINT << "cmd_wait" << ENDL(); - process_wait(); - break; - - case CQ_DISPATCH_CMD_GO: - DPRINT << "cmd_go" << ENDL(); - break; - - case CQ_DISPATCH_CMD_SINK: - DPRINT << "cmd_sink" << ENDL(); - break; - - case CQ_DISPATCH_CMD_DEBUG: - DPRINT << "cmd_debug" << ENDL(); - cmd_ptr = process_debug_cmd(cmd_ptr); - goto re_run_command; - break; - - case CQ_DISPATCH_CMD_DELAY: - DPRINT << "cmd_delay" << ENDL(); - process_delay_cmd(); - break; - - case CQ_DISPATCH_CMD_TERMINATE: - DPRINT << "dispatch terminate\n"; - if (is_d_variant && !is_h_variant) { - relay_to_next_cb(cmd_ptr, sizeof(CQDispatchCmd)); - } - cmd_ptr += sizeof(CQDispatchCmd); - done = true; - break; - - default: - DPRINT << "dispatcher_d invalid command:" << cmd_ptr << " " << cb_fence << " " << dispatch_cb_base << " " << dispatch_cb_end << " " << rd_block_idx << " " << "xx" << ENDL(); - DPRINT << HEX() << *(uint32_t*)cmd_ptr << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+1) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+2) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+3) << ENDL(); - DEBUG_STATUS("!CMD"); - ASSERT(0); + } break; + + case CQ_DISPATCH_CMD_WAIT: + DPRINT << "cmd_wait" << ENDL(); + process_wait(); + break; + + case CQ_DISPATCH_CMD_GO: DPRINT << "cmd_go" << ENDL(); break; + + case CQ_DISPATCH_CMD_SINK: DPRINT << "cmd_sink" << ENDL(); break; + + case CQ_DISPATCH_CMD_DEBUG: + DPRINT << "cmd_debug" << ENDL(); + cmd_ptr = process_debug_cmd(cmd_ptr); + goto re_run_command; + break; + + case CQ_DISPATCH_CMD_DELAY: + DPRINT << "cmd_delay" << ENDL(); + process_delay_cmd(); + break; + + case CQ_DISPATCH_CMD_TERMINATE: + DPRINT << "dispatch terminate\n"; + if (is_d_variant && !is_h_variant) { + relay_to_next_cb(cmd_ptr, sizeof(CQDispatchCmd)); + } + cmd_ptr += sizeof(CQDispatchCmd); + done = true; + break; + + default: + DPRINT << "dispatcher_d invalid command:" << cmd_ptr << " " << cb_fence << " " << dispatch_cb_base << " " + << dispatch_cb_end << " " << rd_block_idx << " " + << "xx" << ENDL(); + DPRINT << HEX() << *(uint32_t *)cmd_ptr << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 1) << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 2) << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 3) << ENDL(); + DEBUG_STATUS("!CMD"); + ASSERT(0); } return done; } -static inline bool process_cmd_h(uint32_t& cmd_ptr) { - +static inline bool process_cmd_h(uint32_t &cmd_ptr) { bool done = false; volatile CQDispatchCmd tt_l1_ptr *cmd = (volatile CQDispatchCmd tt_l1_ptr *)cmd_ptr; switch (cmd->base.cmd_id) { - case CQ_DISPATCH_CMD_WRITE_LINEAR_H: - DPRINT << "dispatch_h write_linear_h\n"; - process_write(); - break; - - case CQ_DISPATCH_CMD_WRITE_LINEAR_H_HOST: - DPRINT << "dispatch_h linear_h_host\n"; - process_write_host_h(); - break; - - case CQ_DISPATCH_CMD_TERMINATE: - DPRINT << "dispatch_h terminate\n"; - cmd_ptr += sizeof(CQDispatchCmd); - done = true; - break; - - default: - DPRINT << "dispatcher_h invalid command:" << cmd_ptr << " " << cb_fence << " " << " " << dispatch_cb_base << " " << dispatch_cb_end << " " << rd_block_idx << " " << "xx" << ENDL(); - DPRINT << HEX() << *(uint32_t*)cmd_ptr << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+1) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+2) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+3) << ENDL(); - DEBUG_STATUS("!CMD"); - ASSERT(0); + case CQ_DISPATCH_CMD_WRITE_LINEAR_H: + DPRINT << "dispatch_h write_linear_h\n"; + process_write(); + break; + + case CQ_DISPATCH_CMD_WRITE_LINEAR_H_HOST: + DPRINT << "dispatch_h linear_h_host\n"; + process_write_host_h(); + break; + + case CQ_DISPATCH_CMD_TERMINATE: + DPRINT << "dispatch_h terminate\n"; + cmd_ptr += sizeof(CQDispatchCmd); + done = true; + break; + + default: + DPRINT << "dispatcher_h invalid command:" << cmd_ptr << " " << cb_fence << " " + << " " << dispatch_cb_base << " " << dispatch_cb_end << " " << rd_block_idx << " " + << "xx" << ENDL(); + DPRINT << HEX() << *(uint32_t *)cmd_ptr << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 1) << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 2) << ENDL(); + DPRINT << HEX() << *((uint32_t *)cmd_ptr + 3) << ENDL(); + DEBUG_STATUS("!CMD"); + ASSERT(0); } return done; } void kernel_main() { - DPRINT << "dispatch_" << is_h_variant << is_d_variant << ": start" << ENDL(); static_assert(is_d_variant || split_dispatch_page_preamble_size == 0); @@ -892,27 +869,22 @@ void kernel_main() { dispatch_cb_blocks, dispatch_cb_log_page_size, my_noc_xy, - my_dispatch_cb_sem_id>(cmd_ptr, - cb_fence, - block_noc_writes_to_clear, - block_next_start_addr, - rd_block_idx); + my_dispatch_cb_sem_id>( + cmd_ptr, cb_fence, block_noc_writes_to_clear, block_next_start_addr, rd_block_idx); } - done = is_d_variant ? - process_cmd_d(cmd_ptr) : - process_cmd_h(cmd_ptr); + done = is_d_variant ? process_cmd_d(cmd_ptr) : process_cmd_h(cmd_ptr); // Move to next page cmd_ptr = round_up_pow2(cmd_ptr, dispatch_cb_page_size); // XXXXX move this inside while loop waiting for get_dispatch_cb_page above // XXXXX can potentially clear a partial block when stalled w/ some more bookkeeping - cb_block_release_pages(block_noc_writes_to_clear, - wr_block_idx); + cb_block_release_pages< + upstream_noc_xy, + upstream_dispatch_cb_sem_id, + dispatch_cb_blocks, + dispatch_cb_pages_per_block>(block_noc_writes_to_clear, wr_block_idx); } noc_async_write_barrier(); @@ -935,7 +907,8 @@ void kernel_main() { // We're 1 block behind cb_release_pages(dispatch_cb_pages_per_block); } - uint32_t npages = dispatch_cb_pages_per_block - ((block_next_start_addr[rd_block_idx] - cmd_ptr) >> dispatch_cb_log_page_size); + uint32_t npages = + dispatch_cb_pages_per_block - ((block_next_start_addr[rd_block_idx] - cmd_ptr) >> dispatch_cb_log_page_size); cb_release_pages(npages); // Confirm expected number of pages, spinning here is a leak diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp index f990132a60c8..6c6a6c5d8d6d 100644 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp +++ b/tt_metal/impl/dispatch/kernels/cq_prefetch.cpp @@ -52,6 +52,7 @@ constexpr uint32_t is_h_variant = get_compile_time_arg_val(22); constexpr uint32_t my_noc_xy = uint32_t(NOC_XY_ENCODING(MY_NOC_X, MY_NOC_Y)); constexpr uint32_t upstream_noc_xy = uint32_t(NOC_XY_ENCODING(UPSTREAM_NOC_X, UPSTREAM_NOC_Y)); constexpr uint32_t downstream_noc_xy = uint32_t(NOC_XY_ENCODING(DOWNSTREAM_NOC_X, DOWNSTREAM_NOC_Y)); +constexpr uint32_t pcie_noc_xy = uint32_t(NOC_XY_PCIE_ENCODING(NOC_0_X(static_cast(NOC_INDEX), noc_size_x, PCIE_NOC_X), NOC_0_Y(static_cast(NOC_INDEX), noc_size_y, PCIE_NOC_Y), NOC_INDEX)); constexpr uint32_t downstream_cb_page_size = 1 << downstream_cb_log_page_size; constexpr uint32_t downstream_cb_end = downstream_cb_base + (1 << downstream_cb_log_page_size) * downstream_cb_pages; constexpr uint32_t prefetch_q_end = prefetch_q_base + prefetch_q_size; @@ -146,7 +147,7 @@ void read_from_pcie(volatile tt_l1_ptr prefetch_q_entry_type *& prefetch_q_rd_pt pcie_read_ptr = pcie_base; } - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); + uint64_t host_src_addr = get_noc_addr_helper(pcie_noc_xy, pcie_read_ptr); DPRINT << "read_from_pcie: " << fence + preamble_size << " " << pcie_read_ptr << ENDL(); noc_async_read(host_src_addr, fence + preamble_size, size); pending_read_size = size + preamble_size; @@ -341,16 +342,21 @@ static uint32_t process_relay_inline_noflush_cmd(uint32_t cmd_ptr, return CQ_PREFETCH_CMD_BARE_MIN_SIZE; } -template static uint32_t write_pages_to_dispatcher(uint32_t& downstream_data_ptr, uint32_t& scratch_write_addr, uint32_t& amt_to_write) { uint32_t page_residual_space = downstream_cb_page_size - (downstream_data_ptr & (downstream_cb_page_size - 1)); - uint32_t npages = (amt_to_write - page_residual_space + downstream_cb_page_size + extra_space - 1) / downstream_cb_page_size; + uint32_t npages = (amt_to_write - page_residual_space + downstream_cb_page_size - round) / downstream_cb_page_size; // Grabbing all pages at once is ok if scratch_size < 3 * downstream_cb_block_size + // test_for_nonzero is an optimization: inner loops moving lots of pages don't bother if (!test_for_nonzero || npages != 0) { cb_acquire_pages(npages); } @@ -464,7 +470,7 @@ uint32_t process_relay_paged_cmd_large(uint32_t cmd_ptr, uint32_t amt_to_write = write_length; ASSERT((amt_to_write & 0x1f) == 0); - uint32_t npages = write_pages_to_dispatcher + uint32_t npages = write_pages_to_dispatcher<1, true> (downstream_data_ptr, scratch_write_addr, amt_to_write); // One page was acquired w/ the cmd in CMD_RELAY_INLINE_NOFLUSH with 16 bytes written @@ -577,7 +583,7 @@ uint32_t process_relay_paged_cmd(uint32_t cmd_ptr, scratch_write_addr = scratch_db_top[db_toggle]; uint32_t amt_to_write = amt_read - cmd->relay_paged.length_adjust; ASSERT((amt_to_write & 0x1f) == 0); - uint32_t npages = write_pages_to_dispatcher + uint32_t npages = write_pages_to_dispatcher<1, true> (downstream_data_ptr, scratch_write_addr, amt_to_write); downstream_data_ptr = round_up_pow2(downstream_data_ptr, downstream_cb_page_size); @@ -643,7 +649,7 @@ uint32_t process_relay_linear_cmd(uint32_t cmd_ptr, // Third step - write from DB scratch_write_addr = scratch_db_top[db_toggle]; uint32_t amt_to_write = amt_to_read; - uint32_t npages = write_pages_to_dispatcher + uint32_t npages = write_pages_to_dispatcher<1, true> (downstream_data_ptr, scratch_write_addr, amt_to_write); downstream_data_ptr = round_up_pow2(downstream_data_ptr, downstream_cb_page_size); diff --git a/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp b/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp deleted file mode 100644 index f77c26d9f330..000000000000 --- a/tt_metal/impl/dispatch/kernels/cq_prefetch.hpp +++ /dev/null @@ -1,674 +0,0 @@ -// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -// Common prefetch code for use by _hd, _h, _d prefetch variants - -#include "dataflow_api.h" -#include "debug/dprint.h" -#include "tt_metal/impl/dispatch/kernels/cq_common.hpp" - -extern const uint32_t scratch_db_top[2]; - - -template -FORCE_INLINE -void write_downstream(uint32_t& data_ptr, - uint32_t& downstream_data_ptr, - uint32_t length) { - - uint32_t remaining = cb_end - downstream_data_ptr; - if (length > remaining) { - if (remaining > 0) { - noc_async_write(data_ptr, get_noc_addr_helper(downstream_noc_xy, downstream_data_ptr), remaining); - data_ptr += remaining; - length -= remaining; - } - downstream_data_ptr = cb_base; - } - - noc_async_write(data_ptr, get_noc_addr_helper(downstream_noc_xy, downstream_data_ptr), length); - downstream_data_ptr += length; -} - -template -FORCE_INLINE -void read_from_pcie(volatile tt_l1_ptr uint16_t *& prefetch_q_rd_ptr, - uint32_t& pending_read_size, - uint32_t& fence, - uint32_t& pcie_read_ptr, - uint32_t cmd_ptr, - uint32_t size) { - - // Wrap cmddat_q - if (fence + size + preamble_size > cmddat_q_base + cmddat_q_size) { - // only wrap if there are no commands ready, otherwise we'll leave some on the floor - // TODO: does this matter for perf? - if (cmd_ptr != fence) { - return; - } - fence = cmddat_q_base; - } - - // Wrap pcie/hugepage - if (pcie_read_ptr + size > pcie_base + pcie_size) { - pcie_read_ptr = pcie_base; - } - - uint64_t host_src_addr = get_noc_addr_helper(NOC_XY_ENCODING(PCIE_NOC_X, PCIE_NOC_Y), pcie_read_ptr); - noc_async_read(host_src_addr, fence + preamble_size, size); - pending_read_size = size + preamble_size; - pcie_read_ptr += size; - - *prefetch_q_rd_ptr = 0; - - // Tell host we read - *(volatile tt_l1_ptr uint32_t *) prefetch_q_rd_ptr_addr = (uint32_t)prefetch_q_rd_ptr; - - prefetch_q_rd_ptr++; - - // Wrap prefetch_q - if ((uint32_t)prefetch_q_rd_ptr == prefetch_q_end) { - prefetch_q_rd_ptr = (volatile tt_l1_ptr uint16_t*)prefetch_q_base; - } -} - -// This routine can be called in 8 states based on the boolean values cmd_ready, prefetch_q_ready, read_pending: -// - !cmd_ready, !prefetch_q_ready, !read_pending: stall on prefetch_q, issue read, read barrier -// - !cmd_ready, !prefetch_q_ready, read pending: read barrier (and re-evaluate prefetch_q_ready) -// - !cmd_ready, prefetch_q_ready, !read_pending: issue read, read barrier (XXXX +issue read after?) -// - !cmd_ready, prefetch_q_ready, read_pending: read barrier, issue read -// - cmd_ready, !prefetch_q_ready, !read_pending: exit -// - cmd_ready, !prefetch_q_ready, read_pending: exit (no barrier yet) -// - cmd_ready, prefetch_q_ready, !read_pending: issue read -// - cmd_ready, prefetch_q_ready, read_pending: exit (don't add latency to the in flight request) -// -// With WH tagging of reads: -// open question: should fetcher loop on prefetch_q_ready issuing reads until !prefetch_q_ready -// - !cmd_ready, !prefetch_q_ready, !read_pending: stall on prefetch_q, issue read, read barrier -// - !cmd_ready, !prefetch_q_ready, read pending: read barrier on oldest tag -// - !cmd_ready, prefetch_q_ready, !read_pending: issue read, read barrier (XXXX +retry after?) -// - !cmd_ready, prefetch_q_ready, read_pending: issue read, read barrier on oldest tag -// - cmd_ready, !prefetch_q_ready, !read_pending: exit -// - cmd_ready, !prefetch_q_ready, read_pending: exit (no barrier yet) -// - cmd_ready, prefetch_q_ready, !read_pending: issue and tag read -// - cmd_ready, prefetch_q_ready, read_pending: issue and tag read -template -void fetch_q_get_cmds(uint32_t& fence, uint32_t& cmd_ptr, uint32_t& pcie_read_ptr) { - - static uint32_t pending_read_size = 0; - static volatile tt_l1_ptr uint16_t* prefetch_q_rd_ptr = (volatile tt_l1_ptr uint16_t*)prefetch_q_base; - - if (fence < cmd_ptr) { - DPRINT << "wrap cmd ptr1 " << fence << " " << cmd_ptr << ENDL(); - cmd_ptr = fence; - } - - bool cmd_ready = (cmd_ptr != fence); - uint32_t fetch_size = (uint32_t)*prefetch_q_rd_ptr << prefetch_q_log_minsize; - - if (fetch_size != 0 && pending_read_size == 0) { - DPRINT << "read1: " << (uint32_t)prefetch_q_rd_ptr << " " << " " << fence << " " << fetch_size << ENDL(); - read_from_pcie - (prefetch_q_rd_ptr, pending_read_size, fence, pcie_read_ptr, cmd_ptr, fetch_size); - } - if (!cmd_ready) { - if (pending_read_size != 0) { - DPRINT << "barrier" << ENDL(); - noc_async_read_barrier(); - - // wrap the cmddat_q - if (fence < cmd_ptr) { - cmd_ptr = fence; - } - - fence += pending_read_size; - pending_read_size = 0; - // After the stall, re-check the host - fetch_size = (uint32_t)*prefetch_q_rd_ptr << prefetch_q_log_minsize; - if (fetch_size != 0) { - read_from_pcie - (prefetch_q_rd_ptr, pending_read_size, fence, pcie_read_ptr, cmd_ptr, fetch_size); - } - } else { - // By here, prefetch_q_ready must be false - // Nothing to fetch, nothing pending, nothing available, stall on host - DEBUG_STATUS("HQW"); - DPRINT << "prefetcher stall" << ENDL(); - while ((fetch_size = *prefetch_q_rd_ptr) == 0); - DPRINT << "recurse" << ENDL(); - fetch_q_get_cmds(fence, cmd_ptr, pcie_read_ptr); - DEBUG_STATUS("HQD"); - } - } -} - -template -uint32_t process_debug_cmd(uint32_t cmd_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t checksum = 0; - uint32_t data_start = (uint32_t)cmd + sizeof(CQPrefetchCmd); - uint32_t *data = (uint32_t *)data_start; - uint32_t size = cmd->debug.size; - - uint32_t front_size = (size <= cmddat_end - data_start) ? size : cmddat_end - data_start; - for (uint32_t i = 0; i < front_size / sizeof(uint32_t); i++) { - checksum += *data++; - } - uint32_t back_size = size - front_size; - if (back_size > 0) { - data = (uint32_t *)cmddat_base; - for (uint32_t i = 0; i < back_size / sizeof(uint32_t); i++) { - checksum += *data++; - } - } - - if (checksum != cmd->debug.checksum) { - DEBUG_STATUS("!CHK"); - ASSERT(0); - } - - return cmd->debug.stride; -} - -template -static uint32_t process_relay_inline_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - - uint32_t length = cmd->relay_inline.length; - uint32_t data_ptr = cmd_ptr + sizeof(CQPrefetchCmd); - - uint32_t npages = (length + cb_page_size - 1) >> cb_log_page_size; - - // Assume the downstream buffer is big relative to cmddat command size that we can - // grab what we need in one chunk - cb_acquire_pages(npages); - - uint32_t remaining = cmddat_end - data_ptr; - if (cmddat_wrap_enable && length > remaining) { - // wrap cmddat - write_downstream(data_ptr, dispatch_data_ptr, remaining); - length -= remaining; - data_ptr = cmddat_base; - } - - DPRINT << my_noc_xy << " " << dispatch_noc_xy << " " << cb_base << ENDL(); - write_downstream(data_ptr, dispatch_data_ptr, length); - - // Round to nearest page - dispatch_data_ptr += (cb_page_size - (dispatch_data_ptr & (cb_page_size - 1))) & (cb_page_size - 1); - - // XXXXX - painful syncing right now? move this into get_cmds - noc_async_writes_flushed(); - cb_release_pages(npages); - - return cmd->relay_inline.stride; -} - -// This version of inline sends inline data to the dispatcher but doesn't flush the page to the dispatcher -// This is used to assemble dispatcher commands when data comes out of band, eg, reading from DRAM -// That means this command is stateful, incorrect use will be...bad -// NOTE: this routine assumes we're sending a command header and that is LESS THAN A PAGE -template -static uint32_t process_relay_inline_noflush_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - - uint32_t length = sizeof(CQDispatchCmd); - uint32_t data_ptr = cmd_ptr + sizeof(CQPrefetchCmd); - - cb_acquire_pages(1); - if (dispatch_data_ptr == cb_end) { - dispatch_data_ptr = cb_base; - } - noc_async_write(data_ptr, get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr), length); - dispatch_data_ptr += length; - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -static uint32_t write_pages_to_dispatcher(uint32_t& dispatch_data_ptr, - uint32_t& scratch_write_addr, - uint32_t& amt_to_write) { - - uint32_t page_residual_space = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - uint32_t npages = (amt_to_write - page_residual_space + dispatch_cb_page_size + extra_space - 1) / dispatch_cb_page_size; - - // Grabbing all pages at once is ok if scratch_size < 3 * dispatch_cb_block_size - if (!test_for_nonzero || npages != 0) { - cb_acquire_pages(npages); - } - - uint64_t noc_addr = get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr); - if (dispatch_data_ptr == dispatch_cb_end) { - dispatch_data_ptr = dispatch_cb_base; - } else if (dispatch_data_ptr + amt_to_write > dispatch_cb_end) { // wrap - uint32_t last_chunk_size = dispatch_cb_end - dispatch_data_ptr; - noc_async_write(scratch_write_addr, noc_addr, last_chunk_size); - dispatch_data_ptr = dispatch_cb_base; - scratch_write_addr += last_chunk_size; - amt_to_write -= last_chunk_size; - noc_addr = get_noc_addr_helper(dispatch_noc_xy, dispatch_data_ptr); - } - - noc_async_write(scratch_write_addr, noc_addr, amt_to_write); - dispatch_data_ptr += amt_to_write; - - return npages; -} - -// This fn prefetches data from DRAM memory and writes data to the dispatch core. -// Reading from DRAM has the following characteristics: -// - latency is moderately high ~400 cycles on WH -// - DRAM bw is ~maximized when page size reaches 2K -// - for kernel dispatch, it is expected that page sizes will often be <2K -// - for buffer writing, page sizes will vary -// - writing to dispatcher works best with 4K pages (2K pages cover overhead, 4K gives perf cushion) -// - writing a 4K page takes ~32*4=128 cycles -// - writing 4 4K pages is 512 cycles, close to parity w/ the latency of DRAM -// - to hide the latency (~12% overhead), assume we need to read ~32 pages=128K, double buffered -// - in other words, we'll never achieve high efficiency and always be (somewhat) latency bound -// Algorithm does: -// - read a batch from DRAM -// - loop: read a batch from DRAM while sending to dispatcher -// - send a batch to dispatcher -// The size of the first read should be based on latency. With small page sizes -// bandwidth will be low and we'll be DRAM bound (send to dispatcher is ~free). -// With larger pages we'll get closer to a bandwidth match -// The dispatch buffer is a ring buffer. -template -uint32_t process_relay_paged_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - // This ensures that a previous cmd using the scratch buf has finished - noc_async_writes_flushed(); - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t page_id = cmd->relay_paged.start_page; - uint32_t base_addr = cmd->relay_paged.base_addr; - uint32_t page_size = cmd->relay_paged.page_size; - uint32_t pages = cmd->relay_paged.pages; - uint32_t read_length = pages * page_size; - - InterleavedAddrGen addr_gen; - addr_gen.bank_base_address = base_addr; - addr_gen.page_size = page_size; - - // First step - read into DB0 - uint32_t scratch_read_addr = scratch_db_top[0]; - uint32_t amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - uint32_t amt_read = 0; - while (amt_to_read >= page_size) { - uint64_t noc_addr = addr_gen.get_noc_addr(page_id); // XXXX replace this w/ walking the banks to save mul on GS - noc_async_read(noc_addr, scratch_read_addr, page_size); - scratch_read_addr += page_size; - page_id++; - amt_to_read -= page_size; - amt_read += page_size; - } - noc_async_read_barrier(); - - // Second step - read into DB[x], write from DB[x], toggle x, iterate - // Writes are fast, reads are slow - uint32_t db_toggle = 0; - uint32_t scratch_write_addr; - read_length -= amt_read; - while (read_length != 0) { - // This ensures that writes from prior iteration are done - // TODO(pgk); we can do better on WH w/ tagging - noc_async_writes_flushed(); - - db_toggle ^= 1; - scratch_read_addr = scratch_db_top[db_toggle]; - scratch_write_addr = scratch_db_top[db_toggle ^ 1]; - - uint32_t amt_to_write = amt_read; - amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - amt_read = 0; - while (amt_to_read >= page_size) { - uint64_t noc_addr = addr_gen.get_noc_addr(page_id); // XXXX replace this w/ walking the banks to save mul on GS - noc_async_read(noc_addr, scratch_read_addr, page_size); - scratch_read_addr += page_size; - page_id++; - amt_to_read -= page_size; - amt_read += page_size; - } - - // Third step - write from DB - uint32_t npages = write_pages_to_dispatcher< - 0, - false, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - cb_release_pages(npages); - - read_length -= amt_read; - - // TODO(pgk); we can do better on WH w/ tagging - noc_async_read_barrier(); - } - - // Third step - write from DB - scratch_write_addr = scratch_db_top[db_toggle]; - uint32_t amt_to_write = amt_read; - uint32_t npages = write_pages_to_dispatcher< - CQ_DISPATCH_CMD_SIZE, - true, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - uint32_t pad_to_page = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - dispatch_data_ptr += pad_to_page; - - // One page was acquired w/ the cmd in CMD_RELAY_INLINE_NOFLUSH - cb_release_pages(npages + 1); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -uint32_t process_relay_linear_cmd(uint32_t cmd_ptr, - uint32_t& dispatch_data_ptr) { - - // This ensures that a previous cmd using the scratch buf has finished - noc_async_writes_flushed(); - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - uint32_t noc_xy_addr = cmd->relay_linear.noc_xy_addr; - uint32_t read_addr = cmd->relay_linear.addr; - uint32_t length = cmd->relay_linear.length; - uint32_t read_length = length; - - // First step - read into DB0 - uint32_t scratch_read_addr = scratch_db_top[0]; - uint32_t amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - uint64_t noc_addr = get_noc_addr_helper(noc_xy_addr, read_addr); - noc_async_read(noc_addr, scratch_read_addr, amt_to_read); - read_addr += amt_to_read; - noc_async_read_barrier(); - - // Second step - read into DB[x], write from DB[x], toggle x, iterate - // Writes are fast, reads are slow - uint32_t db_toggle = 0; - uint32_t scratch_write_addr; - read_length -= amt_to_read; - while (read_length != 0) { - // This ensures that writes from prior iteration are done - // TODO(pgk); we can do better on WH w/ tagging - noc_async_writes_flushed(); - - db_toggle ^= 1; - scratch_read_addr = scratch_db_top[db_toggle]; - scratch_write_addr = scratch_db_top[db_toggle ^ 1]; - - uint32_t amt_to_write = amt_to_read; - amt_to_read = (scratch_db_half_size > read_length) ? read_length : scratch_db_half_size; - noc_addr = get_noc_addr_helper(noc_xy_addr, read_addr); - noc_async_read(noc_addr, scratch_read_addr, amt_to_read); - read_addr += amt_to_read; - - // Third step - write from DB - uint32_t npages = write_pages_to_dispatcher< - 0, - false, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - cb_release_pages(npages); - - read_length -= amt_to_read; - - // TODO(pgk); we can do better on WH w/ tagging - noc_async_read_barrier(); - } - - // Third step - write from DB - scratch_write_addr = scratch_db_top[db_toggle]; - uint32_t amt_to_write = amt_to_read; - uint32_t npages = write_pages_to_dispatcher< - CQ_DISPATCH_CMD_SIZE, - true, - my_noc_xy, - my_dispatch_cb_sem_id, - dispatch_noc_xy, - dispatch_cb_base, - dispatch_cb_end, - dispatch_cb_page_size>(dispatch_data_ptr, scratch_write_addr, amt_to_write); - - uint32_t pad_to_page = dispatch_cb_page_size - (dispatch_data_ptr & (dispatch_cb_page_size - 1)); - dispatch_data_ptr += pad_to_page; - - // One page was acquired w/ the cmd in CMD_RELAY_INLINE_NOFLUSH - cb_release_pages(npages + 1); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -uint32_t process_stall(uint32_t cmd_ptr) { - - static uint32_t count = 0; - - count++; - - DEBUG_STATUS("PSW"); - volatile tt_l1_ptr uint32_t* sem_addr = - reinterpret_cast(get_semaphore(dispatch_sync_sem_id)); - while (*sem_addr != count); - DEBUG_STATUS("PSD"); - - return CQ_PREFETCH_CMD_BARE_MIN_SIZE; -} - -template -bool process_cmd(uint32_t cmd_ptr, - uint32_t& downstream_data_ptr, - uint32_t& stride) { - - volatile CQPrefetchCmd tt_l1_ptr *cmd = (volatile CQPrefetchCmd tt_l1_ptr *)cmd_ptr; - bool done = false; - - switch (cmd->base.cmd_id) { - case CQ_PREFETCH_CMD_RELAY_LINEAR: - DPRINT << "relay linear: " << cmd_ptr << ENDL(); - stride = process_relay_linear_cmd< - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_RELAY_PAGED: - DPRINT << "relay dram page: " << cmd_ptr << ENDL(); - if (cmd->relay_paged.is_dram) { - stride = process_relay_paged_cmd< - true, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - } else { - stride = process_relay_paged_cmd< - false, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - downstream_cb_base, - downstream_cb_end, - downstream_cb_page_size, - scratch_db_half_size>(cmd_ptr, downstream_data_ptr); - } - break; - - case CQ_PREFETCH_CMD_RELAY_INLINE: - DPRINT << "inline" << ENDL(); - stride = process_relay_inline_cmd< - cmddat_wrap_enable, - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_sem_id, - cmddat_base, - cmddat_end, - downstream_cb_base, - downstream_cb_end, - downstream_cb_log_page_size, - downstream_cb_page_size>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_RELAY_INLINE_NOFLUSH: - DPRINT << "inline no flush" << ENDL(); - stride = process_relay_inline_noflush_cmd< - my_noc_xy, - my_downstream_cb_sem_id, - downstream_noc_xy, - downstream_cb_base, - downstream_cb_end>(cmd_ptr, downstream_data_ptr); - break; - - case CQ_PREFETCH_CMD_STALL: - DPRINT << "stall" << ENDL(); - stride = process_stall(cmd_ptr); - break; - - case CQ_PREFETCH_CMD_DEBUG: - DPRINT << "debug" << ENDL(); - stride = process_debug_cmd(cmd_ptr); - break; - - case CQ_PREFETCH_CMD_TERMINATE: - DPRINT << "terminating\n"; - done = true; - break; - - default: - DPRINT << "prefetch invalid command:" << (uint32_t)cmd->base.cmd_id << " " << cmd_ptr << " " << ENDL(); - DPRINT << HEX() << *(uint32_t*)cmd_ptr << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+1) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+2) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+3) << ENDL(); - DPRINT << HEX() << *((uint32_t*)cmd_ptr+4) << ENDL(); - DEBUG_STATUS("!CMD"); - ASSERT(0); - } - - return done; -} diff --git a/tt_metal/impl/dispatch/kernels/eth_tunneler.cpp b/tt_metal/impl/dispatch/kernels/eth_tunneler.cpp index 971afc15f8d6..8453cca33c48 100644 --- a/tt_metal/impl/dispatch/kernels/eth_tunneler.cpp +++ b/tt_metal/impl/dispatch/kernels/eth_tunneler.cpp @@ -2,10 +2,12 @@ // // SPDX-License-Identifier: Apache-2.0 +// clang-format off #include "dataflow_api.h" #include "debug/dprint.h" #include "tt_metal/impl/dispatch/kernels/packet_queue.hpp" #include "tests/tt_metal/tt_metal/perf_microbenchmark/routing/kernels/traffic_gen.hpp" +// clang-format on #define NUM_BIDIR_TUNNELS 1 #define NUM_TUNNEL_QUEUES (NUM_BIDIR_TUNNELS * 2) @@ -17,103 +19,88 @@ constexpr uint32_t endpoint_id_start_index = get_compile_time_arg_val(0); constexpr uint32_t tunnel_lanes = get_compile_time_arg_val(1); constexpr uint32_t in_queue_start_addr_words = get_compile_time_arg_val(2); constexpr uint32_t in_queue_size_words = get_compile_time_arg_val(3); -constexpr uint32_t in_queue_size_bytes = in_queue_size_words*PACKET_WORD_SIZE_BYTES; +constexpr uint32_t in_queue_size_bytes = in_queue_size_words * PACKET_WORD_SIZE_BYTES; static_assert(is_power_of_2(in_queue_size_words), "in_queue_size_words must be a power of 2"); static_assert(tunnel_lanes <= NUM_TUNNEL_QUEUES, "cannot have more than 2 tunnel directions."); static_assert(tunnel_lanes, "tunnel directions cannot be 0. 1 => Unidirectional. 2 => Bidirectional"); -constexpr uint32_t remote_receiver_x[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(4) & 0xFF), - (get_compile_time_arg_val(5) & 0xFF) - }; - -constexpr uint32_t remote_receiver_y[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(4) >> 8) & 0xFF, - (get_compile_time_arg_val(5) >> 8) & 0xFF - }; - -constexpr uint32_t remote_receiver_queue_id[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(4) >> 16) & 0xFF, - (get_compile_time_arg_val(5) >> 16) & 0xFF - }; - -constexpr DispatchRemoteNetworkType remote_receiver_network_type[NUM_TUNNEL_QUEUES] = - { - static_cast((get_compile_time_arg_val(4) >> 24) & 0xFF), - static_cast((get_compile_time_arg_val(5) >> 24) & 0xFF) - }; - -constexpr uint32_t remote_receiver_queue_start_addr_words[NUM_TUNNEL_QUEUES] = - { - get_compile_time_arg_val(6), - get_compile_time_arg_val(8) - }; - -constexpr uint32_t remote_receiver_queue_size_words[NUM_TUNNEL_QUEUES] = - { - get_compile_time_arg_val(7), - get_compile_time_arg_val(9) - }; - -static_assert(is_power_of_2(remote_receiver_queue_size_words[0]), "remote_receiver_queue_size_words must be a power of 2"); -static_assert(is_power_of_2(remote_receiver_queue_size_words[1]), "remote_receiver_queue_size_words must be a power of 2"); - -constexpr uint32_t remote_sender_x[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(10) & 0xFF), - (get_compile_time_arg_val(11) & 0xFF) - }; - -constexpr uint32_t remote_sender_y[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(10) >> 8) & 0xFF, - (get_compile_time_arg_val(11) >> 8) & 0xFF - }; - -constexpr uint32_t remote_sender_queue_id[NUM_TUNNEL_QUEUES] = - { - (get_compile_time_arg_val(10) >> 16) & 0xFF, - (get_compile_time_arg_val(11) >> 16) & 0xFF - }; - -constexpr DispatchRemoteNetworkType remote_sender_network_type[NUM_TUNNEL_QUEUES] = - { - static_cast((get_compile_time_arg_val(10) >> 24) & 0xFF), - static_cast((get_compile_time_arg_val(11) >> 24) & 0xFF) - }; +constexpr uint32_t remote_receiver_x[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(4) & 0xFF), (get_compile_time_arg_val(5) & 0xFF)}; + +constexpr uint32_t remote_receiver_y[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(4) >> 8) & 0xFF, (get_compile_time_arg_val(5) >> 8) & 0xFF}; + +constexpr uint32_t remote_receiver_queue_id[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(4) >> 16) & 0xFF, (get_compile_time_arg_val(5) >> 16) & 0xFF}; + +constexpr DispatchRemoteNetworkType remote_receiver_network_type[NUM_TUNNEL_QUEUES] = { + static_cast((get_compile_time_arg_val(4) >> 24) & 0xFF), + static_cast((get_compile_time_arg_val(5) >> 24) & 0xFF)}; + +constexpr uint32_t remote_receiver_queue_start_addr_words[NUM_TUNNEL_QUEUES] = { + get_compile_time_arg_val(6), get_compile_time_arg_val(8)}; + +constexpr uint32_t remote_receiver_queue_size_words[NUM_TUNNEL_QUEUES] = { + get_compile_time_arg_val(7), get_compile_time_arg_val(9)}; + +static_assert( + is_power_of_2(remote_receiver_queue_size_words[0]), "remote_receiver_queue_size_words must be a power of 2"); +static_assert( + is_power_of_2(remote_receiver_queue_size_words[1]), "remote_receiver_queue_size_words must be a power of 2"); + +constexpr uint32_t remote_sender_x[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(10) & 0xFF), (get_compile_time_arg_val(11) & 0xFF)}; + +constexpr uint32_t remote_sender_y[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(10) >> 8) & 0xFF, (get_compile_time_arg_val(11) >> 8) & 0xFF}; + +constexpr uint32_t remote_sender_queue_id[NUM_TUNNEL_QUEUES] = { + (get_compile_time_arg_val(10) >> 16) & 0xFF, (get_compile_time_arg_val(11) >> 16) & 0xFF}; + +constexpr DispatchRemoteNetworkType remote_sender_network_type[NUM_TUNNEL_QUEUES] = { + static_cast((get_compile_time_arg_val(10) >> 24) & 0xFF), + static_cast((get_compile_time_arg_val(11) >> 24) & 0xFF)}; constexpr uint32_t test_results_buf_addr_arg = get_compile_time_arg_val(12); constexpr uint32_t test_results_buf_size_bytes = get_compile_time_arg_val(13); -tt_l1_ptr uint32_t* const test_results = - reinterpret_cast(test_results_buf_addr_arg); +tt_l1_ptr uint32_t* const test_results = reinterpret_cast(test_results_buf_addr_arg); constexpr uint32_t timeout_cycles = get_compile_time_arg_val(14); void kernel_main() { - rtos_context_switch_ptr = (void (*)())RtosTable[0]; noc_init(); test_results[PQ_TEST_STATUS_INDEX] = PACKET_QUEUE_TEST_STARTED; test_results[PQ_TEST_MISC_INDEX] = 0xff000000; - test_results[PQ_TEST_MISC_INDEX+1] = 0xbb000000; - test_results[PQ_TEST_MISC_INDEX+2] = 0xAABBCCDD; - test_results[PQ_TEST_MISC_INDEX+3] = 0xDDCCBBAA; - test_results[PQ_TEST_MISC_INDEX+4] = endpoint_id_start_index; + test_results[PQ_TEST_MISC_INDEX + 1] = 0xbb000000; + test_results[PQ_TEST_MISC_INDEX + 2] = 0xAABBCCDD; + test_results[PQ_TEST_MISC_INDEX + 3] = 0xDDCCBBAA; + test_results[PQ_TEST_MISC_INDEX + 4] = endpoint_id_start_index; for (uint32_t i = 0; i < tunnel_lanes; i++) { - input_queues[i].init(i, in_queue_start_addr_words + i*in_queue_size_words, in_queue_size_words, - remote_sender_x[i], remote_sender_y[i], remote_sender_queue_id[i], remote_sender_network_type[i]); + input_queues[i].init( + i, + in_queue_start_addr_words + i * in_queue_size_words, + in_queue_size_words, + remote_sender_x[i], + remote_sender_y[i], + remote_sender_queue_id[i], + remote_sender_network_type[i]); } for (uint32_t i = 0; i < tunnel_lanes; i++) { - output_queues[i].init(i + NUM_TUNNEL_QUEUES, remote_receiver_queue_start_addr_words[i], remote_receiver_queue_size_words[i], - remote_receiver_x[i], remote_receiver_y[i], remote_receiver_queue_id[i], remote_receiver_network_type[i], - &input_queues[i], 1); + output_queues[i].init( + i + NUM_TUNNEL_QUEUES, + remote_receiver_queue_start_addr_words[i], + remote_receiver_queue_size_words[i], + remote_receiver_x[i], + remote_receiver_y[i], + remote_receiver_queue_id[i], + remote_receiver_network_type[i], + &input_queues[i], + 1); } if (!wait_all_src_dest_ready(input_queues, tunnel_lanes, output_queues, tunnel_lanes, timeout_cycles)) { @@ -142,10 +129,11 @@ void kernel_main() { for (uint32_t i = 0; i < tunnel_lanes; i++) { if (input_queues[i].get_curr_packet_valid()) { bool full_packet_sent; - uint32_t words_sent = output_queues[i].forward_data_from_input(0, full_packet_sent); - //data_words_sent += words_sent; - //if ((words_sent > 0) && (timeout_cycles > 0)) { - progress_timestamp = get_timestamp_32b(); + uint32_t words_sent = + output_queues[i].forward_data_from_input(0, full_packet_sent, input_queues[i].get_end_of_cmd()); + // data_words_sent += words_sent; + // if ((words_sent > 0) && (timeout_cycles > 0)) { + progress_timestamp = get_timestamp_32b(); //} } output_queues[i].prev_words_in_flight_check_flush(); @@ -156,8 +144,8 @@ void kernel_main() { all_outputs_finished &= output_finished; } - //need to optimize this. - //context switch to base fw is very costly. + // need to optimize this. + // context switch to base fw is very costly. internal_::risc_context_switch(); } diff --git a/tt_metal/impl/dispatch/kernels/packet_demux.cpp b/tt_metal/impl/dispatch/kernels/packet_demux.cpp index 9fa19a887649..7c915f73766b 100644 --- a/tt_metal/impl/dispatch/kernels/packet_demux.cpp +++ b/tt_metal/impl/dispatch/kernels/packet_demux.cpp @@ -235,7 +235,7 @@ void kernel_main() { uint32_t dest = input_queue.get_curr_packet_dest(); uint8_t output_queue_id = dest_output_queue_id(dest); bool full_packet_sent; - uint32_t words_sent = output_queues[output_queue_id].forward_data_from_input(0, full_packet_sent); + uint32_t words_sent = output_queues[output_queue_id].forward_data_from_input(0, full_packet_sent, input_queue.get_end_of_cmd()); data_words_sent += words_sent; if ((words_sent > 0) && (timeout_cycles > 0)) { progress_timestamp = get_timestamp_32b(); diff --git a/tt_metal/impl/dispatch/kernels/packet_mux.cpp b/tt_metal/impl/dispatch/kernels/packet_mux.cpp index 515951018eb3..a97984306372 100644 --- a/tt_metal/impl/dispatch/kernels/packet_mux.cpp +++ b/tt_metal/impl/dispatch/kernels/packet_mux.cpp @@ -185,7 +185,7 @@ void kernel_main() { } if (input_queues[curr_input].get_curr_packet_valid()) { bool full_packet_sent; - uint32_t words_sent = output_queue.forward_data_from_input(curr_input, full_packet_sent); + uint32_t words_sent = output_queue.forward_data_from_input(curr_input, full_packet_sent, input_queues[curr_input].get_end_of_cmd()); data_words_sent += words_sent; if ((words_sent > 0) && (timeout_cycles > 0)) { progress_timestamp = get_timestamp_32b(); diff --git a/tt_metal/impl/dispatch/kernels/packet_queue.hpp b/tt_metal/impl/dispatch/kernels/packet_queue.hpp index 0be258377269..bf4e9a294fb3 100644 --- a/tt_metal/impl/dispatch/kernels/packet_queue.hpp +++ b/tt_metal/impl/dispatch/kernels/packet_queue.hpp @@ -410,6 +410,7 @@ class packet_input_queue_state_t : public packet_queue_state_t { uint16_t curr_packet_src; uint16_t curr_packet_dest; uint32_t curr_packet_size_words; + uint32_t end_of_cmd; uint32_t curr_packet_words_sent; uint32_t curr_packet_tag; uint16_t curr_packet_flags; @@ -423,7 +424,9 @@ class packet_input_queue_state_t : public packet_queue_state_t { (this->queue_start_addr_words + this->get_queue_rptr_sent_offset_words())*PACKET_WORD_SIZE_BYTES ); this->curr_packet_header_ptr = next_packet_header_ptr; - uint32_t packet_size_bytes = next_packet_header_ptr->packet_size_bytes; + uint32_t packet_size_and_flags = next_packet_header_ptr->packet_size_bytes; + uint32_t packet_size_bytes = packet_size_and_flags & 0xFFFFFFFE; + this->end_of_cmd = !(packet_size_and_flags & 1); this->curr_packet_size_words = packet_size_bytes/PACKET_WORD_SIZE_BYTES; if (packet_size_bytes % PACKET_WORD_SIZE_BYTES) { this->curr_packet_size_words++; @@ -489,6 +492,10 @@ class packet_input_queue_state_t : public packet_queue_state_t { this->reset_ready_flag(); } + inline uint32_t get_end_of_cmd() const { + return this->end_of_cmd; + } + inline bool is_packetizer_input() const { return this->cb_mode; } @@ -863,7 +870,7 @@ class packet_output_queue_state_t : public packet_queue_state_t { return num_words_to_forward; } - inline uint32_t forward_data_from_input(uint32_t input_queue_index, bool& full_packet_sent) { + inline uint32_t forward_data_from_input(uint32_t input_queue_index, bool& full_packet_sent, uint32_t end_of_cmd) { packet_input_queue_state_t* input_queue_ptr = &(this->input_queue_status.input_queue_array[input_queue_index]); uint32_t num_words_to_forward = this->get_num_words_to_send(input_queue_index); @@ -894,7 +901,7 @@ class packet_output_queue_state_t : public packet_queue_state_t { this->remote_wptr_update(num_words_to_forward); } else { this->unpacketizer_page_words_sent += num_words_to_forward; - if (full_packet_sent) { + if (full_packet_sent && end_of_cmd) { uint32_t unpacketizer_page_words_sent_past_page_bound = this->unpacketizer_page_words_sent & (this->cb_mode_page_size_words - 1); if (unpacketizer_page_words_sent_past_page_bound > 0) { diff --git a/tt_metal/impl/program/program.cpp b/tt_metal/impl/program/program.cpp index 1edcca12168f..a507e2e2337f 100644 --- a/tt_metal/impl/program/program.cpp +++ b/tt_metal/impl/program/program.cpp @@ -590,16 +590,14 @@ void Program::populate_dispatch_data(Device *device) { {RISCV::ERISC, eth_l1_mem::address_map::FIRMWARE_BASE}}; auto extract_dst_noc_unicast_info = - [&device](const set &ranges, const CoreType core_type) -> vector> { + [&device](const set &ranges, const CoreType core_type) -> vector> { // This API extracts all the pairs of noc multicast encodings given a set of core ranges - vector> dst_noc_unicast_info; + vector> dst_noc_unicast_info; for (const CoreRange &core_range : ranges) { for (auto x = core_range.start.x; x <= core_range.end.x; x++) { for (auto y = core_range.start.y; y <= core_range.end.y; y++) { CoreCoord physical_coord = device->physical_core_from_logical_core(CoreCoord({x, y}), core_type); - uint32_t dst_noc_unicast_encoding = - NOC_XY_ENCODING(NOC_X(physical_coord.x), NOC_Y(physical_coord.y)); - dst_noc_unicast_info.push_back(std::make_pair(dst_noc_unicast_encoding, /*num_mcast_dests=*/0)); + dst_noc_unicast_info.push_back(std::make_pair(physical_coord, /*num_mcast_dests=*/0)); } } } @@ -613,7 +611,7 @@ void Program::populate_dispatch_data(Device *device) { // TODO: use semaphore.core_type from main if (semaphore.core_type() == CoreType::WORKER) { - vector> dst_noc_multicast_info = + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, semaphore.core_range_set().ranges(), semaphore.core_type()); transfer_info_2 transfer_info = { @@ -623,7 +621,7 @@ void Program::populate_dispatch_data(Device *device) { .data = semaphore_data}; this->program_transfer_info.multicast_semaphores[semaphore.address()].push_back(transfer_info); } else if (semaphore.core_type() == CoreType::ETH) { - vector> dst_noc_unicast_info = + vector> dst_noc_unicast_info = extract_dst_noc_unicast_info(semaphore.core_range_set().ranges(), semaphore.core_type()); transfer_info_2 transfer_info = { .dst_base_addr = semaphore.address(), @@ -640,7 +638,7 @@ void Program::populate_dispatch_data(Device *device) { // Program Binaries and Go Signals // TODO: cleanup put the WORKERS and ETH logic together.. for (KernelGroup &kernel_group : this->get_kernel_groups(CoreType::WORKER)) { - vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( + vector> dst_noc_multicast_info = extract_dst_noc_multicast_info>( device, kernel_group.core_ranges.ranges(), kernel_group.get_core_type()); // So far, we don't support linking optimizations for kernel groups @@ -710,7 +708,7 @@ void Program::populate_dispatch_data(Device *device) { } } for (KernelGroup &kernel_group : this->get_kernel_groups(CoreType::ETH)) { - vector> dst_noc_unicast_info = + vector> dst_noc_unicast_info = extract_dst_noc_unicast_info(kernel_group.core_ranges.ranges(), kernel_group.get_core_type()); vector kernel_ids; diff --git a/tt_metal/impl/program/program.hpp b/tt_metal/impl/program/program.hpp index 868a9c711e18..10e33f55591c 100644 --- a/tt_metal/impl/program/program.hpp +++ b/tt_metal/impl/program/program.hpp @@ -54,19 +54,16 @@ struct KernelGroup { }; template -vector> extract_dst_noc_multicast_info(Device* device, const CoreRangeContainer& ranges, const CoreType core_type) { +vector> extract_dst_noc_multicast_info(Device* device, const CoreRangeContainer& ranges, const CoreType core_type) { // This API extracts all the pairs of noc multicast encodings given a set of core ranges - vector> dst_noc_multicast_info; + vector> dst_noc_multicast_info; dst_noc_multicast_info.reserve(ranges.size()); for (const CoreRange& core_range : ranges) { CoreCoord physical_start = device->physical_core_from_logical_core(core_range.start, core_type); CoreCoord physical_end = device->physical_core_from_logical_core(core_range.end, core_type); - uint32_t dst_noc_multicast_encoding = - NOC_MULTICAST_ENCODING(physical_start.x, physical_start.y, physical_end.x, physical_end.y); - uint32_t num_receivers = core_range.size(); - dst_noc_multicast_info.push_back(std::make_pair(dst_noc_multicast_encoding, num_receivers)); + dst_noc_multicast_info.push_back(std::make_pair(CoreRange(physical_start, physical_end), num_receivers)); } return dst_noc_multicast_info; } diff --git a/tt_metal/impl/program/program_device_map.hpp b/tt_metal/impl/program/program_device_map.hpp index e5c6d5cfd5a5..dc648887b133 100644 --- a/tt_metal/impl/program/program_device_map.hpp +++ b/tt_metal/impl/program/program_device_map.hpp @@ -16,9 +16,11 @@ struct transfer_info { bool linked; }; +using transfer_info_cores = std::variant; + struct transfer_info_2 { std::uint32_t dst_base_addr; - vector> dst_noc_info; // noc_encoding, num_mcast_dests + vector> dst_noc_info; // noc_encoding, num_mcast_dests bool linked; vector data; }; @@ -26,7 +28,7 @@ struct kernel_bins_transfer_info { vector dst_base_addrs; // BRISC, NCRISC, TRISC etc.. vector page_offsets; // offsets into paged buffer in DRAM vector lengths; // WriteLinear lengths - vector> dst_noc_info; // noc_encoding, num_mcast_dests + vector> dst_noc_info; // noc_encoding, num_mcast_dests bool linked; vector data; // all binaries' data for kernel group }; diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h b/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h index e29d02434594..22ebaba89e54 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/typecast.h @@ -20,16 +20,21 @@ namespace ckernel { /** * Performs an elementwise typecast operation on the input. - * Supports typecast from fp32 to uint32. + * Supports following typecasts: + * fp32/fp16b -> uint32 + * fp32/fp16b -> uint16 + * For output to be uint32, Dest must be in 32 bit mode. * * Return value: None * * | Argument | Description | Type | Valid Range | Required | * |----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| * | tile_index | The index of the tile in DST register buffer to perform typecast operation | uint32_t | Must be less than the size of the DST register buffer | True | + * | OUT_DTYPE | Desired output data format | uint32_t | Must be valid tt::DataFormat | True | */ +template ALWI void typecast_tile(uint32_t idst) { - MATH(( llk_math_eltwise_unary_sfpu_typecast(idst) )); + MATH(( llk_math_eltwise_unary_sfpu_typecast(idst) )); } /** diff --git a/tt_metal/jit_build/genfiles.cpp b/tt_metal/jit_build/genfiles.cpp index 9c244ddd913f..b805e5ffa1e9 100644 --- a/tt_metal/jit_build/genfiles.cpp +++ b/tt_metal/jit_build/genfiles.cpp @@ -199,9 +199,8 @@ generate_pack_data_formats(tt_hlk_desc& desc, DataFormat unpack_conditional_dst_ static void emit_pack_data_formats(std::string pack_data_format_descs, std::vector src_formats_all_cbs, std::vector dst_formats_all_cbs) { ofstream file_stream; file_stream.open(pack_data_format_descs); - // TODO: we should be emitting "unsigned char", no reason to use 4B per data format - file_stream << create_formats_array_string("constexpr std::int32_t", "pack_src_format", NUM_CIRCULAR_BUFFERS, data_format_vec_to_string(src_formats_all_cbs)); - file_stream << create_formats_array_string("constexpr std::int32_t", "pack_dst_format", NUM_CIRCULAR_BUFFERS, data_format_vec_to_string(dst_formats_all_cbs)); + file_stream << create_formats_array_string("constexpr unsigned char", "pack_src_format", NUM_CIRCULAR_BUFFERS, data_format_vec_to_string(src_formats_all_cbs)); + file_stream << create_formats_array_string("constexpr unsigned char", "pack_dst_format", NUM_CIRCULAR_BUFFERS, data_format_vec_to_string(dst_formats_all_cbs)); // budabackend-style format array // file_stream << create_formats_array_string("const std::int32_t", "pack_src_format", 16, data_format_vec_to_string(src_formats)); diff --git a/tt_metal/tt_metal.cpp b/tt_metal/tt_metal.cpp index 4ce64b5b07a3..665de904b462 100644 --- a/tt_metal/tt_metal.cpp +++ b/tt_metal/tt_metal.cpp @@ -253,17 +253,12 @@ std::map CreateDevices( const std::vector &l1_bank_remap) { ZoneScoped; std::map active_devices; // TODO: pass this to CloseDevices - // Construct NUMA Node to CPU core map - std::unordered_set free_cores = {}; - auto cpu_cores_per_numa_node = device_cpu_allocator::get_cpu_cores_per_numa_node(free_cores); - for (const auto &device_id : device_ids) { const auto &mmio_device_id = tt::Cluster::instance().get_associated_mmio_device(device_id); if (active_devices.find(mmio_device_id) == active_devices.end()) { for (const auto &mmio_controlled_device_id : tt::Cluster::instance().get_devices_controlled_by_mmio_device(mmio_device_id)) { - int core_assigned_to_device = device_cpu_allocator::get_cpu_core_for_device_worker_thread( - mmio_controlled_device_id, cpu_cores_per_numa_node, free_cores); + int core_assigned_to_device = mmio_controlled_device_id % sysconf(_SC_NPROCESSORS_ONLN); Device *dev = new Device( mmio_controlled_device_id, num_hw_cqs, @@ -276,9 +271,6 @@ std::map CreateDevices( } } } - // Bind main thread to cores not being used by workers. - device_cpu_allocator::bind_current_thread_to_free_cores(free_cores); - // TODO: need to only enable routing for used mmio chips tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true); return active_devices; @@ -802,14 +794,8 @@ Device *CreateDevice( const size_t l1_small_size, const std::vector &l1_bank_remap) { ZoneScoped; - // Construct NUMA Node to CPU core map - std::unordered_set free_cores = {}; - auto cpu_cores_per_numa_node = device_cpu_allocator::get_cpu_cores_per_numa_node(free_cores); - int core_assigned_to_device = - device_cpu_allocator::get_cpu_core_for_device_worker_thread(device_id, cpu_cores_per_numa_node, free_cores); + int core_assigned_to_device = device_id % sysconf(_SC_NPROCESSORS_ONLN); Device *dev = new Device(device_id, num_hw_cqs, l1_small_size, l1_bank_remap, false, core_assigned_to_device); - // Bind main thread to cores not being used by workers. - device_cpu_allocator::bind_current_thread_to_free_cores(free_cores); tt::Cluster::instance().set_internal_routing_info_for_ethernet_cores(true); detail::InitDeviceProfiler(dev); return dev; diff --git a/ttnn/cpp/pybind11/operations/binary.hpp b/ttnn/cpp/pybind11/operations/binary.hpp index 4c9f2104b58a..7bbf43ff2a14 100644 --- a/ttnn/cpp/pybind11/operations/binary.hpp +++ b/ttnn/cpp/pybind11/operations/binary.hpp @@ -33,9 +33,11 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati * :attr:`input_tensor_b` (ttnn.Tensor or Number): the tensor or number to add to :attr:`input_tensor_a`. Keyword args: - * :attr:`memory_config` (ttnn.MemoryConfig): memory config for the output tensor - * :attr:`dtype` (ttnn.DataType): data type for the output tensor - * :attr:`activations` (List[str]): list of activation functions to apply to the output tensor + * :attr:`memory_config` (Optional[ttnn.MemoryConfig]): memory config for the output tensor + * :attr:`dtype` (Optional[ttnn.DataType]): data type for the output tensor + * :attr:`output_tensor` (Optional[ttnn.Tensor]): preallocated output tensor + * :attr:`activations` (Optional[List[str]]): list of activation functions to apply to the output tensor + * :attr:`queue_id` (Optional[uint8]): command queue id Example:: @@ -51,34 +53,47 @@ void bind_binary_operation(py::module& module, const binary_operation_t& operati module, operation, doc, + // tensor and scalar ttnn::pybind_overload_t{ [](const binary_operation_t& self, const ttnn::Tensor& input_tensor_a, const float scalar, const std::optional& memory_config, const std::optional& dtype, - const std::optional>& activations) -> ttnn::Tensor { - return self(input_tensor_a, scalar, memory_config, dtype, activations); + const std::optional& output_tensor, + const std::optional>& activations, + const uint8_t& queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor_a, scalar, memory_config, dtype, output_tensor, activations); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), + py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("dtype") = std::nullopt, - py::arg("activations") = std::nullopt}, + py::arg("output_tensor") = std::nullopt, + py::arg("activations") = std::nullopt, + py::arg("queue_id") = 0}, + + // tensor and tensor ttnn::pybind_overload_t{ [](const binary_operation_t& self, const ttnn::Tensor& input_tensor_a, const ttnn::Tensor& input_tensor_b, const std::optional& memory_config, const std::optional& dtype, - const std::optional>& activations) -> ttnn::Tensor { - return self(input_tensor_a, input_tensor_b, memory_config, dtype, activations); + const std::optional& output_tensor, + const std::optional>& activations, + const uint8_t& queue_id) -> ttnn::Tensor { + return self(queue_id, input_tensor_a, input_tensor_b, memory_config, dtype, output_tensor, activations); }, py::arg("input_tensor_a"), py::arg("input_tensor_b"), + py::kw_only(), py::arg("memory_config") = std::nullopt, py::arg("dtype") = std::nullopt, - py::arg("activations") = std::nullopt}); + py::arg("output_tensor") = std::nullopt, + py::arg("activations") = std::nullopt, + py::arg("queue_id") = 0}); } } // namespace detail diff --git a/ttnn/cpp/ttnn/multi_device.hpp b/ttnn/cpp/ttnn/multi_device.hpp index 41943189363b..1a36bad30869 100644 --- a/ttnn/cpp/ttnn/multi_device.hpp +++ b/ttnn/cpp/ttnn/multi_device.hpp @@ -46,6 +46,8 @@ std::vector get_device_tensors(const ttnn::Tensor& tensor) { tensors.push_back(shard); } return tensors; + } else { + return {tensor}; } TT_THROW("Expected tensor to be on MultiDeviceHostStorage type!"); } diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp index 243b6ef4808a..b73c63d819f2 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.cpp @@ -80,7 +80,7 @@ inline BinaryProgramType get_program_type(const Binary& operation, const std::ve TT_THROW("ttnn::operations::binary::Binary: unsupported broadcast"); } -void Binary::validate(const std::vector& input_tensors) const { +void Binary::validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const { auto program_type = get_program_type(*this, input_tensors); const auto& input_tensor_a = input_tensors.at(0); @@ -170,6 +170,14 @@ void Binary::validate(const std::vector& input_tensors) const { if (program_type != BinaryProgramType::ElementWiseMultiCore) { TT_FATAL(not this->program_config.activations.has_value()); } + + if (!output_tensors.empty()) { + TT_FATAL(output_tensors.size() == 1, "Must have 1 output tensors"); + + if(output_tensors.at(0).has_value()) { + TT_FATAL(!this->program_config.in_place, "Operation is configured as in_place. First input is used as output. Provided output tensor is ignored"); + } + } } std::vector Binary::compute_output_shapes(const std::vector& input_tensors) const { @@ -181,12 +189,16 @@ std::vector Binary::compute_output_shapes(const std::vector return {input_tensor_b.get_legacy_shape()}; } -std::vector Binary::create_output_tensors(const std::vector& input_tensors) const { +std::vector Binary::create_output_tensors(const std::vector& input_tensors, const std::vector>& output_tensors) const { const auto& input_tensor_a = input_tensors.at(0); const auto& input_tensor_b = input_tensors.at(1); if (this->program_config.in_place) { return {input_tensor_a}; } else { + if (!output_tensors.empty() && output_tensors.at(0).has_value()) { + return {output_tensors.at(0).value()}; + } + auto program_type = get_program_type(*this, input_tensors); if (program_type == BinaryProgramType::ElementWiseMultiCore) { diff --git a/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp b/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp index 006097c36331..6ae72c5a9822 100644 --- a/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp +++ b/ttnn/cpp/ttnn/op_library/binary/binary_op.hpp @@ -29,6 +29,8 @@ namespace binary { using BinaryOpType = tt::tt_metal::BinaryOpType; +constexpr uint8_t DefaultQueueId = 0; + struct BinaryProgramConfig { BinaryOpType binary_op_type; bool in_place; @@ -48,9 +50,9 @@ struct Binary { const BinaryProgramConfig program_config; std::optional compute_kernel_config; - void validate(const std::vector &input_tensors) const; + void validate_with_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors) const; + std::vector create_output_tensors(const std::vector &input_tensors, const std::vector>& output_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; @@ -92,16 +94,23 @@ struct ExecuteBinary { } template - static auto input_tensors_to_validate(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { + static auto input_tensors_to_validate(uint8_t queue_id, const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { return std::forward_as_tuple(input_tensor_a, input_tensor_b); } static Tensor execute_on_worker_thread( + uint8_t queue_id, const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg, const std::optional &memory_config = std::nullopt, - const std::optional &dtype = std::nullopt, + const std::optional &output_dtype = std::nullopt, + std::optional optional_output_tensor = std::nullopt, std::optional> activations = std::nullopt) { + + if(output_dtype.has_value() && optional_output_tensor.has_value()){ + TT_FATAL(output_dtype.value() == optional_output_tensor.value().get_dtype(), "If both output dtype and output tensor provided dtype should match"); + } + auto &&[input_tensor_a, input_tensor_b] = [](const auto &input_tensor_a_arg, const auto &input_tensor_b_arg) { const auto input_shape_a = input_tensor_a_arg.get_shape(); const auto input_shape_b = input_tensor_b_arg.get_shape(); @@ -111,6 +120,7 @@ struct ExecuteBinary { } return std::make_tuple(input_tensor_a_arg, input_tensor_b_arg); }(input_tensor_a_arg, input_tensor_b_arg); + auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config()); // TODO(arakhmati): #7731 - remove this! @@ -124,15 +134,38 @@ struct ExecuteBinary { input_tensor_b = tt::tt_metal::repeat(input_tensor_b, repeats.value(), output_memory_config); } - return operation::run( - Binary{BinaryProgramConfig{ - binary_op_type, - in_place, - activations, - output_memory_config, - dtype.value_or(input_tensor_a.get_dtype())}}, - {input_tensor_a, input_tensor_b}) - .at(0); + DataType dtype = output_dtype.value_or(input_tensor_a.get_dtype()); + if(optional_output_tensor.has_value()) { + dtype = optional_output_tensor.value().get_dtype(); + } + + auto output_tensors = operation::run(Binary{BinaryProgramConfig{binary_op_type, + in_place, + activations, + output_memory_config, + dtype}}, + {input_tensor_a, input_tensor_b}, + {}, + {optional_output_tensor}, + queue_id); + + return output_tensors.at(0); + } + + template + static auto input_tensors_to_validate(const Tensor &input_tensor_a, const Tensor &input_tensor_b, Args &&...args) { + return std::forward_as_tuple(input_tensor_a, input_tensor_b); + } + + static Tensor execute_on_worker_thread( + const Tensor &input_tensor_a_arg, + const Tensor &input_tensor_b_arg, + const std::optional &memory_config = std::nullopt, + const std::optional &output_dtype = std::nullopt, + std::optional optional_output_tensor = std::nullopt, + std::optional> activations = std::nullopt) + { + return execute_on_worker_thread(DefaultQueueId, input_tensor_a_arg, input_tensor_b_arg, memory_config, output_dtype, optional_output_tensor, activations); } template @@ -147,6 +180,24 @@ struct ExecuteBinary { const float scalar, const std::optional &memory_config = std::nullopt, const std::optional &dtype = std::nullopt, + const std::optional &optional_output_tensor = std::nullopt, + std::optional> activations = std::nullopt) { + + return ExecuteBinary::execute_on_worker_thread(DefaultQueueId, input_tensor_a, scalar, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, dtype, optional_output_tensor, activations); + } + + template + static auto input_tensors_to_validate(uint8_t queue_id, const Tensor &input_tensor_a, const float input_tensor_b, Args &&...args) { + return std::forward_as_tuple(input_tensor_a, input_tensor_b); + } + + static Tensor execute_on_worker_thread( + uint8_t queue_id, + const ttnn::Tensor &input_tensor_a, + const float scalar, + const std::optional &memory_config = std::nullopt, + const std::optional &dtype = std::nullopt, + const std::optional &optional_output_tensor = std::nullopt, std::optional> activations = std::nullopt) { // Cast Float Scalar to a device tensor auto host_buffer = owned_buffer::create<::bfloat16>(static_cast(TILE_HEIGHT * TILE_WIDTH)); @@ -159,7 +210,7 @@ struct ExecuteBinary { Tensor scalar_tensor_device = scalar_tensor_host.to(input_tensor_a.device()); // TODO(arakhmati): #7637 pass in memory_config instead of operation::DEFAULT_OUTPUT_MEMORY_CONFIG return ExecuteBinary::execute_on_worker_thread( - input_tensor_a, scalar_tensor_device, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, dtype, activations); + input_tensor_a, scalar_tensor_device, operation::DEFAULT_OUTPUT_MEMORY_CONFIG, dtype, optional_output_tensor, activations); } }; diff --git a/ttnn/cpp/ttnn/operations/unary.hpp b/ttnn/cpp/ttnn/operations/unary.hpp index 2ab4686b5f48..2b95096d2fc7 100644 --- a/ttnn/cpp/ttnn/operations/unary.hpp +++ b/ttnn/cpp/ttnn/operations/unary.hpp @@ -42,8 +42,9 @@ inline Tensor execute_on_worker_thread( const Tensor& input_tensor, const std::vector& op_chain, const std::optional& memory_config = std::nullopt) { - DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? DataType::UINT32 : input_tensor.get_dtype(); - bool fp32_dest_acc_en = input_tensor.get_dtype() == DataType::UINT32 or + DataType output_dtype = (op_chain[0].op_type == UnaryOpType::TYPECAST) ? static_cast(op_chain[0].params[0]) : input_tensor.get_dtype(); + bool fp32_dest_acc_en = output_dtype == DataType::UINT32 or + input_tensor.get_dtype() == DataType::UINT32 or input_tensor.get_dtype() == DataType::INT32; // MT: Currently only uint32/int32 is moved to // DST directly, fp32 is converted to fp16b return operation::run( diff --git a/ttnn/ttnn/__init__.py b/ttnn/ttnn/__init__.py index 889a517af461..ea52b8fb3864 100644 --- a/ttnn/ttnn/__init__.py +++ b/ttnn/ttnn/__init__.py @@ -57,7 +57,7 @@ def validate(self, name): if self.enable_fast_runtime_mode: if self.enable_logging: logger.warning( - "Running in fast runtime mode without logging. Please disable fast runtime mode if you want to enable logging." + "Logging cannot be enabled in fast runtime mode. Please disable fast runtime mode if you want to enable logging." ) if name in {