diff --git a/.github/workflows/fast-dispatch-full-regressions-and-models.yaml b/.github/workflows/fast-dispatch-full-regressions-and-models.yaml index ba92052af86..60aafb71475 100644 --- a/.github/workflows/fast-dispatch-full-regressions-and-models.yaml +++ b/.github/workflows/fast-dispatch-full-regressions-and-models.yaml @@ -73,14 +73,14 @@ jobs: arch: wormhole_b0, runs-on: ["cloud-virtual-machine", "N300", "in-service"], cmd: tests/scripts/single_card/nightly/run_wh_b0_only.sh, - timeout: 80 + timeout: 100 }, { name: "N150 WH-only models", arch: wormhole_b0, runs-on: ["cloud-virtual-machine", "N150", "in-service"], cmd: tests/scripts/single_card/nightly/run_wh_b0_only.sh, - timeout: 80 + timeout: 100 }, { name: "API tests GS", diff --git a/models/experimental/functional_unet/tests/test_unet_model.py b/models/experimental/functional_unet/tests/test_unet_model.py index 84f66d9c856..d2b0154bac4 100644 --- a/models/experimental/functional_unet/tests/test_unet_model.py +++ b/models/experimental/functional_unet/tests/test_unet_model.py @@ -26,7 +26,7 @@ def test_unet_model(batch, groups, device, use_program_cache, reset_seeds): ttnn_model = unet_shallow_ttnn.UNet(parameters, device) torch_output_tensor = model(torch_input) - output_tensor = ttnn_model(ttnn_input, list(torch_input.shape)) + output_tensor = ttnn_model(ttnn_input) B, C, H, W = torch_output_tensor.shape ttnn_tensor = ttnn.to_torch(output_tensor).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2) diff --git a/models/experimental/functional_unet/tests/test_unet_multi_device.py b/models/experimental/functional_unet/tests/test_unet_multi_device.py index 774d121772b..91c3ae1010c 100644 --- a/models/experimental/functional_unet/tests/test_unet_multi_device.py +++ b/models/experimental/functional_unet/tests/test_unet_multi_device.py @@ -49,6 +49,6 @@ def test_unet_multi_device_model(batch, groups, mesh_device, use_program_cache, ) torch_output_tensor = model(torch_input) - output_tensor = ttnn_model(ttnn_input, list(torch_input.shape)) + output_tensor = ttnn_model(ttnn_input) check_pcc_conv(torch_output_tensor, output_tensor, mesh_composer=output_mesh_composer, pcc=0.99) diff --git a/models/experimental/functional_unet/tests/test_unet_perf.py b/models/experimental/functional_unet/tests/test_unet_perf.py new file mode 100644 index 00000000000..fc0adb77a99 --- /dev/null +++ b/models/experimental/functional_unet/tests/test_unet_perf.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn +import pytest + +from loguru import logger + +from tests.ttnn.utils_for_testing import assert_with_pcc + +from models.experimental.functional_unet.tt.model_preprocessing import ( + create_unet_input_tensors, + create_unet_model_parameters, +) +from models.experimental.functional_unet.tt import unet_shallow_torch +from models.experimental.functional_unet.tt import unet_shallow_ttnn +from models.experimental.functional_unet.tests.common import ( + check_pcc_conv, + is_n300_with_eth_dispatch_cores, +) + +from models.perf.perf_utils import prep_perf_report +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from models.utility_functions import ( + profiler, + skip_for_grayskull, +) + + +@skip_for_grayskull("UNet not currently supported on GS") +@pytest.mark.models_device_performance_bare_metal +@pytest.mark.parametrize( + "batch, groups, expected_device_perf_fps", + ((2, 1, 443.0),), +) +def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float, reset_seeds): + command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE FW SAMPLES/S" + post_processed_results = run_device_perf( + command, subdir="unet_shallow", num_iterations=1, cols=cols, batch_size=batch + ) + expected_perf_cols = {inference_time_key: expected_device_perf_fps} + expected_results = check_device_perf( + post_processed_results, margin=0.02, expected_perf_cols=expected_perf_cols, assert_on_fail=True + ) + prep_device_perf_report( + model_name=f"unet-shallow_batch-{batch}_groups-{groups}", + batch_size=batch, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments="", + ) + + +@skip_for_grayskull("UNet not currently supported on GS") +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 64768}], indirect=True) +@pytest.mark.parametrize( + "batch, groups, iterations, expected_compile_time, expected_inference_time_ms", + ((2, 1, 16, 16.0, 39.0),), +) +def test_unet_perf_e2e( + batch: int, + groups: int, + iterations: int, + expected_compile_time: float, + expected_inference_time_ms: float, + device, + use_program_cache, + reset_seeds, +): + profiler.clear() + + torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True) + + profiler.start(f"initialize_ref_model") + model = unet_shallow_torch.UNet.from_random_weights(groups=1) + profiler.end(f"initialize_ref_model") + + profiler.start(f"initialize_model") + parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device) + ttnn_model = unet_shallow_ttnn.UNet(parameters, device) + profiler.end(f"initialize_model") + + torch_output_tensor = model(torch_input) + + logger.info(f"Compiling model with warmup run") + profiler.start(f"inference_and_compile_time") + output_tensor = ttnn_model(ttnn_input) + profiler.end(f"inference_and_compile_time") + + inference_and_compile_time = profiler.get("inference_and_compile_time") + logger.info(f"Model compiled with warmup run in {(inference_and_compile_time):.2f} s") + + logger.info(f"Running inference for {iterations} iterations") + for idx in range(iterations): + profiler.start("inference_time") + profiler.start(f"inference_time_{idx}") + output_tensor = ttnn_model(ttnn_input) + profiler.end(f"inference_time_{idx}") + profiler.end("inference_time") + + mean_inference_time = profiler.get("inference_time") + inference_time = profiler.get(f"inference_time_{iterations - 1}") + compile_time = inference_and_compile_time - inference_time + logger.info(f"Model compilation took {compile_time:.1f} s") + logger.info(f"Inference time on last iterations was completed in {(inference_time * 1000.0):.2f} ms") + logger.info( + f"Mean inference time for {batch} (batch) images was {(mean_inference_time * 1000.0):.2f} ms ({batch / mean_inference_time:.2f} fps)" + ) + + expected_inference_time = expected_inference_time_ms * 1e-3 + prep_perf_report( + model_name=f"unet_shallow", + batch_size=batch, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + ) + + logger.info(f"Running sanity check against reference model output") + B, C, H, W = torch_output_tensor.shape + ttnn_tensor = ttnn.to_torch(output_tensor).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2) + assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.99) + + +@skip_for_grayskull("UNet not currently supported on GS") +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 64768}], indirect=True) +@pytest.mark.parametrize( + "batch, groups, iterations, expected_compile_time, expected_inference_time_ms", + ((2, 1, 16, 16.0, 39.0),), +) +def test_unet_data_parallel_perf_e2e( + batch: int, + groups: int, + iterations: int, + expected_compile_time: float, + expected_inference_time_ms: float, + mesh_device, + use_program_cache, + reset_seeds, +): + if not is_n300_with_eth_dispatch_cores(mesh_device): + pytest.skip("Test is only valid for N300") + + profiler.clear() + + inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) + weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) + output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) + + torch_input, ttnn_input = create_unet_input_tensors(mesh_device, batch, groups, pad_input=True) + + profiler.start(f"initialize_ref_model") + model = unet_shallow_torch.UNet.from_random_weights(groups=groups) + profiler.end(f"initialize_ref_model") + + profiler.start(f"initialize_model") + parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=mesh_device) + ttnn_model = unet_shallow_ttnn.UNet(parameters, device=mesh_device, mesh_mapper=weights_mesh_mapper) + profiler.end(f"initialize_model") + + num_devices = len(mesh_device.get_device_ids()) + total_batch = num_devices * batch + torch_input, ttnn_input = create_unet_input_tensors( + mesh_device, total_batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper + ) + logger.info(f"Created reference input tensors: {list(torch_input.shape)}") + logger.info( + f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={mesh_device.get_device_ids()}" + ) + + torch_output_tensor = model(torch_input) + + logger.info(f"Compiling model with warmup run") + profiler.start(f"inference_and_compile_time") + output_tensor = ttnn_model(ttnn_input) + profiler.end(f"inference_and_compile_time") + + inference_and_compile_time = profiler.get("inference_and_compile_time") + logger.info(f"Model compiled with warmup run in {(inference_and_compile_time):.2f} s") + + logger.info(f"Running inference for {iterations} iterations") + for idx in range(iterations): + profiler.start("inference_time") + profiler.start(f"inference_time_{idx}") + output_tensor = ttnn_model(ttnn_input) + profiler.end(f"inference_time_{idx}") + profiler.end("inference_time") + + mean_inference_time = profiler.get("inference_time") + inference_time = profiler.get(f"inference_time_{iterations - 1}") + compile_time = inference_and_compile_time - inference_time + logger.info(f"Model compilation took {compile_time:.1f} s") + logger.info(f"Inference time on last iterations was completed in {(inference_time * 1000.0):.2f} ms") + logger.info( + f"Mean inference time for {total_batch} (batch) images was {(mean_inference_time * 1000.0):.2f} ms ({total_batch / mean_inference_time:.2f} fps)" + ) + + expected_inference_time = expected_inference_time_ms * 1e-3 + prep_perf_report( + model_name=f"unet_shallow-data_parallel", + batch_size=total_batch, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="batch_{total_batch}-num_devices_{num_devices}", + ) + + logger.info(f"Running sanity check against reference model output") + check_pcc_conv(torch_output_tensor, output_tensor, mesh_composer=output_mesh_composer, pcc=0.99) diff --git a/models/experimental/functional_unet/tt/model_preprocessing.py b/models/experimental/functional_unet/tt/model_preprocessing.py index c92848c42a7..62583bb4dfd 100644 --- a/models/experimental/functional_unet/tt/model_preprocessing.py +++ b/models/experimental/functional_unet/tt/model_preprocessing.py @@ -66,34 +66,65 @@ def create_unet_model_parameters(model: unet_shallow_torch.UNet, input_tensor: t } parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32} + parameters.c1["use_split_reader"] = True + parameters.c1["use_activation_double_buffer"] = True parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32} + parameters.c1_2["use_split_reader"] = True + parameters.c1_2["use_activation_double_buffer"] = True parameters.c2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} parameters.c2_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} + parameters.c2_2["use_activation_double_buffer"] = True + parameters.c3["conv_blocking_and_parallelization_config_override"] = None + parameters.c3["use_split_reader"] = True + parameters.c3["use_activation_double_buffer"] = True parameters.c3_2["conv_blocking_and_parallelization_config_override"] = None + parameters.c3_2["use_split_reader"] = True + parameters.c3_2["use_activation_double_buffer"] = True + parameters.c4["conv_blocking_and_parallelization_config_override"] = None + parameters.c4["use_activation_double_buffer"] = True parameters.c4_2["conv_blocking_and_parallelization_config_override"] = None + parameters.c4_2["use_activation_double_buffer"] = True parameters.bnc["conv_blocking_and_parallelization_config_override"] = None + parameters.bnc["use_activation_double_buffer"] = True parameters.bnc_2["conv_blocking_and_parallelization_config_override"] = None + parameters.bnc_2["use_activation_double_buffer"] = True parameters.c5["conv_blocking_and_parallelization_config_override"] = None + parameters.c5["use_activation_double_buffer"] = True parameters.c5_2["conv_blocking_and_parallelization_config_override"] = None + parameters.c5_2["use_activation_double_buffer"] = True parameters.c5_3["conv_blocking_and_parallelization_config_override"] = None + parameters.c5_3["use_activation_double_buffer"] = True parameters.c6["conv_blocking_and_parallelization_config_override"] = None + parameters.c6["use_split_reader"] = True + parameters.c6["use_activation_double_buffer"] = True parameters.c6_2["conv_blocking_and_parallelization_config_override"] = None + parameters.c6_2["use_split_reader"] = True + parameters.c6_2["use_activation_double_buffer"] = True parameters.c6_3["conv_blocking_and_parallelization_config_override"] = None + parameters.c6_3["use_split_reader"] = True + parameters.c6_3["use_activation_double_buffer"] = True - parameters.c7["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32} + parameters.c7["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} + parameters.c7["use_activation_double_buffer"] = True parameters.c7_2["conv_blocking_and_parallelization_config_override"] = None + parameters.c7_2["use_split_reader"] = True + parameters.c7_2["use_activation_double_buffer"] = True parameters.c7_3["conv_blocking_and_parallelization_config_override"] = None - - parameters.c8["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32} - parameters.c8["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32} - parameters.c8_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32} - parameters.c8_3["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32} + parameters.c7_3["use_split_reader"] = True + parameters.c7_3["use_activation_double_buffer"] = True + + parameters.c8["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} + parameters.c8["use_activation_double_buffer"] = True + parameters.c8_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} + parameters.c8_2["use_activation_double_buffer"] = True + parameters.c8_3["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32} + parameters.c8_3["use_activation_double_buffer"] = True parameters.output_layer["conv_blocking_and_parallelization_config_override"] = None diff --git a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py index 8f4142a2287..e7efbb22f68 100644 --- a/models/experimental/functional_unet/tt/unet_shallow_ttnn.py +++ b/models/experimental/functional_unet/tt/unet_shallow_ttnn.py @@ -6,6 +6,8 @@ import ttnn import torch +from typing import List + from ttnn.operations.conv2d import determine_parallel_config, create_sharded_memory_config_from_parallel_config from models.utility_functions import nearest_32 @@ -23,6 +25,7 @@ def determine_num_cores_for_upsample(nhw: int, width: int, max_cores=64) -> int: return cores +# TODO: Make this valid over any num_cores def get_core_grid_from_num_cores(num_cores: int): if num_cores == 44: return ttnn.CoreRangeSet( @@ -42,38 +45,37 @@ def get_core_grid_from_num_cores(num_cores: int): raise RuntimeError(f"Could not get core grid given num_cores={num_cores}") -def unet_concat(ttnn_tensors, dim=-1): - assert len(ttnn_tensors) > 0 +def unet_concat(inputs: List, dim=-1): + assert len(inputs) > 0 assert dim < 0 - ttlib_tensors = ttnn_tensors - all_sharded = all(t.is_sharded() for t in ttlib_tensors) + all_sharded = all(tensor.is_sharded() for tensor in inputs) if all_sharded: - max_idx, output_mem_config = max( - ((i, t.memory_config()) for i, t in enumerate(ttlib_tensors)), key=lambda m: m[1].shard_spec.num_cores() + max_idx, memory_config = max( + ((i, t.memory_config()) for i, t in enumerate(inputs)), key=lambda m: m[1].shard_spec.num_cores() ) - for i in range(0, len(ttlib_tensors)): + for i in range(0, len(inputs)): if i == max_idx: continue - t = ttlib_tensors[i] + t = inputs[i] t_mem_config = t.memory_config() t_shard_shape = t_mem_config.shard_spec.shape - output_shard_shape = output_mem_config.shard_spec.shape + output_shard_shape = memory_config.shard_spec.shape output_shard_shape[dim] += t_shard_shape[dim] - output_mem_config.shard_spec.shape = output_shard_shape + memory_config.shard_spec.shape = output_shard_shape reshard_shape = output_shard_shape reshard_shape[dim] = t_shard_shape[dim] if reshard_shape != t_shard_shape: t_mem_config.shard_spec.shape = reshard_shape - t_mem_config.shard_spec.grid = output_mem_config.shard_spec.grid - t_mem_config.shard_spec.orientation = output_mem_config.shard_spec.orientation - ttlib_tensors[i] = ttnn.reshard(t, t_mem_config) + t_mem_config.shard_spec.grid = memory_config.shard_spec.grid + t_mem_config.shard_spec.orientation = memory_config.shard_spec.orientation + inputs[i] = ttnn.reshard(t, t_mem_config) else: - output_mem_config = ttnn.DRAM_MEMORY_CONFIG - for i in range(0, len(ttlib_tensors)): - if ttlib_tensors[i].is_sharded(): - ttlib_tensors[i] = ttnn.to_memory_config(ttlib_tensors[i], output_mem_config) - return ttnn.concat(ttlib_tensors, dim=dim, memory_config=output_mem_config) + memory_config = ttnn.DRAM_MEMORY_CONFIG + for i in range(0, len(inputs)): + if inputs[i].is_sharded(): + inputs[i] = ttnn.to_memory_config(inputs[i], memory_config) + return ttnn.concat(inputs, dim=dim, memory_config=memory_config) class UNetPointwiseConv2D: @@ -106,7 +108,14 @@ def __init__( def __call__(self, x): x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) - x = ttnn.linear(x, self.weight, bias=self.bias, dtype=self.activation_dtype) + x = ttnn.linear( + x, + self.weight, + memory_config=ttnn.L1_MEMORY_CONFIG, + bias=self.bias, + dtype=self.activation_dtype, + core_grid=ttnn.CoreGrid(y=8, x=8), + ) return x @@ -150,8 +159,10 @@ def __init__( deallocate_activation=self.deallocate_activation, fp32_dest_acc_enabled=True, packer_l1_accum_enabled=False, - enable_act_double_buffer=False, - enable_split_reader=False, + enable_act_double_buffer=( + conv.use_activation_double_buffer if "use_activation_double_buffer" in conv else False + ), + enable_split_reader=conv.use_split_reader if "use_split_reader" in conv else False, enable_subblock_padding=False, activation=activation, output_layout=ttnn.TILE_LAYOUT, @@ -226,8 +237,20 @@ def __init__( should_reshard=False, mesh_mapper=None, ): - self.conv1 = UNetConv2D(conv1, bn=bn1, device=device, cache=conv_cache, mesh_mapper=mesh_mapper) - self.conv2 = UNetConv2D(conv2, bn=bn2, device=device, cache=conv_cache, mesh_mapper=mesh_mapper) + self.conv1 = UNetConv2D( + conv1, + bn=bn1, + device=device, + cache=conv_cache, + mesh_mapper=mesh_mapper, + ) + self.conv2 = UNetConv2D( + conv2, + bn=bn2, + device=device, + cache=conv_cache, + mesh_mapper=mesh_mapper, + ) self.pool1 = UNetMaxPool2D(pool, conv2.out_channels, device=device) self.should_reshard = should_reshard @@ -447,7 +470,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b6_3, device, conv_cache=self.conv_cache, - should_reshard=True, + should_reshard=False, mesh_mapper=mesh_mapper, ) self.upblock3 = UNetUpblock( @@ -471,7 +494,7 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: parameters.b8_3, device, conv_cache=self.conv_cache, - should_reshard=True, + should_reshard=False, mesh_mapper=mesh_mapper, ) @@ -483,15 +506,16 @@ def __init__(self, parameters: ParameterDict, device, mesh_mapper=None) -> None: def bottleneck(self, x): if x.is_sharded(): - x = ttnn.sharded_to_interleaved(x, ttnn.L1_MEMORY_CONFIG) - x = ttnn.interleaved_to_sharded( - x, - self.bnc_sharded_memory_config, - ) + x = ttnn.reshard(x, self.bnc_sharded_memory_config) + else: + x = ttnn.interleaved_to_sharded( + x, + self.bnc_sharded_memory_config, + ) x = self.bnc(x) return self.bnc2(x) - def __call__(self, x, original_shape, perf_mode=False): + def __call__(self, x): assert len(x.shape) == 4, f"Expected UNet input tensors to be rank 4 (was {len(x.shape)})" x = x.to(self.device, ttnn.L1_MEMORY_CONFIG) diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index fd892a9f4cb..a7285e69426 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -53,6 +53,7 @@ run_perf_models_cnn_javelin() { local test_marker=$2 # Run tests + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/experimental/functional_unet/tests/test_unet_perf.py -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto tests/device_perf_tests/stable_diffusion -m $test_marker --timeout=480 ## Merge all the generated reports @@ -82,6 +83,8 @@ run_device_perf_models() { if [ "$tt_arch" == "wormhole_b0" ]; then env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/resnet50/tests -m $test_marker + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/experimental/functional_unet/tests/test_unet_perf.py -m $test_marker + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/mamba/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/metal_BERT_large_11/tests -m $test_marker