From c507b4d22dba8b50ef78da1a6587eb7c7a9fe07f Mon Sep 17 00:00:00 2001 From: mtairum Date: Fri, 28 Jun 2024 21:01:08 +0000 Subject: [PATCH] #9479: Update Mixtral perf estimates and clean mixtral unit test --- .../demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py | 10 +++++----- tests/scripts/t3000/run_t3000_model_perf_tests.sh | 2 +- tests/scripts/t3000/run_t3000_unit_tests.sh | 3 +-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py index 3732cb21f0ae..0d9717160338 100644 --- a/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py +++ b/models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py @@ -43,10 +43,10 @@ def forward(self, x): @pytest.mark.parametrize( "generation_start_pos, expected_compile_time, expected_inference_time", ( - (32, 150, 0.058), # FIXME: Perf regression (issue #9479) - (128, 150, 0.058), # FIXME: Perf regression (issue #9479) - (1024, 150, 0.058), # FIXME: Perf regression (issue #9479) - (2048, 150, 0.058), # FIXME: Perf regression (issue #9479) + (32, 150, 0.075), + (128, 150, 0.075), + (1024, 150, 0.075), + (2048, 150, 0.075), ), ) def test_mixtral_model_perf( @@ -61,7 +61,7 @@ def test_mixtral_model_perf( # Can use dummy_weights=True correctness is not tested, but it is much slower model_args = TtModelArgs(t3k_device_mesh.get_device(0), dummy_weights=False) - model_args.n_layers = 1 + model_args.n_layers = 32 # Clear global profiler state before starting measurements profiler.clear() diff --git a/tests/scripts/t3000/run_t3000_model_perf_tests.sh b/tests/scripts/t3000/run_t3000_model_perf_tests.sh index 2cf1dc5dcc4e..6140b9efeafd 100755 --- a/tests/scripts/t3000/run_t3000_model_perf_tests.sh +++ b/tests/scripts/t3000/run_t3000_model_perf_tests.sh @@ -22,7 +22,7 @@ run_t3000_mixtral_tests() { echo "LOG_METAL: Running run_t3000_mixtral_tests" - env pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py::test_mixtral_model_perf[wormhole_b0-True-2048-150-0.058] -m "model_perf_t3000" + env pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py -m "model_perf_t3000" # Record the end time end_time=$(date +%s) diff --git a/tests/scripts/t3000/run_t3000_unit_tests.sh b/tests/scripts/t3000/run_t3000_unit_tests.sh index ea092261a138..a8019137642b 100755 --- a/tests/scripts/t3000/run_t3000_unit_tests.sh +++ b/tests/scripts/t3000/run_t3000_unit_tests.sh @@ -80,7 +80,6 @@ run_t3000_mixtral_tests() { pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_embedding.py pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py - pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py::test_mixtral_model_inference[wormhole_b0-True-1-1-pcc] # Record the end time end_time=$(date +%s) @@ -111,7 +110,7 @@ main() { echo "Script is being sourced, not executing main function" return 0 fi - + if [[ -z "$TT_METAL_HOME" ]]; then echo "Must provide TT_METAL_HOME in environment" 1>&2 exit 1