Skip to content

Commit

Permalink
#13399: Add data parallel support for convnet mnist model
Browse files Browse the repository at this point in the history
  • Loading branch information
vigneshkeerthivasanx committed Nov 28, 2024
1 parent 1758447 commit b1581f6
Show file tree
Hide file tree
Showing 10 changed files with 570 additions and 48 deletions.
80 changes: 32 additions & 48 deletions models/demos/convnet_mnist/tt/convnet_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@ def convnet_mnist(
device,
):
batch_size = input_tensor.shape[0]
torch_maxpool = True

conv_config = ttnn.Conv2dConfig(
dtype=ttnn.bfloat16,
weights_dtype=ttnn.bfloat16,
math_fidelity=ttnn.MathFidelity.LoFi,
activation="",
shard_layout=ttnn.TensorMemoryLayout.HEIGHT_SHARDED,
math_approx_mode_enabled=True,
fp32_dest_acc_enabled=False,
packer_l1_accum_enabled=False,
Expand All @@ -30,6 +28,7 @@ def convnet_mnist(
reshard_if_not_optimal=True,
deallocate_activation=True,
reallocate_halo_output=True,
shard_layout=(ttnn.TensorMemoryLayout.HEIGHT_SHARDED if True else ttnn.TensorMemoryLayout.BLOCK_SHARDED),
)

x = ttnn.to_layout(input_tensor, layout=ttnn.ROW_MAJOR_LAYOUT)
Expand All @@ -53,28 +52,19 @@ def convnet_mnist(
)
x = ttnn.relu(x)

if torch_maxpool: # Can be removed once issue #12642 is resolved
x = ttnn.to_torch(x)
x = torch.reshape(x, (batch_size, 30, 30, 32))
x = torch.permute(x, (0, 3, 1, 2))
x = F.max_pool2d(x, 2)
x = torch.permute(x, (0, 2, 3, 1))
x = ttnn.from_torch(x, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT)

else:
x = ttnn.sharded_to_interleaved(x, ttnn.L1_MEMORY_CONFIG)
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.max_pool2d(
input_tensor=x,
batch_size=batch_size,
input_h=30,
input_w=30,
channels=32,
kernel_size=[2, 2],
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
)
x = ttnn.sharded_to_interleaved(x, ttnn.L1_MEMORY_CONFIG)
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.max_pool2d(
input_tensor=x,
batch_size=batch_size,
input_h=30,
input_w=30,
channels=32,
kernel_size=[2, 2],
stride=[2, 2],
padding=[0, 0],
dilation=[1, 1],
)

[x, out_height, out_width, weights_device, bias_device] = ttnn.conv2d(
input_tensor=x,
Expand All @@ -94,32 +84,26 @@ def convnet_mnist(
debug=False,
groups=1,
)

x = ttnn.relu(x)

if torch_maxpool: # Can be removed once issue #12642 is resolved
x = ttnn.to_torch(x)
x = torch.reshape(x, (batch_size, 13, 13, 64))
x = torch.permute(x, (0, 3, 1, 2))
x = F.max_pool2d(x, 2)
x = ttnn.from_torch(x, device=device, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT)

else:
x = ttnn.sharded_to_interleaved(x, ttnn.DRAM_MEMORY_CONFIG)
x = ttnn.to_layout(x, ttnn.ROW_MAJOR_LAYOUT)
x = ttnn.max_pool2d(
input_tensor=x,
batch_size=batch_size,
input_h=out_height,
input_w=out_width,
channels=x.shape[-1],
kernel_size=[2, 2],
stride=[1, 1],
padding=[0, 0],
dilation=[1, 1],
)
x = ttnn.from_device(x)
x = ttnn.reshape(x, (x.shape[0], -1))
x = ttnn.max_pool2d(
input_tensor=x,
batch_size=batch_size,
input_h=out_height,
input_w=out_width,
channels=x.shape[-1],
kernel_size=[2, 2],
stride=[2, 2],
padding=[0, 0],
dilation=[1, 1],
)

x = ttnn.sharded_to_interleaved(x, ttnn.L1_MEMORY_CONFIG)

x = ttnn.reshape(x, (batch_size, 6, 6, 64))
x = ttnn.permute(x, (0, 3, 1, 2))

x = ttnn.reshape(x, (batch_size, -1))

x = ttnn.to_device(x, device)
x = ttnn.to_layout(x, ttnn.TILE_LAYOUT)
Expand Down
24 changes: 24 additions & 0 deletions models/demos/wormhole/convnet_mnist/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Introduction

Convnet Mnist implements a Convolutions to classify handwritten digits from the MNIST dataset. The MNIST dataset contains grayscale images of handwritten digits (0-9), each of size 32x32 pixels.

# Platforms:
WH N300

## How to Run

To run the demo for digit classification using the MNIST model, follow these instructions:

- Use the following command to run the MNIST model.

```
pytest models/demos/wormhole/convnet_mnist/demo/demo.py
```

Maxpool and Softmax are used in torch inside the model.
ISSUES:
#12664 - [softmax](https://github.com/tenstorrent/tt-metal/issues/12664)
#12642 - [maxpool](https://github.com/tenstorrent/tt-metal/issues/12642)


### Owner: [vigneshkumarkeerthivasan](https://github.com/vigneshkeerthivasanx)
17 changes: 17 additions & 0 deletions models/demos/wormhole/convnet_mnist/convnet_mnist_preprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import ttnn


def custom_preprocessor(parameters, device):
parameters.conv1.bias = ttnn.to_device(parameters.conv1.bias, device)
parameters.conv1.bias = ttnn.to_device(parameters.conv1.bias, device)

parameters.fc1.weight = ttnn.to_device(parameters.fc1.weight, device)
parameters.fc1.bias = ttnn.to_device(parameters.fc1.bias, device)
parameters.fc2.weight = ttnn.to_device(parameters.fc2.weight, device)
parameters.fc2.bias = ttnn.to_device(parameters.fc2.bias, device)

return parameters
37 changes: 37 additions & 0 deletions models/demos/wormhole/convnet_mnist/convnet_mnist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import torchvision
import torchvision.transforms as transforms


def get_test_data(batch_size=64):
transform = transforms.Compose(
[
transforms.Resize((32, 32)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.05,), std=(0.05,)),
]
)

test_dataset = torchvision.datasets.MNIST(
root="./data",
train=False,
download=True,
)

batch = []
images = []
outputs = []

for i in range(batch_size):
img, output = test_dataset[i]
tensor = transform(img).unsqueeze(0)
batch.append(tensor)
images.append(img)
outputs.append(output)

batch = torch.cat(batch)
return batch, images, outputs
87 changes: 87 additions & 0 deletions models/demos/wormhole/convnet_mnist/demo/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc.

# SPDX-License-Identifier: Apache-2.0

import torch
import ttnn
import pytest

from pathlib import Path
from loguru import logger

from models.demos.wormhole.convnet_mnist.tt.convnet_mnist import (
convnet_mnist,
custom_preprocessor,
)
from models.demos.wormhole.convnet_mnist import convnet_mnist_preprocessing
from models.demos.wormhole.convnet_mnist.convnet_mnist_utils import get_test_data
from models.experimental.convnet_mnist.reference.convnet import ConvNet
from ttnn.model_preprocessing import preprocess_model_parameters
from models.utility_functions import is_wormhole_b0, skip_for_grayskull


def model_location_generator(rel_path):
internal_weka_path = Path("/mnt/MLPerf")
has_internal_weka = (internal_weka_path / "bit_error_tests").exists()

if has_internal_weka:
return Path("/mnt/MLPerf") / rel_path
else:
return Path("/opt/tt-metal-models") / rel_path


@skip_for_grayskull()
@pytest.mark.parametrize(
"batch_size",
((16),),
)
@pytest.mark.parametrize("device_params", [{"l1_small_size": 16384}], indirect=True)
def test_convnet_mnist(mesh_device, batch_size, reset_seeds):
model_path = model_location_generator("tt_dnn-models/ConvNetMNIST/")
state_dict = str(model_path / "convnet_mnist.pt")
state_dict = torch.load(state_dict)

test_input, images, output = get_test_data(batch_size)

model = ConvNet()
model.load_state_dict(state_dict)
model.eval()
torch_output = model(test_input)
batch_size = len(test_input)

inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0)
output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0)

with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)):
parameters = preprocess_model_parameters(
initialize_model=lambda: model, convert_to_ttnn=lambda *_: True, custom_preprocessor=custom_preprocessor
)
parameters = convnet_mnist_preprocessing.custom_preprocessor(parameters, device=mesh_device)

ttnn_input = torch.permute(test_input, (0, 2, 3, 1))
ttnn_input = ttnn.from_torch(
ttnn_input, dtype=ttnn.bfloat16, layout=ttnn.ROW_MAJOR_LAYOUT, mesh_mapper=inputs_mesh_mapper
)

ttnn_output = convnet_mnist(
input_tensor=ttnn_input,
device=mesh_device,
parameters=parameters,
mesh_mapper=inputs_mesh_mapper,
mesh_composer=output_mesh_composer,
)

ttnn_output = ttnn.to_torch(ttnn_output, mesh_composer=output_mesh_composer)

_, torch_predicted = torch.max(torch_output.data, -1)
_, ttnn_predicted = torch.max(ttnn_output.data, -1)

correct = 0
for i in range(batch_size):
if output[i] == ttnn_predicted[i]:
correct += 1
accuracy = correct / (batch_size)

logger.info(f" Accuracy for {batch_size} Samples : {accuracy}")
logger.info(f"torch_predicted {torch_predicted.squeeze()}")
logger.info(f"ttnn_predicted {ttnn_predicted.squeeze()}")
Loading

0 comments on commit b1581f6

Please sign in to comment.