diff --git a/models/demos/llama3/PERF.md b/models/demos/llama3/PERF.md index d8e6dbb95ce1..dd060a14c1c7 100644 --- a/models/demos/llama3/PERF.md +++ b/models/demos/llama3/PERF.md @@ -20,6 +20,7 @@ This configuration uses bfp4 MLP FF1+FF3 for all models. | 8b | N300 | 84 | 98 | 38.6 | | 8b | T3K | 84 | 98 | 52.6 | | 11b | N300 | 86 | 97 | 38.6 | +| 11b | T3K | 84 | 98 | 52.6 | | 70b | T3K | 95 | 100 | 14.3 | ## LlamaOptimizations.accuracy @@ -38,4 +39,5 @@ This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model. | 8b | N300 | 90 | 98 | 34.1 | | 8b | T3K | 88 | 97 | 49.9 | | 11b | N300 | 90 | 97 | 33.8 | +| 11b | T3K | 88 | 97 | 52.6 | | 70b | T3K | 95 | 100 | 14.5 | diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index e5f81a8840e4..8ce5648a8a82 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -34,8 +34,6 @@ def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: Ll sections = content.split("## ") target_section = next(s for s in sections if s.startswith(f"LlamaOptimizations.{optimizations.__name__}\n")) - print(target_section) - # Parse the table and find the row for our model and device rows = [ line.split("|")[1:] # Each row starts with a separator @@ -87,13 +85,6 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac model_args = TtModelArgs(mesh_device, optimizations=optimizations) tokenizer = Tokenizer(model_args.tokenizer_path) - # Get accuracy thresholds from PERF.md - min_top1_acc, min_top5_acc = get_accuracy_thresholds( - model_args.model_name, - model_args.device_name, - optimizations, - ) - # Load state_dict for TT model logger.info("Loading weights...") state_dict = model_args.load_state_dict() @@ -320,6 +311,13 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac true_word = sanitize(tokenizer.decode([true_token])) logger.info(f"{error['position']}: {context}[{incorrect}] != [{expected}], true: [{true_word}]") + # Get accuracy thresholds from PERF.md + min_top1_acc, min_top5_acc = get_accuracy_thresholds( + model_args.model_name, + model_args.device_name, + optimizations, + ) + logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%") assert total_top1_acc > min_top1_acc, f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >{min_top1_acc}%)" assert total_top5_acc > min_top5_acc, f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >{min_top5_acc}%)"