-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#10439: ttnn implementation of vgg model
- Loading branch information
1 parent
5448c47
commit 5390c0c
Showing
9 changed files
with
889 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Introduction | ||
|
||
The VGG model is a popular convolutional neural network architecture introduced by the Visual Geometry Group at Oxford in their paper "Very Deep Convolutional Networks for Large-Scale Image Recognition" (2014). It is widely used for image classification and feature extraction tasks. | ||
|
||
# Platforms: | ||
GS E150, WH N150, WH N300 | ||
|
||
# Model Architectures | ||
- VGG11 | ||
- VGG16 | ||
VGG11 and VGG16 currently supports BATCH_SIZE = 1. | ||
|
||
# How to Run | ||
To run the demo for image classification of the VGG model using ImageNet-1k Validation Dataset, follow these instructions | ||
|
||
- Use the following command to run the model using ttnn_vgg | ||
-VGG11 | ||
``` | ||
pytest models/demos/functional_vgg/demo/demo.py::test_demo_imagenet_vgg11 | ||
``` | ||
- VGG16 | ||
``` | ||
pytest models/demos/functional_vgg/demo/demo.py::test_demo_imagenet_vgg16 | ||
``` | ||
|
||
NOTE: one ttnn.reshape in VGG11 and VGG16 is on host. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
import torch | ||
from loguru import logger | ||
from torchvision import models | ||
from transformers import AutoImageProcessor | ||
import pytest | ||
import tt_lib | ||
import torch.nn as nn | ||
|
||
from models.utility_functions import ( | ||
disable_compilation_reports, | ||
disable_persistent_kernel_cache, | ||
enable_persistent_kernel_cache, | ||
profiler, | ||
) | ||
import ttnn | ||
|
||
from models.demos.functional_vgg.demo_utils import get_data, get_data_loader, get_batch, preprocess | ||
from loguru import logger | ||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
from models.demos.functional_vgg.tt import ttnn_vgg | ||
|
||
vgg_model_config = { | ||
"MATH_FIDELITY": ttnn.MathFidelity.LoFi, | ||
"WEIGHTS_DTYPE": ttnn.bfloat16, | ||
"ACTIVATIONS_DTYPE": ttnn.bfloat16, | ||
} | ||
|
||
|
||
def run_vgg_imagenet_inference_vgg( | ||
batch_size, | ||
iterations, | ||
imagenet_label_dict, | ||
model_location_generator, | ||
model_class, | ||
weights, | ||
device, | ||
model_config=vgg_model_config, | ||
): | ||
disable_persistent_kernel_cache() | ||
disable_compilation_reports() | ||
profiler.clear() | ||
|
||
# Setup model | ||
torch_model = model_class(weights=weights) | ||
torch_model.to(torch.bfloat16) | ||
torch_model.eval() | ||
|
||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: torch_model, | ||
device=device, | ||
convert_to_ttnn=lambda *_: True, | ||
custom_preprocessor=ttnn_vgg.custom_preprocessor, | ||
) | ||
|
||
if model_class == models.vgg11: | ||
ttnn_model = ttnn_vgg.ttnn_vgg11 | ||
model_name = "VGG11" | ||
else: | ||
ttnn_model = ttnn_vgg.ttnn_vgg16 | ||
model_name = "VGG16" | ||
|
||
# load inputs | ||
logger.info("ImageNet-1k validation Dataset") | ||
input_loc = str(model_location_generator("ImageNet_data")) | ||
data_loader = get_data_loader(input_loc, batch_size, iterations) | ||
|
||
# load ImageNet batch by batch | ||
# and run inference | ||
correct = 0 | ||
for iter in range(iterations): | ||
predictions = [] | ||
torch_predictions = [] | ||
inputs, labels = get_batch(data_loader) | ||
torch_outputs = torch_model(inputs) | ||
permuted_inputs = torch.permute(inputs, (0, 2, 3, 1)) | ||
tt_batched_input_tensor = ttnn.from_torch(permuted_inputs, ttnn.bfloat16) | ||
tt_output = ttnn_model(device, tt_batched_input_tensor, parameters, batch_size, model_config) | ||
tt_output = ttnn.to_torch(tt_output) | ||
prediction = tt_output[:, 0, 0, :].argmax(dim=-1) | ||
torch_prediction = torch_outputs[:, :].argmax(dim=-1) | ||
for i in range(batch_size): | ||
predictions.append(imagenet_label_dict[prediction[i].item()]) | ||
torch_predictions.append(imagenet_label_dict[torch_prediction[i].item()]) | ||
logger.info( | ||
f"Iter: {iter} Sample: {i} - Expected Label: {imagenet_label_dict[labels[i]]} -- \n Torch Predicted label:{predictions[-1]} \tPredicted Label: {predictions[-1]}" | ||
) | ||
if imagenet_label_dict[labels[i]] == predictions[-1]: | ||
correct += 1 | ||
|
||
del tt_output, tt_batched_input_tensor, inputs, labels, predictions | ||
accuracy = correct / (batch_size * iterations) | ||
logger.info(f"Model {model_name}") | ||
logger.info(f"Accuracy for {batch_size}x{iterations} inputs: {accuracy}") | ||
|
||
|
||
@pytest.mark.parametrize("device_params", [{"l1_small_size": 24576}], indirect=True) | ||
@pytest.mark.parametrize( | ||
"batch_size, iterations", | ||
((1, 1),), | ||
) | ||
@pytest.mark.parametrize( | ||
"model_class, weights", | ||
[ | ||
(models.vgg11, models.VGG11_Weights.IMAGENET1K_V1), | ||
(models.vgg16, models.VGG16_Weights.IMAGENET1K_V1), | ||
], | ||
) | ||
def test_demo_imagenet_vgg( | ||
batch_size, iterations, imagenet_label_dict, model_location_generator, model_class, weights, device | ||
): | ||
run_vgg_imagenet_inference_vgg( | ||
batch_size, iterations, imagenet_label_dict, model_location_generator, model_class, weights, device | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from PIL import Image | ||
import torch | ||
import os | ||
import glob | ||
from models.sample_data.huggingface_imagenet_classes import IMAGENET2012_CLASSES | ||
from datasets import load_dataset | ||
from torchvision import models | ||
from PIL import Image | ||
import torchvision.transforms as transforms | ||
import torch | ||
|
||
|
||
class InputExample(object): | ||
def __init__(self, image, label=None): | ||
self.image = image | ||
self.label = label | ||
|
||
|
||
def get_input(image_path): | ||
img = Image.open(image_path) | ||
return img | ||
|
||
|
||
def get_label(image_path): | ||
_, image_name = image_path.rsplit("/", 1) | ||
image_name_exact, _ = image_name.rsplit(".", 1) | ||
_, label_id = image_name_exact.rsplit("_", 1) | ||
label = list(IMAGENET2012_CLASSES).index(label_id) | ||
return label | ||
|
||
|
||
preprocess = transforms.Compose( | ||
[ | ||
transforms.Resize(256), # Resize the shorter side to 256 pixels | ||
transforms.CenterCrop(224), # Crop the center to 224x224 pixels | ||
transforms.ToTensor(), # Convert the image to a tensor | ||
transforms.Normalize( # Normalize using ImageNet's mean and std | ||
mean=[0.485, 0.456, 0.406], # These are the mean values for each channel | ||
std=[0.229, 0.224, 0.225], # These are the std values for each channel | ||
), | ||
] | ||
) | ||
|
||
|
||
def get_batch(data_loader): | ||
loaded_images = next(data_loader) | ||
images = None | ||
labels = [] | ||
transform = transforms.ToTensor() | ||
resize_transform = transforms.Resize((224, 224)) | ||
for image in loaded_images: | ||
img = image.image | ||
labels.append(image.label) | ||
if img.mode == "L": | ||
img = img.convert(mode="RGB") | ||
|
||
img = preprocess(img) | ||
img = img.to(torch.bfloat16) | ||
img = img.unsqueeze(0) | ||
if images is None: | ||
images = img | ||
else: | ||
images = torch.cat((images, img), dim=0) | ||
return images, labels | ||
|
||
|
||
def get_data_loader(input_loc, batch_size, iterations): | ||
img_dir = input_loc + "/" | ||
data_path = os.path.join(img_dir, "*G") | ||
files = glob.glob(data_path) | ||
|
||
def loader(): | ||
examples = [] | ||
for f1 in files: | ||
examples.append( | ||
InputExample( | ||
image=get_input(f1), | ||
label=get_label(f1), | ||
) | ||
) | ||
if len(examples) == batch_size: | ||
yield examples | ||
del examples | ||
examples = [] | ||
|
||
def loader_hf(): | ||
examples = [] | ||
for f1 in files: | ||
examples.append( | ||
InputExample( | ||
image=f1["image"], | ||
label=f1["label"], | ||
) | ||
) | ||
if len(examples) == batch_size: | ||
yield examples | ||
del examples | ||
examples = [] | ||
|
||
if len(files) == 0: | ||
files_raw = iter(load_dataset("imagenet-1k", split="validation", use_auth_token=True, streaming=True)) | ||
files = [] | ||
sample_count = batch_size * iterations | ||
for _ in range(sample_count): | ||
files.append(next(files_raw)) | ||
del files_raw | ||
return loader_hf() | ||
|
||
return loader() | ||
|
||
|
||
def get_data(input_loc): | ||
img_dir = input_loc + "/" | ||
data_path = os.path.join(img_dir, "*G") | ||
files = sorted(glob.glob(data_path)) | ||
examples = [] | ||
for f1 in files: | ||
examples.append( | ||
InputExample( | ||
image=get_input(f1), | ||
label=get_label(f1), | ||
) | ||
) | ||
image_examples = examples | ||
|
||
return image_examples |
Oops, something went wrong.