From 19cba6142ed70e6f37eaaade7892e2953598787c Mon Sep 17 00:00:00 2001 From: Sean Nijjar Date: Thu, 12 Sep 2024 00:04:03 -0400 Subject: [PATCH] #12559: add ttnn implementation for convnet_mnist model --- models/demos/convnet_mnist/README.md | 24 +++ .../convnet_mnist_preprocessing.py | 17 ++ .../convnet_mnist/convnet_mnist_utils.py | 37 +++++ models/demos/convnet_mnist/demo/demo.py | 69 ++++++++ .../convnet_mnist/tests/test_performance.py | 138 ++++++++++++++++ .../demos/convnet_mnist/tt/convnet_mnist.py | 154 ++++++++++++++++++ tests/scripts/run_performance.sh | 4 + .../single_card/run_single_card_demo_tests.sh | 3 + .../convnet_mnist/test_convnet_mnist.py | 58 +++++++ 9 files changed, 504 insertions(+) create mode 100644 models/demos/convnet_mnist/README.md create mode 100644 models/demos/convnet_mnist/convnet_mnist_preprocessing.py create mode 100644 models/demos/convnet_mnist/convnet_mnist_utils.py create mode 100644 models/demos/convnet_mnist/demo/demo.py create mode 100644 models/demos/convnet_mnist/tests/test_performance.py create mode 100644 models/demos/convnet_mnist/tt/convnet_mnist.py create mode 100644 tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py diff --git a/models/demos/convnet_mnist/README.md b/models/demos/convnet_mnist/README.md new file mode 100644 index 000000000000..71ffd750cba4 --- /dev/null +++ b/models/demos/convnet_mnist/README.md @@ -0,0 +1,24 @@ +# Introduction + +Convnet Mnist implements a Convolutions to classify handwritten digits from the MNIST dataset. The MNIST dataset contains grayscale images of handwritten digits (0-9), each of size 32x32 pixels. + +# Platforms: + GS E150, WH N150, WH N300 + +## How to Run + +To run the demo for digit classification using the MNIST model, follow these instructions: + +- Use the following command to run the MNIST model. + +``` +pytest models/demos/convnet_mnist/demo/demo.py +``` + +Maxpool and Softmax are used in torch inside the model. +ISSUES: + #12664 - [softmax](https://github.com/tenstorrent/tt-metal/issues/12664) + #12642 - [maxpool](https://github.com/tenstorrent/tt-metal/issues/12642) + + +### Owner: [vigneshkumarkeerthivasan](https://github.com/vigneshkeerthivasanx) diff --git a/models/demos/convnet_mnist/convnet_mnist_preprocessing.py b/models/demos/convnet_mnist/convnet_mnist_preprocessing.py new file mode 100644 index 000000000000..99681afebe95 --- /dev/null +++ b/models/demos/convnet_mnist/convnet_mnist_preprocessing.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import ttnn + + +def custom_preprocessor(parameters, device): + parameters.conv1.bias = ttnn.to_device(parameters.conv1.bias, device) + parameters.conv1.bias = ttnn.to_device(parameters.conv1.bias, device) + + parameters.fc1.weight = ttnn.to_device(parameters.fc1.weight, device) + parameters.fc1.bias = ttnn.to_device(parameters.fc1.bias, device) + parameters.fc2.weight = ttnn.to_device(parameters.fc2.weight, device) + parameters.fc2.bias = ttnn.to_device(parameters.fc2.bias, device) + + return parameters diff --git a/models/demos/convnet_mnist/convnet_mnist_utils.py b/models/demos/convnet_mnist/convnet_mnist_utils.py new file mode 100644 index 000000000000..74755f817b3c --- /dev/null +++ b/models/demos/convnet_mnist/convnet_mnist_utils.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torchvision +import torchvision.transforms as transforms + + +def get_test_data(batch_size=64): + transform = transforms.Compose( + [ + transforms.Resize((32, 32)), + transforms.ToTensor(), + transforms.Normalize(mean=(0.05,), std=(0.05,)), + ] + ) + + test_dataset = torchvision.datasets.MNIST( + root="./data", + train=False, + download=True, + ) + + batch = [] + images = [] + outputs = [] + + for i in range(batch_size): + img, output = test_dataset[i] + tensor = transform(img).unsqueeze(0) + batch.append(tensor) + images.append(img) + outputs.append(output) + + batch = torch.cat(batch) + return batch, images, outputs diff --git a/models/demos/convnet_mnist/demo/demo.py b/models/demos/convnet_mnist/demo/demo.py new file mode 100644 index 000000000000..828f259f8b0b --- /dev/null +++ b/models/demos/convnet_mnist/demo/demo.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +import pytest + +from pathlib import Path +from loguru import logger + +from models.demos.convnet_mnist.tt.convnet_mnist import convnet_mnist, custom_preprocessor +from models.demos.convnet_mnist import convnet_mnist_preprocessing +from models.demos.convnet_mnist.convnet_mnist_utils import get_test_data +from models.experimental.convnet_mnist.reference.convnet import ConvNet +from ttnn.model_preprocessing import preprocess_model_parameters + + +def model_location_generator(rel_path): + internal_weka_path = Path("/mnt/MLPerf") + has_internal_weka = (internal_weka_path / "bit_error_tests").exists() + + if has_internal_weka: + return Path("/mnt/MLPerf") / rel_path + else: + return Path("/opt/tt-metal-models") / rel_path + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_convnet_mnist(device, reset_seeds): + model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/") + state_dict = str(model_path / "convnet_mnist.pt") + state_dict = torch.load(state_dict) + + test_input, images, output = get_test_data(8) + + model = ConvNet() + model.load_state_dict(state_dict) + model.eval() + torch_output = model(test_input) + batch_size = len(test_input) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor + ) + parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=device) + + ttnn_input = torch.permute(test_input, (0, 2, 3, 1)) + ttnn_input = ttnn.from_torch(ttnn_input, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT) + + ttnn_output = convnet_mnist( + input_tensor=ttnn_input, + device=device, + parameters=parameters, + ) + ttnn_output = ttnn.to_torch(ttnn_output) + + _, torch_predicted = torch.max(torch_output.data, -1) + _, ttnn_predicted = torch.max(ttnn_output.data, -1) + + correct = 0 + for i in range(batch_size): + if output[i] == ttnn_predicted[i]: + correct += 1 + accuracy = correct / (batch_size) + + logger.info(f" Accuracy for {batch_size} Samples : {accuracy}") + logger.info(f"torch_predicted {torch_predicted.squeeze()}") + logger.info(f"ttnn_predicted {ttnn_predicted.squeeze()}") diff --git a/models/demos/convnet_mnist/tests/test_performance.py b/models/demos/convnet_mnist/tests/test_performance.py new file mode 100644 index 000000000000..e0aa4b052ab2 --- /dev/null +++ b/models/demos/convnet_mnist/tests/test_performance.py @@ -0,0 +1,138 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +import time +from pathlib import Path + +from torchvision import models +from loguru import logger +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, +) +from models.perf.perf_utils import prep_perf_report +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from models.demos.convnet_mnist.tt.convnet_mnist import convnet_mnist, custom_preprocessor +from models.demos.convnet_mnist import convnet_mnist_preprocessing +from models.experimental.convnet_mnist.reference.convnet import ConvNet +from models.utility_functions import is_grayskull + + +def get_expected_times(convnet_mnist): + return (15.0, 9.2) + + +def model_location_generator(rel_path): + internal_weka_path = Path("/mnt/MLPerf") + has_internal_weka = (internal_weka_path / "bit_error_tests").exists() + + if has_internal_weka: + return Path("/mnt/MLPerf") / rel_path + else: + return Path("/opt/tt-metal-models") / rel_path + + +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "batch_size, act_dtype, weight_dtype, math_fidelity", ((1, ttnn.bfloat16, ttnn.bfloat16, ttnn.MathFidelity.LoFi),) +) +@pytest.mark.parametrize( + "input_shape", + [ + (1, 1, 32, 32), + ], +) +def test_convnet_mnist( + device, + input_shape, + batch_size, + act_dtype, + weight_dtype, + math_fidelity, +): + disable_persistent_kernel_cache() + model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/") + state_dict = str(model_path / "convnet_mnist.pt") + state_dict = torch.load(state_dict) + + input_tensor = torch.randn(input_shape, dtype=torch.bfloat16) + input_tensor = torch.permute(input_tensor, (0, 2, 3, 1)) + + model = ConvNet() + model.load_state_dict(state_dict) + model.eval() + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor + ) + parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=device) + + durations = [] + for i in range(2): + start = time.time() + ttnn_input = ttnn.from_torch(input_tensor, dtype=ttnn.bfloat16) + + ttnn_output = convnet_mnist( + input_tensor=ttnn_input, + device=device, + parameters=parameters, + ) + output = ttnn.from_device(ttnn_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times("convnet_mnist") + prep_perf_report( + model_name="convnet_mnist", + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@pytest.mark.parametrize( + "batch_size, expected_perf", + [ + [1, 105.710], + ], +) +@pytest.mark.models_device_performance_bare_metal +def test_perf_device_bare_metal_convnet_mnist(batch_size, expected_perf): + subdir = "ttnn_convnet_mnist" + num_iterations = 1 + margin = 0.03 + expected_perf = 1753.5 if is_grayskull() else 2705.5 + + command = f"pytest tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py" + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_functional_convnet_mnist{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments="", + ) diff --git a/models/demos/convnet_mnist/tt/convnet_mnist.py b/models/demos/convnet_mnist/tt/convnet_mnist.py new file mode 100644 index 000000000000..a98d0db0ed02 --- /dev/null +++ b/models/demos/convnet_mnist/tt/convnet_mnist.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +import torch.nn.functional as F +from torch import nn + + +def convnet_mnist( + input_tensor, + parameters, + device, +): + batch_size = input_tensor.shape[0] + torch_maxpool = True + + conv_config = ttnn.Conv2dConfig( + dtype=ttnn.bfloat16, + weights_dtype=ttnn.bfloat16, + math_fidelity=ttnn.MathFidelity.LoFi, + activation="", + shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED, + math_approx_mode_enabled=True, + fp32_dest_acc_enabled=False, + packer_l1_accum_enabled=False, + input_channels_alignment=32, + transpose_shards=False, + reshard_if_not_optimal=True, + deallocate_activation=True, + reallocate_halo_output=True, + ) + + x = ttnn.to_layout(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT) + [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + input_tensor=x, + weight_tensor=parameters.conv1.weight, + in_channels=1, + out_channels=32, + device=device, + bias_tensor=parameters.conv1.bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batch_size=batch_size, + input_height=input_tensor.shape[1], + input_width=input_tensor.shape[2], + conv_config=conv_config, + conv_op_cache={}, + debug=True, + groups=1, + ) + x = ttnn.relu(x) + + if torch_maxpool: # Can be removed once issue #12642 is resolved + x = ttnn.to_torch(x) + x = torch.reshape(x, (batch_size, 30, 30, 32)) + x = torch.permute(x, (0, 3, 1, 2)) + x = F.max_pool2d(x, 2) + x = torch.permute(x, (0, 2, 3, 1)) + x = ttnn.from_torch(x, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT) + + else: + x = ttnn.sharded_to_interleaved(x, ttnn.L1_MEMORY_CONFIG) + x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.max_pool2d( + input_tensor=x, + batch_size=batch_size, + input_h=30, + input_w=30, + channels=32, + kernel_size=[2, 2], + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + device=device, + ) + + [x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + input_tensor=x, + weight_tensor=parameters.conv2.weight, + in_channels=32, + out_channels=64, + device=device, + bias_tensor=parameters.conv2.bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(0, 0), + batch_size=batch_size, + input_height=15, + input_width=15, + conv_config=conv_config, + conv_op_cache={}, + debug=False, + groups=1, + ) + + x = ttnn.relu(x) + + if torch_maxpool: # Can be removed once issue #12642 is resolved + x = ttnn.to_torch(x) + x = torch.reshape(x, (batch_size, 13, 13, 64)) + x = torch.permute(x, (0, 3, 1, 2)) + x = F.max_pool2d(x, 2) + x = ttnn.from_torch(x, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT) + + else: + x = ttnn.sharded_to_interleaved(x, ttnn.DRAM_MEMORY_CONFIG) + x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT) + x = ttnn.max_pool2d( + input_tensor=x, + batch_size=batch_size, + input_h=out_height, + input_w=out_width, + channels=x.shape[-1], + kernel_size=[2, 2], + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + device=device, + ) + x = ttnn.from_device(x) + x = ttnn.reshape(x, (x.shape[0], -1)) + + x = ttnn.to_device(x, device) + x = ttnn.to_layout(x, ttnn.TILE_LAYOUT) + x = ttnn.linear(x, parameters.fc1.weight, bias=parameters.fc1.bias, activation="relu") + + x = ttnn.linear(x, parameters.fc2.weight, bias=parameters.fc2.bias) + + output = torch.softmax(ttnn.to_torch(x), dim=-1) + output = ttnn.from_torch(output, device=device, dtype=ttnn.bfloat16) + return output + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype) + return parameter + + +def custom_preprocessor(model, device): + parameters = {} + if isinstance(model, nn.Conv2d): + weight = model.weight + bias = model.bias + while weight.dim() < 4: + weight = weight.unsqueeze(0) + while bias.dim() < 4: + bias = bias.unsqueeze(0) + parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.bfloat16) + parameters["bias"] = preprocess_conv_parameter(bias, dtype=ttnn.bfloat16) + + return parameters diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index dbe5becdaca3..95f8c7ba4bd7 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -29,6 +29,8 @@ run_perf_models_other() { env pytest -n auto models/demos/metal_BERT_large_11/tests -m $test_marker + env pytest -n auto models/demos/convnet_mnist/tests -m $test_marker + ## Merge all the generated reports env python models/perf/merge_perf_results.py } @@ -74,6 +76,8 @@ run_device_perf_models() { env pytest models/demos/distilbert/tests -m $test_marker + env pytest models/demos/convnet_mnist/tests/ -m $test_marker + if [ "$tt_arch" == "grayskull" ]; then #TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with #Model Device perf regression tests to make sure thy run on no-soft-reset BMs diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index c10b13dc5402..b629e7295689 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -22,6 +22,9 @@ run_common_func_tests() { # Distilbert pytest --disable-warnings models/demos/distilbert/demo/demo.py --timeout 600; fail+=$? + # ConvNet Mnist + pytest --disable-warnings models/demos/convnet_mnist/demo/demo.py --timeout 600; fail+=$? + return $fail } diff --git a/tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py b/tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py new file mode 100644 index 000000000000..f09722d5a655 --- /dev/null +++ b/tests/ttnn/integration_tests/convnet_mnist/test_convnet_mnist.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import ttnn +import pytest + +from pathlib import Path + +from models.demos.convnet_mnist.tt.convnet_mnist import convnet_mnist, custom_preprocessor +from models.demos.convnet_mnist import convnet_mnist_preprocessing +from models.demos.convnet_mnist.convnet_mnist_utils import get_test_data +from models.experimental.convnet_mnist.reference.convnet import ConvNet +from ttnn.model_preprocessing import preprocess_model_parameters +from tests.ttnn.utils_for_testing import assert_with_pcc + + +def model_location_generator(rel_path): + internal_weka_path = Path("/mnt/MLPerf") + has_internal_weka = (internal_weka_path / "bit_error_tests").exists() + + if has_internal_weka: + return Path("/mnt/MLPerf") / rel_path + else: + return Path("/opt/tt-metal-models") / rel_path + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True) +def test_convnet_mnist(device, reset_seeds): + model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/") + state_dict = str(model_path / "convnet_mnist.pt") + state_dict = torch.load(state_dict) + + test_input, images, outputs = get_test_data(8) + + model = ConvNet() + model.load_state_dict(state_dict) + model.eval() + + torch_output = model(test_input) + + parameters = preprocess_model_parameters( + initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor + ) + parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=device) + + ttnn_input = torch.permute(test_input, (0, 2, 3, 1)) + ttnn_input = ttnn.from_torch(ttnn_input, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT) + + ttnn_output = convnet_mnist( + input_tensor=ttnn_input, + device=device, + parameters=parameters, + ) + ttnn_output = ttnn.to_torch(ttnn_output) + + assert_with_pcc(torch_output, ttnn_output, 0.99)