From af5cc2f64eca4304436a33e67c800c6eecefe075 Mon Sep 17 00:00:00 2001 From: sabira-mcw Date: Fri, 15 Nov 2024 07:36:44 +0000 Subject: [PATCH] #13398: Update batch_size to 256 for data parallel MNIST --- models/demos/wormhole/mnist/README.md | 4 +-- models/demos/wormhole/mnist/demo/demo.py | 8 +++--- ...st_perf_mnist.py => test_perf_mnist_wh.py} | 27 +++++++++++-------- tests/scripts/run_performance.sh | 3 ++- .../mnist/{test_mnist.py => test_mnist_wh.py} | 6 +++-- 5 files changed, 29 insertions(+), 19 deletions(-) rename models/demos/wormhole/mnist/tests/{test_perf_mnist.py => test_perf_mnist_wh.py} (83%) rename tests/ttnn/integration_tests/mnist/{test_mnist.py => test_mnist_wh.py} (91%) diff --git a/models/demos/wormhole/mnist/README.md b/models/demos/wormhole/mnist/README.md index 7d7d5ed92489..5faedbd7f5f9 100644 --- a/models/demos/wormhole/mnist/README.md +++ b/models/demos/wormhole/mnist/README.md @@ -8,9 +8,9 @@ WH N150, WH N300 The MNIST model uses only fully connected linear layers to classify handwritten digits from the MNIST dataset. Despite the absence of convolutional layers, the model efficiently processes the 28x28 pixel images by flattening them into a 1D vector and passing them through multiple linear layers to predict the corresponding digit (0-9). This approach demonstrates how even simpler architectures can be applied for image classification tasks. -### Batch size: 512 +### Batch size: 256 -Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 512 +Batch Size determines the number of input sequences processed simultaneously during training or inference, impacting computational efficiency and memory usage. It's recommended to set the batch_size to 256 ## How to Run diff --git a/models/demos/wormhole/mnist/demo/demo.py b/models/demos/wormhole/mnist/demo/demo.py index 59e526353cde..bc6fc5763340 100644 --- a/models/demos/wormhole/mnist/demo/demo.py +++ b/models/demos/wormhole/mnist/demo/demo.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from models.demos.wormhole.mnist.reference.mnist import MnistModel from models.demos.wormhole.mnist.tt import tt_mnist - +from models.utility_functions import disable_persistent_kernel_cache from ttnn.model_preprocessing import preprocess_model_parameters from models.utility_functions import is_wormhole_b0, skip_for_grayskull @@ -25,7 +25,8 @@ def run_demo_dataset(batch_size, iterations, model_location_generator, mesh_devi state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) model = MnistModel(state_dict) model = model.eval() - + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size if mesh_device_flag else batch_size / 2 inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): @@ -68,7 +69,7 @@ def run_demo_dataset(batch_size, iterations, model_location_generator, mesh_devi @skip_for_grayskull() -@pytest.mark.parametrize("batch_size", [512]) +@pytest.mark.parametrize("batch_size", [256]) @pytest.mark.parametrize("iterations", [1]) def test_demo_dataset( batch_size, @@ -76,6 +77,7 @@ def test_demo_dataset( model_location_generator, mesh_device, ): + disable_persistent_kernel_cache() return run_demo_dataset( batch_size=batch_size, iterations=iterations, diff --git a/models/demos/wormhole/mnist/tests/test_perf_mnist.py b/models/demos/wormhole/mnist/tests/test_perf_mnist_wh.py similarity index 83% rename from models/demos/wormhole/mnist/tests/test_perf_mnist.py rename to models/demos/wormhole/mnist/tests/test_perf_mnist_wh.py index 31f6648b6331..541f4c5baa95 100644 --- a/models/demos/wormhole/mnist/tests/test_perf_mnist.py +++ b/models/demos/wormhole/mnist/tests/test_perf_mnist_wh.py @@ -29,7 +29,7 @@ def get_expected_times(tt_mnist): if is_wormhole_b0(): return { - tt_mnist: (10.460, 0.0139), + tt_mnist: (10.89, 0.017), }[tt_mnist] @@ -37,7 +37,7 @@ def get_expected_times(tt_mnist): @pytest.mark.models_performance_virtual_machine @pytest.mark.parametrize( "batch_size", - [512], + [256], ) @pytest.mark.parametrize( "tt_mnist", @@ -52,7 +52,8 @@ def test_performance_mnist(mesh_device, batch_size, tt_mnist, model_location_gen test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True) dataloader = DataLoader(test_dataset, batch_size=batch_size) x, labels = next(iter(dataloader)) - + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size if mesh_device_flag else batch_size / 2 inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) @@ -71,34 +72,38 @@ def test_performance_mnist(mesh_device, batch_size, tt_mnist, model_location_gen ttnn_output = tt_mnist.mnist(mesh_device, batch_size, x, parameters) end = time.time() durations.append(end - start) - # enable_persistent_kernel_cache() + enable_persistent_kernel_cache() inference_and_compile_time, *inference_times = durations - average_inference_time = sum(inference_times) / len(inference_times) + inference_time = sum(inference_times) / len(inference_times) expected_compile_time, expected_inference_time = get_expected_times(tt_mnist) prep_perf_report( model_name="MNIST", batch_size=batch_size, inference_and_compile_time=inference_and_compile_time, - inference_time=average_inference_time, + inference_time=inference_time, expected_compile_time=expected_compile_time, expected_inference_time=expected_inference_time, comments="", inference_time_cpu=0.0, ) - logger.info(f"Compile time: {inference_and_compile_time - average_inference_time}") - logger.info(f"Inference time: {average_inference_time}") + logger.info(f"Compile time: {inference_and_compile_time - inference_time}") + logger.info(f"Inference time: {inference_time}") logger.info(f"Inference times: {inference_times}") - logger.info(f"Sample(s) per second: {1 / average_inference_time * batch_size}") + logger.info(f"Sample(s) per second: {1 / inference_time * batch_size}") + assert ( + inference_time < expected_inference_time + ), f"Expected inference time: {expected_inference_time} Actual inference time: {inference_time}" + logger.info("Exit MNIST perf test") @skip_for_grayskull() @pytest.mark.parametrize( "batch_size, expected_perf", [ - [512, 2899420.682], + [256, 1520045.60], ], ) @pytest.mark.models_device_performance_bare_metal @@ -107,7 +112,7 @@ def test_perf_device_bare_metal(batch_size, expected_perf): num_iterations = 1 margin = 0.03 - command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist.py" + command = f"pytest tests/ttnn/integration_tests/mnist/test_mnist_wh.py" cols = ["DEVICE FW", "DEVICE KERNEL", "DEVICE BRISC KERNEL"] inference_time_key = "AVG DEVICE KERNEL SAMPLES/S" diff --git a/tests/scripts/run_performance.sh b/tests/scripts/run_performance.sh index a4cf4bad30bc..73d0fe6e5f83 100755 --- a/tests/scripts/run_performance.sh +++ b/tests/scripts/run_performance.sh @@ -113,6 +113,8 @@ run_device_perf_models() { fi if [ "$tt_arch" == "wormhole_b0" ]; then + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yam pytets models/demos/wormhole/mnist/tests -m $test_marker + env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/wormhole/resnet50/tests -m $test_marker env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/experimental/functional_unet/tests/test_unet_perf.py -m $test_marker @@ -123,7 +125,6 @@ run_device_perf_models() { env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml pytest models/demos/falcon7b_common/tests -m $test_marker - env WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yam pytets models/demos/wormhole/mnist/tests/test_perf_mnist.py::test_performance_mnist -m $test_marker fi ## Merge all the generated reports diff --git a/tests/ttnn/integration_tests/mnist/test_mnist.py b/tests/ttnn/integration_tests/mnist/test_mnist_wh.py similarity index 91% rename from tests/ttnn/integration_tests/mnist/test_mnist.py rename to tests/ttnn/integration_tests/mnist/test_mnist_wh.py index 973dcf3a244b..c94eae22c3b7 100644 --- a/tests/ttnn/integration_tests/mnist/test_mnist.py +++ b/tests/ttnn/integration_tests/mnist/test_mnist_wh.py @@ -15,7 +15,7 @@ @skip_for_grayskull() @pytest.mark.parametrize( "batch_size", - [512], + [256], ) def test_mnist(mesh_device, reset_seeds, batch_size, model_location_generator): state_dict = torch.load(model_location_generator("mnist_model.pt", model_subdir="mnist")) @@ -26,10 +26,12 @@ def test_mnist(mesh_device, reset_seeds, batch_size, model_location_generator): dataloader = DataLoader(test_dataset, batch_size=batch_size) x, labels = next(iter(dataloader)) torch_output = model(x) + mesh_device_flag = is_wormhole_b0() and ttnn.GetNumAvailableDevices() == 2 + batch_size = batch_size if mesh_device_flag else batch_size / 2 inputs_mesh_mapper = ttnn.ShardTensorToMesh(mesh_device, dim=0) weights_mesh_mapper = ttnn.ReplicateTensorToMesh(mesh_device) output_mesh_composer = ttnn.ConcatMeshToTensor(mesh_device, dim=0) - mesh_device_flag = True + with ttnn.distribute(ttnn.ReplicateTensorToMesh(mesh_device)): parameters = preprocess_model_parameters(initialize_model=lambda: model, device=mesh_device)