-
Notifications
You must be signed in to change notification settings - Fork 91
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
#12558: TTNN implementation of MNIST model
- Loading branch information
1 parent
28e3825
commit b259292
Showing
8 changed files
with
353 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# MNIST | ||
|
||
## Platforms | ||
|
||
GS E150, WH N150, WH N300 | ||
|
||
## Introduction | ||
|
||
The MNIST model uses only fully connected linear layers to classify handwritten digits from the MNIST dataset. Despite the absence of convolutional layers, the model efficiently processes the 28x28 pixel images by flattening them into a 1D vector and passing them through multiple linear layers to predict the corresponding digit (0-9). This approach demonstrates how even simpler architectures can be applied for image classification tasks. | ||
|
||
### Batch size: 8 | ||
|
||
Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 8 | ||
|
||
## How to Run | ||
|
||
To run the demo for digit classification using the MNIST model, follow these instructions: | ||
|
||
- Use the following command to run the MNIST model. | ||
``` | ||
pytest models/demos/mnist/demo/demo.py::test_demo_dataset | ||
``` | ||
|
||
## Inputs | ||
|
||
The demo receives inputs from respective dataset MNIST. | ||
|
||
## Additional Information | ||
|
||
Please note that input tensor for the reshape op used in this model is not supported on device. If you encounter issues when running the model, ensure that device has support for all required operations. | ||
|
||
### Owner: [sabira-mcw](https://github.com/sabira-mcw) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import pytest | ||
import torch | ||
import ttnn | ||
|
||
from torchvision import transforms, datasets | ||
from loguru import logger | ||
|
||
from torch.utils.data import DataLoader | ||
from models.demos.mnist.reference.mnist import MnistModel | ||
from models.demos.mnist.tt import tt_functional_mnist | ||
|
||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
|
||
|
||
def run_demo_dataset(device, batch_size, iterations, model_location_generator): | ||
transform = transforms.Compose([transforms.ToTensor()]) | ||
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) | ||
|
||
state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) | ||
model = MnistModel(state_dict) | ||
model = model.eval() | ||
|
||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: model, | ||
convert_to_ttnn=lambda *_: True, | ||
device=device, | ||
) | ||
correct = 0 | ||
for iters in range(iterations): | ||
dataloader = DataLoader(test_dataset, batch_size=batch_size) | ||
x, labels = next(iter(dataloader)) | ||
dataset_predictions = [] | ||
ttnn_predictions = [] | ||
dataset_ttnn_correct = 0 | ||
x = ttnn.from_torch(x, dtype=ttnn.bfloat16) | ||
tt_output = tt_functional_mnist.mnist(device, batch_size, x, parameters) | ||
tt_output = ttnn.to_torch(tt_output).permute(1, 2, 0, 3).squeeze(0).squeeze(0) | ||
predicted_probabilities = torch.nn.functional.softmax(tt_output, dim=1) | ||
_, predicted_label = torch.max(predicted_probabilities, 1) | ||
tt_output = tt_output | ||
for i in range(batch_size): | ||
dataset_predictions.append(labels[i]) | ||
ttnn_predictions.append(predicted_label[i]) | ||
logger.info(f"Iter: {iters} Sample {i}:") | ||
logger.info(f"Expected Label: {dataset_predictions[i]}") | ||
logger.info(f"Predicted Label: {ttnn_predictions[i]}") | ||
|
||
if dataset_predictions[i] == ttnn_predictions[i]: | ||
dataset_ttnn_correct += 1 | ||
correct += 1 | ||
dataset_ttnn_accuracy = dataset_ttnn_correct / (batch_size) | ||
logger.info( | ||
f"ImageNet Inference Accuracy for iter {iters} of {batch_size} input samples : {dataset_ttnn_accuracy}" | ||
) | ||
|
||
accuracy = correct / (batch_size * iterations) | ||
logger.info(f"ImageNet Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}") | ||
|
||
|
||
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) | ||
@pytest.mark.parametrize("batch_size", [8]) | ||
@pytest.mark.parametrize("iterations", [1]) | ||
def test_demo_dataset( | ||
device, | ||
batch_size, | ||
iterations, | ||
model_location_generator, | ||
): | ||
return run_demo_dataset( | ||
device=device, | ||
batch_size=batch_size, | ||
iterations=iterations, | ||
model_location_generator=model_location_generator, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
|
||
|
||
class MnistModel(torch.nn.Module): | ||
def __init__(self, state_dict): | ||
super().__init__() | ||
|
||
self.fc1 = torch.nn.Linear(784, 120) | ||
self.fc2 = torch.nn.Linear(120, 84) | ||
self.fc3 = torch.nn.Linear(84, 10) | ||
|
||
self.load_state_dict(state_dict) | ||
|
||
def forward(self, x): | ||
x = x.view(x.shape[0], -1) | ||
|
||
x = self.fc1(x) | ||
x = torch.nn.functional.relu(x) | ||
|
||
x = self.fc2(x) | ||
x = torch.nn.functional.relu(x) | ||
|
||
x = self.fc3(x) | ||
x = torch.nn.functional.relu(x) | ||
|
||
return torch.nn.functional.softmax(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import ttnn | ||
import time | ||
import pytest | ||
import torch | ||
from loguru import logger | ||
from torchvision import transforms, datasets | ||
from torch.utils.data import DataLoader | ||
from models.utility_functions import ( | ||
enable_persistent_kernel_cache, | ||
disable_persistent_kernel_cache, | ||
) | ||
from models.perf.perf_utils import prep_perf_report | ||
from models.demos.mnist.tt import tt_functional_mnist | ||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
from models.demos.mnist.reference.mnist import MnistModel | ||
from models.utility_functions import is_grayskull, is_wormhole_b0 | ||
from models.perf.device_perf_utils import run_device_perf, check_device_perf, prep_device_perf_report | ||
|
||
transform = transforms.Compose([transforms.ToTensor()]) | ||
test_dataset = datasets.MNIST(root="./data", train=False, transform=None, download=True) | ||
|
||
|
||
def get_expected_times(functional_mnist): | ||
if is_grayskull(): | ||
return { | ||
tt_functional_mnist: (2, 0.0041), | ||
}[functional_mnist] | ||
elif is_wormhole_b0(): | ||
return { | ||
tt_functional_mnist: (2.8, 0.004), | ||
}[functional_mnist] | ||
|
||
|
||
@pytest.mark.models_performance_bare_metal | ||
@pytest.mark.models_performance_virtual_machine | ||
@pytest.mark.parametrize( | ||
"batch_size", | ||
[8], | ||
) | ||
@pytest.mark.parametrize( | ||
"functional_mnist", | ||
[tt_functional_mnist], | ||
) | ||
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) | ||
def test_performance_mnist(device, batch_size, functional_mnist, model_location_generator, reset_seeds): | ||
state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) | ||
model = MnistModel(state_dict) | ||
model = model.eval() | ||
disable_persistent_kernel_cache() | ||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: model, | ||
convert_to_ttnn=lambda *_: True, | ||
device=device, | ||
) | ||
transform = transforms.Compose([transforms.ToTensor()]) | ||
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) | ||
dataloader = DataLoader(test_dataset, batch_size=batch_size) | ||
x, labels = next(iter(dataloader)) | ||
|
||
test_input = ttnn.from_torch(x, dtype=ttnn.bfloat16) | ||
durations = [] | ||
for _ in range(2): | ||
start = time.time() | ||
|
||
ttnn_output = tt_functional_mnist.mnist( | ||
device=device, | ||
x=test_input, | ||
batch_size=batch_size, | ||
parameters=parameters, | ||
) | ||
end = time.time() | ||
durations.append(end - start) | ||
|
||
inference_and_compile_time, *inference_times = durations | ||
average_inference_time = sum(inference_times) / len(inference_times) | ||
expected_compile_time, expected_inference_time = get_expected_times(functional_mnist) | ||
|
||
prep_perf_report( | ||
model_name="MNIST", | ||
batch_size=batch_size, | ||
inference_and_compile_time=inference_and_compile_time, | ||
inference_time=average_inference_time, | ||
expected_compile_time=expected_compile_time, | ||
expected_inference_time=expected_inference_time, | ||
comments="", | ||
inference_time_cpu=0.0, | ||
) | ||
|
||
logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}") | ||
logger.info(f"Inference time: {average_inference_time}") | ||
logger.info(f"Inference times: {inference_times}") | ||
logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}") | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"batch_size", | ||
[8], | ||
) | ||
@pytest.mark.models_device_performance_bare_metal | ||
def test_perf_device_bare_metal(batch_size, reset_seeds): | ||
subdir = "ttnn_mnist" | ||
num_iterations = 1 | ||
margin = 0.03 | ||
if is_grayskull(): | ||
expected_perf = 44041.949957334364 | ||
elif is_wormhole_b0(): | ||
expected_perf = 51071 | ||
|
||
command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist.py::test_mnist" | ||
cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] | ||
|
||
inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" | ||
expected_perf_cols = {inference_time_key: expected_perf} | ||
|
||
post_processed_results = run_device_perf(command, subdir, num_iterations, cols, batch_size) | ||
expected_results = check_device_perf(post_processed_results, margin, expected_perf_cols) | ||
prep_device_perf_report( | ||
model_name=f"tt_functional_mnist{batch_size}", | ||
batch_size=batch_size, | ||
post_processed_results=post_processed_results, | ||
expected_results=expected_results, | ||
comments="", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import ttnn | ||
import torch | ||
|
||
|
||
def mnist(device, batch_size, x, parameters): | ||
x = ttnn.reshape(x, (x.shape[0], 1, 1, 784)) | ||
|
||
x = ttnn.to_device(x, device=device) | ||
x = ttnn.to_layout(x, layout=ttnn.TILE_LAYOUT) | ||
x = ttnn.linear( | ||
x, parameters.fc1.weight, bias=parameters.fc1.bias, memory_config=ttnn.L1_MEMORY_CONFIG, activation="relu" | ||
) | ||
x = ttnn.linear( | ||
x, | ||
parameters.fc2.weight, | ||
bias=parameters.fc2.bias, | ||
memory_config=ttnn.L1_MEMORY_CONFIG, | ||
activation="relu", | ||
) | ||
x = ttnn.linear( | ||
x, | ||
parameters.fc3.weight, | ||
bias=parameters.fc3.bias, | ||
memory_config=ttnn.L1_MEMORY_CONFIG, | ||
activation="relu", | ||
) | ||
|
||
x = ttnn.softmax(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. | ||
|
||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import torch | ||
import ttnn | ||
import pytest | ||
from tests.ttnn.utils_for_testing import assert_with_pcc | ||
from ttnn.model_preprocessing import preprocess_model_parameters | ||
from models.demos.mnist.reference.mnist import MnistModel | ||
from models.demos.mnist.tt import tt_functional_mnist | ||
from torch.utils.data import DataLoader | ||
from torchvision import transforms, datasets | ||
|
||
|
||
@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True) | ||
@pytest.mark.parametrize( | ||
"batch_size", | ||
[4], | ||
) | ||
def test_mnist(reset_seeds, device, batch_size, model_location_generator): | ||
state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) | ||
model = MnistModel(state_dict) | ||
model = model.eval() | ||
transform = transforms.Compose([transforms.ToTensor()]) | ||
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) | ||
dataloader = DataLoader(test_dataset, batch_size=batch_size) | ||
|
||
x, labels = next(iter(dataloader)) | ||
|
||
torch_output = model(x) | ||
|
||
parameters = preprocess_model_parameters( | ||
initialize_model=lambda: model, | ||
convert_to_ttnn=lambda *_: True, | ||
device=device, | ||
) | ||
x = ttnn.from_torch(x, dtype=ttnn.bfloat16) | ||
|
||
tt_output = tt_functional_mnist.mnist(device, batch_size, x, parameters) | ||
|
||
tt_output = ttnn.to_torch(tt_output).permute(1, 2, 0, 3).squeeze(0).squeeze(0) | ||
|
||
assert_with_pcc(torch_output, tt_output, 0.99) |