Skip to content

Commit

Permalink
#12196: Use split readers wherever possible in UNet Shallow
Browse files Browse the repository at this point in the history
This change enables split reader and activation double buffering on some
convolutions. This commit also:

Fuses the output layer bias with matmul.

Uses ttnn.reshard instead of ttnn.sharded_to_interleaved + ttnn.interleaved_to_sharded in bottleneck layer.

Adds device and end-to-end performance tests to CI.
  • Loading branch information
esmalTT committed Sep 12, 2024
1 parent afdd15d commit dabe71f
Show file tree
Hide file tree
Showing 7 changed files with 317 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,14 @@ jobs:
arch: wormhole_b0,
runs-on: ["cloud-virtual-machine", "N300", "in-service"],
cmd: tests/scripts/single_card/nightly/run_wh_b0_only.sh,
timeout: 80
timeout: 100
},
{
name: "N150 WH-only models",
arch: wormhole_b0,
runs-on: ["cloud-virtual-machine", "N150", "in-service"],
cmd: tests/scripts/single_card/nightly/run_wh_b0_only.sh,
timeout: 80
timeout: 100
},
{
name: "API tests GS",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_unet_model(batch, groups, device, use_program_cache, reset_seeds):
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)

torch_output_tensor = model(torch_input)
output_tensor = ttnn_model(ttnn_input, list(torch_input.shape))
output_tensor = ttnn_model(ttnn_input)

B, C, H, W = torch_output_tensor.shape
ttnn_tensor = ttnn.to_torch(output_tensor).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@ def test_unet_multi_device_model(batch, groups, mesh_device, use_program_cache,
)

torch_output_tensor = model(torch_input)
output_tensor = ttnn_model(ttnn_input, list(torch_input.shape))
output_tensor = ttnn_model(ttnn_input)

check_pcc_conv(torch_output_tensor, output_tensor, mesh_composer=output_mesh_composer, pcc=0.99)
218 changes: 218 additions & 0 deletions models/experimental/functional_unet/tests/test_unet_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn
import pytest

from loguru import logger

from tests.ttnn.utils_for_testing import assert_with_pcc

from models.experimental.functional_unet.tt.model_preprocessing import (
create_unet_input_tensors,
create_unet_model_parameters,
)
from models.experimental.functional_unet.tt import unet_shallow_torch
from models.experimental.functional_unet.tt import unet_shallow_ttnn
from models.experimental.functional_unet.tests.common import (
check_pcc_conv,
is_n300_with_eth_dispatch_cores,
)

from models.perf.perf_utils import prep_perf_report
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
from models.utility_functions import (
profiler,
skip_for_grayskull,
)


@skip_for_grayskull("UNet not currently supported on GS")
@pytest.mark.models_device_performance_bare_metal
@pytest.mark.parametrize(
"batch, groups, expected_device_perf_fps",
((2, 1, 443.0),),
)
def test_unet_perf_device(batch: int, groups: int, expected_device_perf_fps: float, reset_seeds):
command = f"pytest models/experimental/functional_unet/tests/test_unet_model.py::test_unet_model[device_params0-{groups}-{batch}]"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]

inference_time_key = "AVG DEVICE FW SAMPLES/S"
post_processed_results = run_device_perf(
command, subdir="unet_shallow", num_iterations=1, cols=cols, batch_size=batch
)
expected_perf_cols = {inference_time_key: expected_device_perf_fps}
expected_results = check_device_perf(
post_processed_results, margin=0.02, expected_perf_cols=expected_perf_cols, assert_on_fail=True
)
prep_device_perf_report(
model_name=f"unet-shallow_batch-{batch}_groups-{groups}",
batch_size=batch,
post_processed_results=post_processed_results,
expected_results=expected_results,
comments="",
)


@skip_for_grayskull("UNet not currently supported on GS")
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize("device_params", [{"l1_small_size": 64768}], indirect=True)
@pytest.mark.parametrize(
"batch, groups, iterations, expected_compile_time, expected_inference_time_ms",
((2, 1, 16, 16.0, 39.0),),
)
def test_unet_perf_e2e(
batch: int,
groups: int,
iterations: int,
expected_compile_time: float,
expected_inference_time_ms: float,
device,
use_program_cache,
reset_seeds,
):
profiler.clear()

torch_input, ttnn_input = create_unet_input_tensors(device, batch, groups, pad_input=True)

profiler.start(f"initialize_ref_model")
model = unet_shallow_torch.UNet.from_random_weights(groups=1)
profiler.end(f"initialize_ref_model")

profiler.start(f"initialize_model")
parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device)
profiler.end(f"initialize_model")

torch_output_tensor = model(torch_input)

logger.info(f"Compiling model with warmup run")
profiler.start(f"inference_and_compile_time")
output_tensor = ttnn_model(ttnn_input)
profiler.end(f"inference_and_compile_time")

inference_and_compile_time = profiler.get("inference_and_compile_time")
logger.info(f"Model compiled with warmup run in {(inference_and_compile_time):.2f} s")

logger.info(f"Running inference for {iterations} iterations")
for idx in range(iterations):
profiler.start("inference_time")
profiler.start(f"inference_time_{idx}")
output_tensor = ttnn_model(ttnn_input)
profiler.end(f"inference_time_{idx}")
profiler.end("inference_time")

mean_inference_time = profiler.get("inference_time")
inference_time = profiler.get(f"inference_time_{iterations - 1}")
compile_time = inference_and_compile_time - inference_time
logger.info(f"Model compilation took {compile_time:.1f} s")
logger.info(f"Inference time on last iterations was completed in {(inference_time * 1000.0):.2f} ms")
logger.info(
f"Mean inference time for {batch} (batch) images was {(mean_inference_time * 1000.0):.2f} ms ({batch / mean_inference_time:.2f} fps)"
)

expected_inference_time = expected_inference_time_ms * 1e-3
prep_perf_report(
model_name=f"unet_shallow",
batch_size=batch,
inference_and_compile_time=inference_and_compile_time,
inference_time=inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="",
)

logger.info(f"Running sanity check against reference model output")
B, C, H, W = torch_output_tensor.shape
ttnn_tensor = ttnn.to_torch(output_tensor).reshape(B, H, W, -1)[:, :, :, :C].permute(0, 3, 1, 2)
assert_with_pcc(torch_output_tensor, ttnn_tensor, 0.99)


@skip_for_grayskull("UNet not currently supported on GS")
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize("device_params", [{"l1_small_size": 64768}], indirect=True)
@pytest.mark.parametrize(
"batch, groups, iterations, expected_compile_time, expected_inference_time_ms",
((2, 1, 16, 16.0, 39.0),),
)
def test_unet_data_parallel_perf_e2e(
batch: int,
groups: int,
iterations: int,
expected_compile_time: float,
expected_inference_time_ms: float,
mesh_device,
use_program_cache,
reset_seeds,
):
if not is_n300_with_eth_dispatch_cores(mesh_device):
pytest.skip("Test is only valid for N300")

profiler.clear()

inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0)
weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

torch_input, ttnn_input = create_unet_input_tensors(mesh_device, batch, groups, pad_input=True)

profiler.start(f"initialize_ref_model")
model = unet_shallow_torch.UNet.from_random_weights(groups=groups)
profiler.end(f"initialize_ref_model")

profiler.start(f"initialize_model")
parameters = create_unet_model_parameters(model, torch_input, groups=groups, device=mesh_device)
ttnn_model = unet_shallow_ttnn.UNet(parameters, device=mesh_device, mesh_mapper=weights_mesh_mapper)
profiler.end(f"initialize_model")

num_devices = len(mesh_device.get_device_ids())
total_batch = num_devices * batch
torch_input, ttnn_input = create_unet_input_tensors(
mesh_device, total_batch, groups, pad_input=True, mesh_mapper=inputs_mesh_mapper
)
logger.info(f"Created reference input tensors: {list(torch_input.shape)}")
logger.info(
f"Created multi-device input tensors: shape={list(ttnn_input.shape)} on devices={mesh_device.get_device_ids()}"
)

torch_output_tensor = model(torch_input)

logger.info(f"Compiling model with warmup run")
profiler.start(f"inference_and_compile_time")
output_tensor = ttnn_model(ttnn_input)
profiler.end(f"inference_and_compile_time")

inference_and_compile_time = profiler.get("inference_and_compile_time")
logger.info(f"Model compiled with warmup run in {(inference_and_compile_time):.2f} s")

logger.info(f"Running inference for {iterations} iterations")
for idx in range(iterations):
profiler.start("inference_time")
profiler.start(f"inference_time_{idx}")
output_tensor = ttnn_model(ttnn_input)
profiler.end(f"inference_time_{idx}")
profiler.end("inference_time")

mean_inference_time = profiler.get("inference_time")
inference_time = profiler.get(f"inference_time_{iterations - 1}")
compile_time = inference_and_compile_time - inference_time
logger.info(f"Model compilation took {compile_time:.1f} s")
logger.info(f"Inference time on last iterations was completed in {(inference_time * 1000.0):.2f} ms")
logger.info(
f"Mean inference time for {total_batch} (batch) images was {(mean_inference_time * 1000.0):.2f} ms ({total_batch / mean_inference_time:.2f} fps)"
)

expected_inference_time = expected_inference_time_ms * 1e-3
prep_perf_report(
model_name=f"unet_shallow-data_parallel",
batch_size=total_batch,
inference_and_compile_time=inference_and_compile_time,
inference_time=inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="batch_{total_batch}-num_devices_{num_devices}",
)

logger.info(f"Running sanity check against reference model output")
check_pcc_conv(torch_output_tensor, output_tensor, mesh_composer=output_mesh_composer, pcc=0.99)
43 changes: 37 additions & 6 deletions models/experimental/functional_unet/tt/model_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,34 +66,65 @@ def create_unet_model_parameters(model: unet_shallow_torch.UNet, input_tensor: t
}

parameters.c1["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32}
parameters.c1["use_split_reader"] = True
parameters.c1["use_activation_double_buffer"] = True
parameters.c1_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 8 * 32}
parameters.c1_2["use_split_reader"] = True
parameters.c1_2["use_activation_double_buffer"] = True

parameters.c2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32}
parameters.c2_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32}
parameters.c2_2["use_activation_double_buffer"] = True

parameters.c3["conv_blocking_and_parallelization_config_override"] = None
parameters.c3["use_split_reader"] = True
parameters.c3["use_activation_double_buffer"] = True
parameters.c3_2["conv_blocking_and_parallelization_config_override"] = None
parameters.c3_2["use_split_reader"] = True
parameters.c3_2["use_activation_double_buffer"] = True

parameters.c4["conv_blocking_and_parallelization_config_override"] = None
parameters.c4["use_activation_double_buffer"] = True
parameters.c4_2["conv_blocking_and_parallelization_config_override"] = None
parameters.c4_2["use_activation_double_buffer"] = True

parameters.bnc["conv_blocking_and_parallelization_config_override"] = None
parameters.bnc["use_activation_double_buffer"] = True
parameters.bnc_2["conv_blocking_and_parallelization_config_override"] = None
parameters.bnc_2["use_activation_double_buffer"] = True

parameters.c5["conv_blocking_and_parallelization_config_override"] = None
parameters.c5["use_activation_double_buffer"] = True
parameters.c5_2["conv_blocking_and_parallelization_config_override"] = None
parameters.c5_2["use_activation_double_buffer"] = True
parameters.c5_3["conv_blocking_and_parallelization_config_override"] = None
parameters.c5_3["use_activation_double_buffer"] = True

parameters.c6["conv_blocking_and_parallelization_config_override"] = None
parameters.c6["use_split_reader"] = True
parameters.c6["use_activation_double_buffer"] = True
parameters.c6_2["conv_blocking_and_parallelization_config_override"] = None
parameters.c6_2["use_split_reader"] = True
parameters.c6_2["use_activation_double_buffer"] = True
parameters.c6_3["conv_blocking_and_parallelization_config_override"] = None
parameters.c6_3["use_split_reader"] = True
parameters.c6_3["use_activation_double_buffer"] = True

parameters.c7["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32}
parameters.c7["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32}
parameters.c7["use_activation_double_buffer"] = True
parameters.c7_2["conv_blocking_and_parallelization_config_override"] = None
parameters.c7_2["use_split_reader"] = True
parameters.c7_2["use_activation_double_buffer"] = True
parameters.c7_3["conv_blocking_and_parallelization_config_override"] = None

parameters.c8["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32}
parameters.c8["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32}
parameters.c8_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32}
parameters.c8_3["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 32}
parameters.c7_3["use_split_reader"] = True
parameters.c7_3["use_activation_double_buffer"] = True

parameters.c8["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32}
parameters.c8["use_activation_double_buffer"] = True
parameters.c8_2["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32}
parameters.c8_2["use_activation_double_buffer"] = True
parameters.c8_3["conv_blocking_and_parallelization_config_override"] = {"act_block_h": 5 * 32}
parameters.c8_3["use_activation_double_buffer"] = True

parameters.output_layer["conv_blocking_and_parallelization_config_override"] = None

Expand Down
Loading

0 comments on commit dabe71f

Please sign in to comment.