Skip to content

Commit

Permalink
#13329: Run Mnist perf test for 100 iter and assert accuracy check in…
Browse files Browse the repository at this point in the history
… demo
  • Loading branch information
sabira-mcw committed Nov 22, 2024
1 parent 8b54a20 commit 78c4871
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
2 changes: 0 additions & 2 deletions models/demos/mnist/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,3 @@ The demo receives inputs from respective dataset MNIST.
## Additional Information

If you encounter issues when running the model, ensure that device has support for all required operations.

### Owner: [sabira-mcw](https://github.com/sabira-mcw)
9 changes: 8 additions & 1 deletion models/demos/mnist/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def run_demo_dataset(device, batch_size, iterations, model_location_generator):
ttnn_predictions.append(predicted_label[i])
logger.info(f"Iter: {iters} Sample {i}:")
logger.info(f"Expected Label: {dataset_predictions[i]}")
logger.info(f"Predicted Label: {ttnn_predictions[i]}")
logger.info(f"TT Predicted Label: {ttnn_predictions[i]}")

if dataset_predictions[i] == ttnn_predictions[i]:
dataset_ttnn_correct += 1
Expand All @@ -61,6 +61,7 @@ def run_demo_dataset(device, batch_size, iterations, model_location_generator):

accuracy = correct / (batch_size * iterations)
logger.info(f"ImageNet Inference Accuracy for {batch_size}x{iterations} Samples : {accuracy}")
assert accuracy >= 0.96875, f"Expected accuracy : {0.96875} Actual accuracy: {accuracy}"


@pytest.mark.parametrize("device_params", [{"l1_small_size": 32768}], indirect=True)
Expand All @@ -80,3 +81,9 @@ def test_demo_dataset(
iterations=iterations,
model_location_generator=model_location_generator,
)


# 0.92919921875 - 2048
# 0.943359375 - 1024
# 0.96875 - 256, 64, 128
# 0.875 - 8
8 changes: 4 additions & 4 deletions models/demos/mnist/tests/test_perf_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@
def get_expected_times(tt_mnist):
if is_grayskull():
return {
tt_mnist: (3.54, 0.00905),
tt_mnist: (3.54, 0.005),
}[tt_mnist]
elif is_wormhole_b0():
return {
tt_mnist: (8.14, 0.0081),
tt_mnist: (3.89, 0.005),
}[tt_mnist]


Expand Down Expand Up @@ -63,7 +63,7 @@ def test_performance_mnist(device, batch_size, tt_mnist, model_location_generato

test_input = ttnn.from_torch(x, dtype=ttnn.bfloat16, device=device)
durations = []
for _ in range(2):
for _ in range(100):
start = time.time()

ttnn_output = tt_mnist.mnist(
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_perf_device_bare_metal(batch_size, reset_seeds):
num_iterations = 1
margin = 0.03
if is_grayskull():
expected_perf = 588743.96
expected_perf = 653017.5
elif is_wormhole_b0():
expected_perf = 1338730.2

Expand Down

0 comments on commit 78c4871

Please sign in to comment.