diff --git a/models/demos/functional_vgg/README.md b/models/demos/functional_vgg/README.md new file mode 100644 index 000000000000..3ad535ae0e64 --- /dev/null +++ b/models/demos/functional_vgg/README.md @@ -0,0 +1,26 @@ +# Introduction + +The VGG model is a popular convolutional neural network architecture introduced by the Visual Geometry Group at Oxford in their paper "Very Deep Convolutional Networks for Large-Scale Image Recognition" (2014). It is widely used for image classification and feature extraction tasks. + +# Platforms: + GS E150, WH N150, WH N300 + +# Model Architectures +- VGG11 +- VGG16 +VGG11 and VGG16 currently supports BATCH_SIZE = 1. + +# How to Run +To run the demo for image classification of the VGG model using ImageNet-1k Validation Dataset, follow these instructions + +- Use the following command to run the model using ttnn_vgg +-VGG11 +``` +pytest models/demos/functional_vgg/demo/demo.py::test_demo_imagenet_vgg11 +``` +- VGG16 +``` +pytest models/demos/functional_vgg/demo/demo.py::test_demo_imagenet_vgg16 +``` + +NOTE: one ttnn.reshape in VGG11 and VGG16 is on host. diff --git a/models/demos/functional_vgg/demo/demo.py b/models/demos/functional_vgg/demo/demo.py new file mode 100644 index 000000000000..8deab60b43e8 --- /dev/null +++ b/models/demos/functional_vgg/demo/demo.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + + +import torch +from loguru import logger +from torchvision import models +from transformers import AutoImageProcessor +import pytest +import tt_lib +import torch.nn as nn + +from models.utility_functions import ( + disable_compilation_reports, + disable_persistent_kernel_cache, + enable_persistent_kernel_cache, + profiler, +) +import ttnn + +from models.demos.functional_vgg.demo_utils import get_data, get_data_loader, get_batch, preprocess +from loguru import logger +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.functional_vgg.tt import ttnn_vgg + +vgg_model_config = { + "MATH_FIDELITY": ttnn.MathFidelity.LoFi, + "WEIGHTS_DTYPE": ttnn.bfloat16, + "ACTIVATIONS_DTYPE": ttnn.bfloat16, +} + + +def run_vgg_imagenet_inference_vgg( + batch_size, + iterations, + imagenet_label_dict, + model_location_generator, + model_class, + weights, + device, + model_config=vgg_model_config, +): + disable_persistent_kernel_cache() + disable_compilation_reports() + profiler.clear() + + # Setup model + torch_model = model_class(weights=weights) + torch_model.to(torch.bfloat16) + torch_model.eval() + + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_model, + device=device, + convert_to_ttnn=lambda *_: True, + custom_preprocessor=ttnn_vgg.custom_preprocessor, + ) + + if model_class == models.vgg11: + ttnn_model = ttnn_vgg.ttnn_vgg11 + model_name = "VGG11" + else: + ttnn_model = ttnn_vgg.ttnn_vgg16 + model_name = "VGG16" + + # load inputs + logger.info("ImageNet-1k validation Dataset") + input_loc = str(model_location_generator("ImageNet_data")) + data_loader = get_data_loader(input_loc, batch_size, iterations) + + # load ImageNet batch by batch + # and run inference + correct = 0 + for iter in range(iterations): + predictions = [] + torch_predictions = [] + inputs, labels = get_batch(data_loader) + torch_outputs = torch_model(inputs) + permuted_inputs = torch.permute(inputs, (0, 2, 3, 1)) + tt_batched_input_tensor = ttnn.from_torch(permuted_inputs, ttnn.bfloat16) + tt_output = ttnn_model(device, tt_batched_input_tensor, parameters, batch_size, model_config) + tt_output = ttnn.to_torch(tt_output) + prediction = tt_output[:, 0, 0, :].argmax(dim=-1) + torch_prediction = torch_outputs[:, :].argmax(dim=-1) + for i in range(batch_size): + predictions.append(imagenet_label_dict[prediction[i].item()]) + torch_predictions.append(imagenet_label_dict[torch_prediction[i].item()]) + logger.info( + f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- \n Torch Predicted label:{predictions[-1]} \tPredicted Label: {predictions[-1]}" + ) + if imagenet_label_dict[labels[i]] == predictions[-1]: + correct += 1 + + del tt_output, tt_batched_input_tensor, inputs, labels, predictions + accuracy = correct / (batch_size * iterations) + logger.info(f"Model {model_name}") + logger.info(f"Accuracy for {batch_size}x{iterations} inputs: {accuracy}") + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "batch_size, iterations", + ((1, 1),), +) +@pytest.mark.parametrize( + "model_class, weights", + [ + (models.vgg11, models.VGG11_Weights.IMAGENET1K_V1), + (models.vgg16, models.VGG16_Weights.IMAGENET1K_V1), + ], +) +def test_demo_imagenet_vgg( + batch_size, iterations, imagenet_label_dict, model_location_generator, model_class, weights, device +): + run_vgg_imagenet_inference_vgg( + batch_size, iterations, imagenet_label_dict, model_location_generator, model_class, weights, device + ) diff --git a/models/demos/functional_vgg/demo_utils.py b/models/demos/functional_vgg/demo_utils.py new file mode 100644 index 000000000000..32296462f017 --- /dev/null +++ b/models/demos/functional_vgg/demo_utils.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +from PIL import Image +import torch +import os +import glob +from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES +from datasets import load_dataset +from torchvision import models +from PIL import Image +import torchvision.transforms as transforms +import torch + + +class InputExample(object): + def __init__(self, image, label=None): + self.image = image + self.label = label + + +def get_input(image_path): + img = Image.open(image_path) + return img + + +def get_label(image_path): + _, image_name = image_path.rsplit("/", 1) + image_name_exact, _ = image_name.rsplit(".", 1) + _, label_id = image_name_exact.rsplit("_", 1) + label = list(IMAGENET2012_CLASSES).index(label_id) + return label + + +preprocess = transforms.Compose( + [ + transforms.Resize(256), # Resize the shorter side to 256 pixels + transforms.CenterCrop(224), # Crop the center to 224x224 pixels + transforms.ToTensor(), # Convert the image to a tensor + transforms.Normalize( # Normalize using ImageNet's mean and std + mean=[0.485, 0.456, 0.406], # These are the mean values for each channel + std=[0.229, 0.224, 0.225], # These are the std values for each channel + ), + ] +) + + +def get_batch(data_loader): + loaded_images = next(data_loader) + images = None + labels = [] + transform = transforms.ToTensor() + resize_transform = transforms.Resize((224, 224)) + for image in loaded_images: + img = image.image + labels.append(image.label) + if img.mode == "L": + img = img.convert(mode="RGB") + + img = preprocess(img) + img = img.to(torch.bfloat16) + img = img.unsqueeze(0) + if images is None: + images = img + else: + images = torch.cat((images, img), dim=0) + return images, labels + + +def get_data_loader(input_loc, batch_size, iterations): + img_dir = input_loc + "/" + data_path = os.path.join(img_dir, "*G") + files = glob.glob(data_path) + + def loader(): + examples = [] + for f1 in files: + examples.append( + InputExample( + image=get_input(f1), + label=get_label(f1), + ) + ) + if len(examples) == batch_size: + yield examples + del examples + examples = [] + + def loader_hf(): + examples = [] + for f1 in files: + examples.append( + InputExample( + image=f1["image"], + label=f1["label"], + ) + ) + if len(examples) == batch_size: + yield examples + del examples + examples = [] + + if len(files) == 0: + files_raw = iter(load_dataset("imagenet-1k", split="validation", use_auth_token=True, streaming=True)) + files = [] + sample_count = batch_size * iterations + for _ in range(sample_count): + files.append(next(files_raw)) + del files_raw + return loader_hf() + + return loader() + + +def get_data(input_loc): + img_dir = input_loc + "/" + data_path = os.path.join(img_dir, "*G") + files = sorted(glob.glob(data_path)) + examples = [] + for f1 in files: + examples.append( + InputExample( + image=get_input(f1), + label=get_label(f1), + ) + ) + image_examples = examples + + return image_examples diff --git a/models/demos/functional_vgg/tests/test_perf_vgg.py b/models/demos/functional_vgg/tests/test_perf_vgg.py new file mode 100644 index 000000000000..ed4c660e9e7d --- /dev/null +++ b/models/demos/functional_vgg/tests/test_perf_vgg.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn +import time + +from torchvision import models +from loguru import logger +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.functional_vgg.tt import ttnn_vgg +from models.utility_functions import ( + enable_persistent_kernel_cache, + disable_persistent_kernel_cache, +) +from models.perf.perf_utils import prep_perf_report +from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report +from models.utility_functions import is_grayskull + + +def get_expected_times(functional_vgg): + return (15.0, 10.5) + + +@pytest.mark.models_performance_bare_metal +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "batch_size, act_dtype, weight_dtype, math_fidelity", ((1, ttnn.bfloat16, ttnn.bfloat16, ttnn.MathFidelity.LoFi),) +) +@pytest.mark.parametrize( + "input_shape", + [ + (1, 3, 224, 224), + ], +) +@pytest.mark.parametrize( + "model_class, weights", + [ + (models.vgg11, models.VGG11_Weights.IMAGENET1K_V1), + (models.vgg16, models.VGG16_Weights.IMAGENET1K_V1), + ], +) +def test_vgg( + device, + input_shape, + batch_size, + act_dtype, + weight_dtype, + math_fidelity, + model_class, + weights, +): + disable_persistent_kernel_cache() + torch_model = model_class(weights=weights) + torch_model.to(torch.bfloat16) + torch_model.eval() + torch_input_tensor_nchw = torch.rand(input_shape, dtype=torch.bfloat16) + + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_model, + device=device, + convert_to_ttnn=lambda *_: True, + custom_preprocessor=ttnn_vgg.custom_preprocessor, + ) + + if model_class == models.vgg11: + ttnn_model = ttnn_vgg.ttnn_vgg11 + model_name = "VGG11" + else: + ttnn_model = ttnn_vgg.ttnn_vgg16 + model_name = "VGG16" + + model_config = { + "MATH_FIDELITY": math_fidelity, + "WEIGHTS_DTYPE": weight_dtype, + "ACTIVATIONS_DTYPE": act_dtype, + } + + conv_config = ttnn.Conv2dConfig( + dtype=model_config["ACTIVATIONS_DTYPE"], + weights_dtype=model_config["WEIGHTS_DTYPE"], + math_fidelity=model_config["MATH_FIDELITY"], + activation="relu", + deallocate_activation=True, + input_channels_alignment=16, + act_block_h_override=0, + transpose_shards=True, + ) + + torch_batched_tensor = torch_input_tensor_nchw.repeat(batch_size, 1, 1, 1) + torch_input_tensor = torch.permute(torch_batched_tensor, (0, 2, 3, 1)) + tt_batched_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + + durations = [] + for i in range(2): + tt_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + start = time.time() + ttnn_output = ttnn_model(device, tt_batched_input_tensor, parameters, batch_size, model_config) + output = ttnn.from_device(ttnn_output) + end = time.time() + durations.append(end - start) + enable_persistent_kernel_cache() + + inference_and_compile_time, inference_time, *_ = durations + + expected_compile_time, expected_inference_time = get_expected_times("vgg16") + prep_perf_report( + model_name=model_name, + batch_size=batch_size, + inference_and_compile_time=inference_and_compile_time, + inference_time=inference_time, + expected_compile_time=expected_compile_time, + expected_inference_time=expected_inference_time, + comments="", + inference_time_cpu=0.0, + ) + + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") + logger.info(f"Samples per second: {1 / inference_time * batch_size}") + + +@pytest.mark.parametrize( + "batch_size, model_name", + [ + (1, "ttnn_vgg11"), + (1, "ttnn_vgg16"), + ], +) +@pytest.mark.models_device_performance_bare_metal +def test_perf_device_bare_metal_vgg(batch_size, model_name): + subdir = model_name + num_iterations = 1 + margin = 0.03 + + if model_name == "ttnn_vgg11": + expected_perf = 79.3 if is_grayskull() else 105.7 + command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg11.py" + else: + expected_perf = 73.6 if is_grayskull() else 92.6 + command = f"pytest tests/ttnn/integration_tests/vgg/test_ttnn_vgg16.py" + + cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] + + inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" + expected_perf_cols = {inference_time_key: expected_perf} + + post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) + expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) + prep_device_perf_report( + model_name=f"ttnn_functional_{model_name}_{batch_size}", + batch_size=batch_size, + post_processed_results=post_processed_results, + expected_results=expected_results, + comments="", + ) diff --git a/models/demos/functional_vgg/tt/ttnn_vgg.py b/models/demos/functional_vgg/tt/ttnn_vgg.py new file mode 100644 index 000000000000..2fdddfef6050 --- /dev/null +++ b/models/demos/functional_vgg/tt/ttnn_vgg.py @@ -0,0 +1,312 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch.nn as nn + +from typing import List, Union, Dict + +import ttnn + +import math + +cfgs: Dict[str, List[Union[str, int]]] = { + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "D": [ + 64, + 64, + "M", + 128, + 128, + "M", + 256, + 256, + 256, + "M", + 512, + 512, + 512, + "M", + 512, + 512, + 512, + "M", + ], +} +conv_ttnn_params = [ + [3, 64, 224, 224], + [64, 64, 224, 224], + [64, 128, 112, 112], + [128, 128, 112, 112], + [128, 256, 56, 56], + [256, 256, 56, 56], + [256, 256, 56, 56], + [256, 512, 28, 28], + [512, 512, 28, 28], + [512, 512, 28, 28], + [512, 512, 14, 14], + [512, 512, 14, 14], + [512, 512, 14, 14], +] +conv_feature_ids = [0, 2, 5, 7, 10, 12, 14, 17, 19, 21, 24, 26, 28] +classifier_ids = [0, 3, 6] +h_override = [128, 128, 128, 64, 32, 32, 32, 32, 32, 32, 32, 32, 32] + + +def ttnn_vgg16( + device, + tt_x, + parameters, + batch_size, + model_config, +): + iter_conv_id = 0 + for itr, v in enumerate(cfgs["D"]): + if v == "M": + l = list(tt_x.shape) + in_n, in_c, in_h, in_w = list(tt_x.shape) + + tt_x = ttnn.to_layout(tt_x, ttnn.ROW_MAJOR_LAYOUT) + ttact_d = ttnn.to_device(tt_x, device) + tt_x = ttnn.max_pool2d( + input_tensor=ttact_d, + batch_size=batch_size, + input_h=int(math.sqrt(in_h / batch_size)), + input_w=int(math.sqrt(in_h / batch_size)), + channels=l[3], + kernel_size=[2, 2], + stride=[2, 2], + padding=[0, 0], + dilation=[1, 1], + device=device, + ) + ttnn.deallocate(ttact_d) + tt_x = ttnn.from_device(tt_x) + + else: + h_sharding = True + + if conv_ttnn_params[iter_conv_id][0] > 128: + h_sharding = False + conv_config = ttnn.Conv2dConfig( + dtype=model_config["ACTIVATIONS_DTYPE"], + weights_dtype=model_config["WEIGHTS_DTYPE"], + math_fidelity=model_config["MATH_FIDELITY"], + math_approx_mode_enabled=True, + fp32_dest_acc_enabled=False, + packer_l1_accum_enabled=False, + activation="relu", + deallocate_activation=False, + input_channels_alignment=32, + reallocate_halo_output=False, + act_block_h_override=h_override[iter_conv_id], + transpose_shards=True, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if h_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), + reshard_if_not_optimal=True, + ) + + tt_weight = parameters.features[conv_feature_ids[iter_conv_id]].weight + tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) + tt_bias = parameters.features[conv_feature_ids[iter_conv_id]].bias + # Call ttnn.conv + conv_op_cache = {} + [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + input_tensor=tt_x, + weight_tensor=tt_weight, + in_channels=conv_ttnn_params[iter_conv_id][0], + out_channels=conv_ttnn_params[iter_conv_id][1], + device=device, + bias_tensor=tt_bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=batch_size, + input_height=conv_ttnn_params[iter_conv_id][2], + input_width=conv_ttnn_params[iter_conv_id][3], + conv_config=conv_config, + conv_op_cache=conv_op_cache, + ) + tt_x = ttnn.from_device(tt_output_tensor_on_device) + ttnn.deallocate(tt_output_tensor_on_device) + iter_conv_id += 1 + + tt_x = ttnn.to_device(tt_x, device) + tt_x = ttnn.to_layout(tt_x, ttnn.TILE_LAYOUT) + tt_x = ttnn.permute(tt_x, (0, 3, 1, 2)) + tt_x = ttnn.from_device(tt_x) + tt_x = ttnn.to_layout(tt_x, ttnn.ROW_MAJOR_LAYOUT) + tt_x = ttnn.reshape(tt_x, (batch_size, 1, 1, -1)) + tt_x = ttnn.to_layout(tt_x, layout=ttnn.TILE_LAYOUT) + tt_x = ttnn.to_device(tt_x, device) + + # Linear 1 + tt_x = ttnn.linear( + tt_x, + parameters["classifier"][classifier_ids[0]]["weight"], + bias=parameters["classifier"][classifier_ids[0]]["bias"], + activation="relu", + ) + + # Linear 2 + tt_x = ttnn.linear( + tt_x, + parameters["classifier"][classifier_ids[1]]["weight"], + bias=parameters["classifier"][classifier_ids[1]]["bias"], + activation="relu", + ) + + # Linear 3 + tt_x = ttnn.linear( + tt_x, + parameters["classifier"][classifier_ids[2]]["weight"], + bias=parameters["classifier"][classifier_ids[2]]["bias"], + ) + return tt_x + + +conv_feature_ids_2 = [0, 3, 6, 8, 11, 13, 16, 18] +conv_ttnn_params_2 = [ + [3, 64, 224, 224], + [64, 128, 112, 112], + [128, 256, 56, 56], + [256, 256, 56, 56], + [256, 512, 28, 28], + [512, 512, 28, 28], + [512, 512, 14, 14], + [512, 512, 14, 14], +] +height_override_11 = [128, 128, 32, 32, 32, 32, 32, 32] + + +def ttnn_vgg11( + device, + tt_x, + parameters, + batch_size, + model_config, +): + iter_conv_id = 0 + for itr, v in enumerate(cfgs["A"]): + if v == "M": + l = list(tt_x.shape) + + in_n, in_c, in_h, in_w = list(tt_x.shape) + + tt_x = ttnn.to_layout(tt_x, ttnn.ROW_MAJOR_LAYOUT) + ttact_d = ttnn.to_device(tt_x, device) + tt_x = ttnn.max_pool2d( + input_tensor=ttact_d, + batch_size=batch_size, + input_h=int(math.sqrt(in_h / batch_size)), + input_w=int(math.sqrt(in_h / batch_size)), + channels=l[3], + kernel_size=[2, 2], + stride=[2, 2], + padding=[0, 0], + dilation=[1, 1], + device=device, + ) + tt_x = ttnn.from_device(tt_x) + ttnn.deallocate(ttact_d) + + else: + h_sharding = True + if conv_ttnn_params_2[iter_conv_id][0] > 128: + h_sharding = False + conv_config = ttnn.Conv2dConfig( + dtype=model_config["ACTIVATIONS_DTYPE"], + weights_dtype=model_config["WEIGHTS_DTYPE"], + math_fidelity=model_config["MATH_FIDELITY"], + math_approx_mode_enabled=True, + fp32_dest_acc_enabled=True, + activation="relu", + deallocate_activation=False, + input_channels_alignment=32, + reallocate_halo_output=False, + act_block_h_override=height_override_11[iter_conv_id], + transpose_shards=True, + shard_layout=( + ttnn.TensorMemoryLayout.HEIGHT_SHARDED if h_sharding else ttnn.TensorMemoryLayout.BLOCK_SHARDED + ), + ) + + tt_weight = parameters.features[conv_feature_ids_2[iter_conv_id]].weight + tt_weight = ttnn.to_layout(ttnn.from_device(tt_weight), layout=ttnn.ROW_MAJOR_LAYOUT) + tt_bias = parameters.features[conv_feature_ids_2[iter_conv_id]].bias + + # Call ttnn.conv + conv_op_cache = {} + [tt_output_tensor_on_device, out_height, out_width, weights_device, bias_device] = ttnn.conv2d( + input_tensor=tt_x, + weight_tensor=tt_weight, + in_channels=conv_ttnn_params_2[iter_conv_id][0], + out_channels=conv_ttnn_params_2[iter_conv_id][1], + device=device, + bias_tensor=tt_bias, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + batch_size=batch_size, + input_height=conv_ttnn_params_2[iter_conv_id][2], + input_width=conv_ttnn_params_2[iter_conv_id][3], + conv_config=conv_config, + conv_op_cache=conv_op_cache, + ) + tt_x = ttnn.from_device(tt_output_tensor_on_device) + ttnn.deallocate(tt_output_tensor_on_device) + iter_conv_id += 1 + + tt_x = ttnn.to_device(tt_x, device) + tt_x = ttnn.to_layout(tt_x, ttnn.TILE_LAYOUT) + tt_x = ttnn.permute(tt_x, (0, 3, 1, 2)) + tt_x = ttnn.from_device(tt_x) + tt_x = ttnn.to_layout(tt_x, ttnn.ROW_MAJOR_LAYOUT) + tt_x = ttnn.reshape(tt_x, (batch_size, 1, 1, -1)) + tt_x = ttnn.to_layout(tt_x, layout=ttnn.TILE_LAYOUT) + tt_x = ttnn.to_device(tt_x, device) + + # Linear 1 + tt_x = ttnn.linear( + tt_x, + parameters["classifier"][classifier_ids[0]]["weight"], + bias=parameters["classifier"][classifier_ids[0]]["bias"], + ) + tt_x = ttnn.relu(tt_x) + + # Linear 2 + tt_x = ttnn.linear( + tt_x, + parameters["classifier"][classifier_ids[1]]["weight"], + bias=parameters["classifier"][classifier_ids[1]]["bias"], + ) + tt_x = ttnn.relu(tt_x) + + # Linear 3 + tt_x = ttnn.linear( + tt_x, + parameters["classifier"][classifier_ids[2]]["weight"], + bias=parameters["classifier"][classifier_ids[2]]["bias"], + ) + + return tt_x + + +def preprocess_conv_parameter(parameter, *, dtype): + parameter = ttnn.from_torch(parameter, dtype=dtype, layout=ttnn.TILE_LAYOUT) + return parameter + + +def custom_preprocessor(model, name): + parameters = {} + if isinstance(model, nn.Conv2d): + weight = model.weight + bias = model.bias + while weight.dim() < 4: + weight = weight.unsqueeze(0) + while bias.dim() < 4: + bias = bias.unsqueeze(0) + parameters["weight"] = preprocess_conv_parameter(weight, dtype=ttnn.bfloat16) + parameters["bias"] = preprocess_conv_parameter(bias, dtype=ttnn.bfloat16) + return parameters diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index dbe5becdaca3..8f1400d2b9fd 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -29,6 +29,8 @@ run_perf_models_other() { env pytest -n auto models/demos/metal_BERT_large_11/tests -m $test_marker + env pytest -n auto models/demos/functional_vgg/tests/test_perf_vgg.py -m $test_marker + ## Merge all the generated reports env python models/perf/merge_perf_results.py } @@ -74,6 +76,8 @@ run_device_perf_models() { env pytest models/demos/distilbert/tests -m $test_marker + env pytest models/demos/functional_vgg/tests/ -m $test_marker + if [ "$tt_arch" == "grayskull" ]; then #TODO(MO): Until #6560 is fixed, GS device profiler test are grouped with #Model Device perf regression tests to make sure thy run on no-soft-reset BMs diff --git a/tests/scripts/single_card/run_single_card_demo_tests.sh b/tests/scripts/single_card/run_single_card_demo_tests.sh index c10b13dc5402..66d564036ea6 100755 --- a/tests/scripts/single_card/run_single_card_demo_tests.sh +++ b/tests/scripts/single_card/run_single_card_demo_tests.sh @@ -12,6 +12,9 @@ run_common_func_tests() { # Mistral7B WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto models/demos/wormhole/mistral7b/demo/demo.py --timeout 420; fail+=$? + #VGG11/VGG16 + WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto pytest models/demos/functional_vgg/demo/demo.py --timeout 600; fail+=$? + # Bert pytest -n auto --disable-warnings models/demos/metal_BERT_large_11/demo/demo.py -k batch_7; fail+=$? WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest -n auto --disable-warnings models/demos/metal_BERT_large_11/demo/demo.py -k batch_8; fail+=$? diff --git a/tests/ttnn/integration_tests/vgg/test_ttnn_vgg11.py b/tests/ttnn/integration_tests/vgg/test_ttnn_vgg11.py new file mode 100644 index 000000000000..9a0b3dd6dca9 --- /dev/null +++ b/tests/ttnn/integration_tests/vgg/test_ttnn_vgg11.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn + +from torchvision import models +from loguru import logger +from tests.ttnn.utils_for_testing import check_with_pcc_without_tensor_printout +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.functional_vgg.tt import ttnn_vgg + +from PIL import Image +import torchvision.transforms as transforms + + +def imagenet_sample_input(): + path = "models/sample_data/ILSVRC2012_val_00048736.JPEG" + im = Image.open(path) + im = im.resize((224, 224)) + return transforms.ToTensor()(im).unsqueeze(0) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "batch_size, act_dtype, weight_dtype, math_fidelity", ((1, ttnn.bfloat16, ttnn.bfloat16, ttnn.MathFidelity.LoFi),) +) +def test_vgg11( + device, + batch_size, + act_dtype, + weight_dtype, + math_fidelity, +): + torch_model = models.vgg11(weights=models.VGG11_Weights.IMAGENET1K_V1) + torch_model.to(torch.bfloat16) + torch_model.eval() + torch_input_tensor_nchw = imagenet_sample_input().to(torch.bfloat16) + torch_batched_tensor = torch_input_tensor_nchw.repeat(batch_size, 1, 1, 1) + + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_model, + device=device, + convert_to_ttnn=lambda *_: True, + custom_preprocessor=ttnn_vgg.custom_preprocessor, + ) + + model_config = { + "MATH_FIDELITY": math_fidelity, + "WEIGHTS_DTYPE": weight_dtype, + "ACTIVATIONS_DTYPE": act_dtype, + } + + torch_batched_tensor = torch_input_tensor_nchw.repeat(batch_size, 1, 1, 1) + torch_input_tensor = torch.permute(torch_batched_tensor, (0, 2, 3, 1)) + tt_batched_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + + ttnn_output = ttnn_vgg.ttnn_vgg11(device, tt_batched_input_tensor, parameters, batch_size, model_config) + torch_output_tensor = ttnn.to_torch(ttnn_output) + golden_output = torch_model(torch_batched_tensor) + + passing, pcc_msg = check_with_pcc_without_tensor_printout( + (torch_output_tensor.squeeze(1)).squeeze(1), golden_output, pcc=0.99 + ) + logger.info(f"PCC: {pcc_msg}") diff --git a/tests/ttnn/integration_tests/vgg/test_ttnn_vgg16.py b/tests/ttnn/integration_tests/vgg/test_ttnn_vgg16.py new file mode 100644 index 000000000000..1d74df4621f5 --- /dev/null +++ b/tests/ttnn/integration_tests/vgg/test_ttnn_vgg16.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import torch +import pytest +import ttnn + +from torchvision import models +from loguru import logger +from tests.ttnn.utils_for_testing import check_with_pcc_without_tensor_printout +import ttnn +from ttnn.model_preprocessing import preprocess_model_parameters +from models.demos.functional_vgg.tt import ttnn_vgg + +from PIL import Image +import torchvision.transforms as transforms + + +def imagenet_sample_input(): + path = "models/sample_data/ILSVRC2012_val_00048736.JPEG" + im = Image.open(path) + im = im.resize((224, 224)) + return transforms.ToTensor()(im).unsqueeze(0) + + +@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) +@pytest.mark.parametrize( + "batch_size, act_dtype, weight_dtype, math_fidelity", ((1, ttnn.bfloat16, ttnn.bfloat16, ttnn.MathFidelity.LoFi),) +) +def test_vgg16( + device, + batch_size, + act_dtype, + weight_dtype, + math_fidelity, + reset_seeds, +): + torch_model = models.vgg16(weights=models.VGG16_Weights.IMAGENET1K_V1) + torch_model.to(torch.bfloat16) + torch_model.eval() + torch_input_tensor_nchw = imagenet_sample_input().to(torch.bfloat16) + torch_batched_tensor = torch_input_tensor_nchw.repeat(batch_size, 1, 1, 1) + golden_output = torch_model(torch_batched_tensor) + + parameters = preprocess_model_parameters( + initialize_model=lambda: torch_model, + device=device, + convert_to_ttnn=lambda *_: True, + custom_preprocessor=ttnn_vgg.custom_preprocessor, + ) + + model_config = { + "MATH_FIDELITY": math_fidelity, + "WEIGHTS_DTYPE": weight_dtype, + "ACTIVATIONS_DTYPE": act_dtype, + } + + torch_batched_tensor = torch_input_tensor_nchw.repeat(batch_size, 1, 1, 1) + torch_input_tensor = torch.permute(torch_batched_tensor, (0, 2, 3, 1)) + tt_batched_input_tensor = ttnn.from_torch(torch_input_tensor, ttnn.bfloat16) + + ttnn_output = ttnn_vgg.ttnn_vgg16(device, tt_batched_input_tensor, parameters, batch_size, model_config) + torch_output_tensor = ttnn.to_torch(ttnn_output) + + passing, pcc_msg = check_with_pcc_without_tensor_printout( + (torch_output_tensor.squeeze(1)).squeeze(1), golden_output, pcc=0.99 + ) + logger.info(f"PCC: {pcc_msg}")