Skip to content

Commit

Permalink
#9479: Update Mixtral perf estimates and clean mixtral unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
mtairum committed Jun 28, 2024
1 parent d3bebbb commit 085fd28
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
10 changes: 5 additions & 5 deletions models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ def forward(self, x):
@pytest.mark.parametrize(
"generation_start_pos, expected_compile_time, expected_inference_time",
(
(32, 150, 0.058), # FIXME: Perf regression (issue #9479)
(128, 150, 0.058), # FIXME: Perf regression (issue #9479)
(1024, 150, 0.058), # FIXME: Perf regression (issue #9479)
(2048, 150, 0.058), # FIXME: Perf regression (issue #9479)
(32, 150, 0.069),
(128, 150, 0.069),
(1024, 150, 0.069),
(2048, 150, 0.075),
),
)
def test_mixtral_model_perf(
Expand All @@ -61,7 +61,7 @@ def test_mixtral_model_perf(

# Can use dummy_weights=True correctness is not tested, but it is much slower
model_args = TtModelArgs(t3k_device_mesh.get_device(0), dummy_weights=False)
model_args.n_layers = 1
model_args.n_layers = 32

# Clear global profiler state before starting measurements
profiler.clear()
Expand Down
2 changes: 1 addition & 1 deletion tests/scripts/t3000/run_t3000_model_perf_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ run_t3000_mixtral_tests() {

echo "LOG_METAL: Running run_t3000_mixtral_tests"

env pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py::test_mixtral_model_perf[wormhole_b0-True-2048-150-0.058] -m "model_perf_t3000"
env pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_perf.py -m "model_perf_t3000"

# Record the end time
end_time=$(date +%s)
Expand Down
3 changes: 1 addition & 2 deletions tests/scripts/t3000/run_t3000_unit_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ run_t3000_mixtral_tests() {
pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_embedding.py
pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_moe.py
pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_decoder.py
pytest models/demos/t3000/mixtral8x7b/tests/test_mixtral_model.py::test_mixtral_model_inference[wormhole_b0-True-1-1-pcc]

# Record the end time
end_time=$(date +%s)
Expand Down Expand Up @@ -111,7 +110,7 @@ main() {
echo "Script is being sourced, not executing main function"
return 0
fi

if [[ -z "$TT_METAL_HOME" ]]; then
echo "Must provide TT_METAL_HOME in environment" 1>&2
exit 1
Expand Down

0 comments on commit 085fd28

Please sign in to comment.