From 490c8c22a89ce805ecd79949d4749e34a1825e60 Mon Sep 17 00:00:00 2001 From: sabira-mcw Date: Fri, 15 Nov 2024 07:21:58 +0000 Subject: [PATCH] #13398: Update batch_size to 512 for data parallel MNIST --- models/demos/wormhole/mnist/README.md | 4 ++-- models/demos/wormhole/mnist/demo/demo.py | 2 +- models/demos/wormhole/mnist/tests/test_perf_mnist.py | 6 +++--- tests/ttnn/integration_tests/mnist/test_mnist.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/models/demos/wormhole/mnist/README.md b/models/demos/wormhole/mnist/README.md index 4c62dd3c35dc..7d7d5ed92489 100644 --- a/models/demos/wormhole/mnist/README.md +++ b/models/demos/wormhole/mnist/README.md @@ -8,9 +8,9 @@ WH N150, WH N300 The MNIST model uses only fully connected linear layers to classify handwritten digits from the MNIST dataset. Despite the absence of convolutional layers, the model efficiently processes the 28x28 pixel images by flattening them into a 1D vector and passing them through multiple linear layers to predict the corresponding digit (0-9). This approach demonstrates how even simpler architectures can be applied for image classification tasks. -### Batch size: 4 +### Batch size: 512 -Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 4 +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 512 ## How to Run diff --git a/models/demos/wormhole/mnist/demo/demo.py b/models/demos/wormhole/mnist/demo/demo.py index 0e15557a547e..59e526353cde 100644 --- a/models/demos/wormhole/mnist/demo/demo.py +++ b/models/demos/wormhole/mnist/demo/demo.py @@ -68,7 +68,7 @@ def run_demo_dataset(batch_size, iterations, model_location_generator, mesh_devi @skip_for_grayskull() -@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("batch_size", [512]) @pytest.mark.parametrize("iterations", [1]) def test_demo_dataset( batch_size, diff --git a/models/demos/wormhole/mnist/tests/test_perf_mnist.py b/models/demos/wormhole/mnist/tests/test_perf_mnist.py index a4e06d173b42..31f6648b6331 100644 --- a/models/demos/wormhole/mnist/tests/test_perf_mnist.py +++ b/models/demos/wormhole/mnist/tests/test_perf_mnist.py @@ -29,7 +29,7 @@ def get_expected_times(tt_mnist): if is_wormhole_b0(): return { - tt_mnist: (7.71, 0.0105), + tt_mnist: (10.460, 0.0139), }[tt_mnist] @@ -37,7 +37,7 @@ def get_expected_times(tt_mnist): @pytest.mark.models_performance_virtual_machine @pytest.mark.parametrize( "batch_size", - [32], + [512], ) @pytest.mark.parametrize( "tt_mnist", @@ -98,7 +98,7 @@ def test_performance_mnist(mesh_device, batch_size, tt_mnist, model_location_gen @pytest.mark.parametrize( "batch_size, expected_perf", [ - [32, 143288.92], + [512, 2899420.682], ], ) @pytest.mark.models_device_performance_bare_metal diff --git a/tests/ttnn/integration_tests/mnist/test_mnist.py b/tests/ttnn/integration_tests/mnist/test_mnist.py index ca6889146931..973dcf3a244b 100644 --- a/tests/ttnn/integration_tests/mnist/test_mnist.py +++ b/tests/ttnn/integration_tests/mnist/test_mnist.py @@ -15,7 +15,7 @@ @skip_for_grayskull() @pytest.mark.parametrize( "batch_size", - [32], + [512], ) def test_mnist(mesh_device, reset_seeds, batch_size, model_location_generator): state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist"))