From b101bbc2d3484d0dc1000e09486a0bc7dab4909e Mon Sep 17 00:00:00 2001 From: sabira-mcw Date: Fri, 15 Nov 2024 07:36:44 +0000 Subject: [PATCH] #13398: Update batch_size to 256 for data parallel MNIST --- models/demos/wormhole/mnist/README.md | 4 ++-- models/demos/wormhole/mnist/demo/demo.py | 2 +- models/demos/wormhole/mnist/tests/test_perf_mnist.py | 6 +++--- tests/ttnn/integration_tests/mnist/test_mnist.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/models/demos/wormhole/mnist/README.md b/models/demos/wormhole/mnist/README.md index 7d7d5ed9248..5faedbd7f5f 100644 --- a/models/demos/wormhole/mnist/README.md +++ b/models/demos/wormhole/mnist/README.md @@ -8,9 +8,9 @@ WH N150, WH N300 The MNIST model uses only fully connected linear layers to classify handwritten digits from the MNIST dataset. Despite the absence of convolutional layers, the model efficiently processes the 28x28 pixel images by flattening them into a 1D vector and passing them through multiple linear layers to predict the corresponding digit (0-9). This approach demonstrates how even simpler architectures can be applied for image classification tasks. -### Batch size: 512 +### Batch size: 256 -Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 512 +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 256 ## How to Run diff --git a/models/demos/wormhole/mnist/demo/demo.py b/models/demos/wormhole/mnist/demo/demo.py index 59e526353cd..74202a5e57a 100644 --- a/models/demos/wormhole/mnist/demo/demo.py +++ b/models/demos/wormhole/mnist/demo/demo.py @@ -68,7 +68,7 @@ def run_demo_dataset(batch_size, iterations, model_location_generator, mesh_devi @skip_for_grayskull() -@pytest.mark.parametrize("batch_size", [512]) +@pytest.mark.parametrize("batch_size", [256]) @pytest.mark.parametrize("iterations", [1]) def test_demo_dataset( batch_size, diff --git a/models/demos/wormhole/mnist/tests/test_perf_mnist.py b/models/demos/wormhole/mnist/tests/test_perf_mnist.py index 31f6648b633..5ae426b42e3 100644 --- a/models/demos/wormhole/mnist/tests/test_perf_mnist.py +++ b/models/demos/wormhole/mnist/tests/test_perf_mnist.py @@ -29,7 +29,7 @@ def get_expected_times(tt_mnist): if is_wormhole_b0(): return { - tt_mnist: (10.460, 0.0139), + tt_mnist: (10.89, 0.0162), }[tt_mnist] @@ -37,7 +37,7 @@ def get_expected_times(tt_mnist): @pytest.mark.models_performance_virtual_machine @pytest.mark.parametrize( "batch_size", - [512], + [256], ) @pytest.mark.parametrize( "tt_mnist", @@ -98,7 +98,7 @@ def test_performance_mnist(mesh_device, batch_size, tt_mnist, model_location_gen @pytest.mark.parametrize( "batch_size, expected_perf", [ - [512, 2899420.682], + [256, 1520045.60], ], ) @pytest.mark.models_device_performance_bare_metal diff --git a/tests/ttnn/integration_tests/mnist/test_mnist.py b/tests/ttnn/integration_tests/mnist/test_mnist.py index 973dcf3a244..72a46dbc61d 100644 --- a/tests/ttnn/integration_tests/mnist/test_mnist.py +++ b/tests/ttnn/integration_tests/mnist/test_mnist.py @@ -15,7 +15,7 @@ @skip_for_grayskull() @pytest.mark.parametrize( "batch_size", - [512], + [256], ) def test_mnist(mesh_device, reset_seeds, batch_size, model_location_generator): state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist"))