Skip to content

Commit

Permalink
#13405: Replaced one torch maxpool with ttnn maxpool
Browse files Browse the repository at this point in the history
  • Loading branch information
sabira-mcw committed Dec 3, 2024
1 parent 4fbf718 commit 7ac95f3
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 53 deletions.
2 changes: 0 additions & 2 deletions models/demos/lenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,3 @@ This command will initiate the test for the demo dataset, allowing you to observ
## Inputs

The demo accepts inputs from the MNIST dataset, which consists of a large collection of labeled handwritten digits. The dataset provides a diverse range of examples, enabling the model to learn and generalize effectively. Each input consists of a grayscale image of a handwritten digit, which is processed through the model to produce a predicted classification.

### Owner: [sabira-mcw](https://github.com/sabira-mcw)
9 changes: 5 additions & 4 deletions models/demos/lenet/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@
import pytest
import torch
import ttnn

from torchvision import transforms, datasets
from loguru import logger

from torch.utils.data import DataLoader

from models.utility_functions import (
disable_persistent_kernel_cache,
)
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.lenet.tt import tt_lenet
from models.demos.lenet import lenet_utils
Expand Down Expand Up @@ -50,6 +49,7 @@ def run_demo_dataset(device, batch_size, iterations, model_location_generator, r

accuracy = correct / (batch_size * iterations)
logger.info(f"Dataset Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}")
assert accuracy >= 1.0, f"Expected accuracy : {1.0} Actual accuracy: {accuracy}"


@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
Expand All @@ -62,6 +62,7 @@ def test_demo_dataset(
model_location_generator,
reset_seeds,
):
disable_persistent_kernel_cache()
return run_demo_dataset(
reset_seeds=reset_seeds,
device=device,
Expand Down
39 changes: 20 additions & 19 deletions models/demos/lenet/tests/test_perf_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,35 +2,32 @@

# SPDX-License-Identifier: Apache-2.0

import torch
import pytest
import ttnn
import time
from pathlib import Path

from torchvision import models
from loguru import logger
import ttnn

from ttnn.model_preprocessing import preprocess_model_parameters
from models.utility_functions import (
enable_persistent_kernel_cache,
disable_persistent_kernel_cache,
)
from models.demos.lenet.tt import tt_lenet
from models.utility_functions import is_grayskull, is_wormhole_b0
from models.demos.lenet import lenet_utils
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report
from models.perf.perf_utils import prep_perf_report
from models.demos.lenet import lenet_utils
from models.demos.lenet.tt import tt_lenet


def get_expected_times(tt_lenet):
if is_grayskull():
return {
tt_lenet: (7.525, 0.9495),
tt_lenet: (5.94, 0.63291),
}[tt_lenet]
elif is_wormhole_b0():
return {
tt_lenet: (9.52, 0.91),
tt_lenet: (8.14, 0.8243),
}[tt_lenet]


Expand Down Expand Up @@ -59,13 +56,12 @@ def test_perf_lenet(device, batch_size, tt_lenet, model_location_generator, rese
custom_preprocessor=lenet_utils.custom_preprocessor,
)
parameters = lenet_utils.custom_preprocessor_device(parameters, device)

x = test_input.permute(0, 2, 3, 1)
x = ttnn.from_torch(x, dtype=ttnn.bfloat16)
durations = []
for _ in range(2):
start = time.time()

for _ in range(100):
start = time.time()
ttnn_output = tt_lenet.Lenet(
device=device,
model=model,
Expand All @@ -77,26 +73,31 @@ def test_perf_lenet(device, batch_size, tt_lenet, model_location_generator, rese
)
end = time.time()
durations.append(end - start)
enable_persistent_kernel_cache()

inference_and_compile_time, *inference_times = durations
average_inference_time = sum(inference_times) / len(inference_times)
inference_time = sum(inference_times) / len(inference_times)
expected_compile_time, expected_inference_time = get_expected_times(tt_lenet)

prep_perf_report(
model_name="tt_lenet",
batch_size=batch_size,
inference_and_compile_time=inference_and_compile_time,
inference_time=average_inference_time,
inference_time=inference_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments="",
inference_time_cpu=0.0,
)

logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}")
logger.info(f"Inference time: {average_inference_time}")
logger.info(f"Compile time: {inference_and_compile_time - inference_time}")
logger.info(f"Inference time: {inference_time}")
logger.info(f"Inference times: {inference_times}")
logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}")
logger.info(f"Sample(s) per second: {1 / inference_time * batch_size}")
assert (
inference_time < expected_inference_time
), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}"
logger.info("Exit Lenet perf test")


@pytest.mark.parametrize(
Expand All @@ -109,9 +110,9 @@ def test_perf_device_bare_metal(batch_size, reset_seeds):
num_iterations = 1
margin = 0.03
if is_grayskull():
expected_perf = 6330.022
expected_perf = 193314.92814121
elif is_wormhole_b0():
expected_perf = 20028.54
expected_perf = 113208.6151

command = f"pytest tests/ttnn/integration_tests/lenet/test_lenet.py"
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"]
Expand All @@ -120,7 +121,7 @@ def test_perf_device_bare_metal(batch_size, reset_seeds):
expected_perf_cols = {inference_time_key: expected_perf}

post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols)
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols, assert_on_fail=True)
prep_device_perf_report(
model_name=f"tt_lenet{batch_size}",
batch_size=batch_size,
Expand Down
36 changes: 17 additions & 19 deletions models/demos/lenet/tt/tt_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import ttnn
import torch.nn as nn
import torch


def conv(device, input_tensor, batch_size, parameters):
Expand All @@ -26,7 +25,6 @@ def conv(device, input_tensor, batch_size, parameters):
deallocate_activation=True,
reallocate_halo_output=True,
)
# x = ttnn.to_layout(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
[x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
input_tensor=input_tensor,
weight_tensor=weight,
Expand All @@ -50,10 +48,7 @@ def conv(device, input_tensor, batch_size, parameters):

def Lenet(input_tensor, model, batch_size, num_classes, device, parameters, reset_seeds):
conv_1, out_height, out_width = conv(device, input_tensor, batch_size, parameters.layer1)

conv_1 = ttnn.from_device(conv_1)
conv_1 = ttnn.to_layout(conv_1, layout=ttnn.TILE_LAYOUT)
conv_1 = ttnn.to_device(conv_1, device=device)
conv_1 = ttnn.sharded_to_interleaved(conv_1, ttnn.L1_MEMORY_CONFIG)
conv_1 = ttnn.reshape(conv_1, (batch_size, out_height, out_width, conv_1.shape[-1]))
conv_1 = ttnn.permute(conv_1, (0, 3, 1, 2))
conv_1 = ttnn.to_torch(conv_1)
Expand All @@ -67,21 +62,24 @@ def Lenet(input_tensor, model, batch_size, num_classes, device, parameters, rese

conv_2, out_height, out_width = conv(device, maxpool_1, batch_size, parameters.layer2)

conv_2 = ttnn.from_device(conv_2)
conv_2 = ttnn.to_layout(conv_2, layout=ttnn.TILE_LAYOUT)
conv_2 = ttnn.to_device(conv_2, device=device)
conv_2 = ttnn.reshape(conv_2, (batch_size, out_height, out_width, conv_2.shape[-1]))
conv_2 = ttnn.permute(conv_2, (0, 3, 1, 2))
conv_2 = ttnn.to_torch(conv_2)

max = nn.MaxPool2d(kernel_size=2, stride=2)
maxpool_2 = max(conv_2)

maxpool_2 = ttnn.from_torch(maxpool_2, dtype=ttnn.bfloat16)
conv_2 = ttnn.to_layout(conv_2, layout=ttnn.ROW_MAJOR_LAYOUT)
maxpool_2 = ttnn.max_pool2d(
input_tensor=conv_2,
batch_size=batch_size,
input_h=out_height,
input_w=out_width,
channels=conv_2.shape[3],
kernel_size=[2, 2],
stride=[2, 2],
padding=[0, 0],
dilation=[1, 1],
)

maxpool_2 = ttnn.reshape(maxpool_2, (maxpool_2.shape[0], -1))
maxpool_2 = ttnn.to_device(maxpool_2, device=device)
maxpool_2 = ttnn.sharded_to_interleaved(maxpool_2, ttnn.L1_MEMORY_CONFIG)
maxpool_2 = ttnn.to_layout(maxpool_2, layout=ttnn.TILE_LAYOUT)
maxpool_2 = ttnn.reshape(maxpool_2, (batch_size, 5, 5, maxpool_2.shape[3]))
maxpool_2 = ttnn.permute(maxpool_2, (0, 3, 1, 2))
maxpool_2 = ttnn.reshape(maxpool_2, (maxpool_2.shape[0], -1))

linear_1 = ttnn.linear(
maxpool_2,
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/run_performance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ run_device_perf_models() {

env pytest models/demos/mnist/tests -m $test_marker

env pytest -n auto models/demos/lenet/tests -m $test_marker
env pytest models/demos/lenet/tests -m $test_marker

if [ "$tt_arch" == "grayskull" ]; then
#TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with
Expand Down
5 changes: 2 additions & 3 deletions tests/scripts/single_card/run_single_card_demo_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,11 @@ run_common_func_tests() {
# ConvNet Mnist
pytest --disable-warnings models/demos/convnet_mnist/demo/demo.py --timeout 600; fail+=$?

<<<<<<< HEAD
# Mnist
pytest --disable-warnings models/demos/mnist/demo/demo.py --timeout 600; fail+=$?
=======

# Lenet
pytest --disable-warnings models/demos/lenet/demo/demo.py --timeout 600; fail+=$?
>>>>>>> #13405: TTNN implementation of LENET model

return $fail
}
Expand Down
6 changes: 1 addition & 5 deletions tests/ttnn/integration_tests/lenet/test_lenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import pytest
import ttnn
import torch
import torch.nn as nn

from tests.ttnn.utils_for_testing import assert_with_pcc
from ttnn.model_preprocessing import preprocess_model_parameters
from models.demos.lenet.tt import tt_lenet
Expand All @@ -25,15 +24,12 @@ def test_lenet_inference(device, batch_size, model_location_generator, reset_see
torch_LeNet, state_dict = lenet_utils.load_torch_lenet(pt_model_path, num_classes)
model = torch_LeNet.float()
model = torch_LeNet.eval()

torch_output = model(test_input)

parameters = preprocess_model_parameters(
initialize_model=lambda: torch_LeNet,
custom_preprocessor=lenet_utils.custom_preprocessor,
)
parameters = lenet_utils.custom_preprocessor_device(parameters, device)
x = test_input
x = test_input.permute(0, 2, 3, 1)
x = ttnn.from_torch(
x, dtype=ttnn.bfloat16, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.L1_MEMORY_CONFIG
Expand Down

0 comments on commit 7ac95f3

Please sign in to comment.