diff --git a/models/demos/llama3/PERF.md b/models/demos/llama3/PERF.md index 25835100b2b4..d8e6dbb95ce1 100644 --- a/models/demos/llama3/PERF.md +++ b/models/demos/llama3/PERF.md @@ -2,7 +2,11 @@ Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected from [tests/test_llama_accuracy.py](tests/test_llama_accuracy.py). You can generate this table by running these tests with the `lt` tool (tell it to run `accuracy,demo`) and pressing `m` whilst in the results section to export to markdown. -4-bit MLP: +Note that `test_llama_accuracy.py` parses the below to determine expected values. + +## LlamaOptimizations.performance + +This configuration uses bfp4 MLP FF1+FF3 for all models. | Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | |-------|--------|-----------|-----------|---------------| @@ -18,7 +22,9 @@ Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected f | 11b | N300 | 86 | 97 | 38.6 | | 70b | T3K | 95 | 100 | 14.3 | -Mixed-bit MLP (main): +## LlamaOptimizations.accuracy + +This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model. | Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | |-------|--------|-----------|-----------|---------------| diff --git a/models/demos/llama3/README.md b/models/demos/llama3/README.md index 43ad9556ab1a..aff98ce5239e 100644 --- a/models/demos/llama3/README.md +++ b/models/demos/llama3/README.md @@ -85,6 +85,13 @@ pytest models/demos/llama3/demo/demo.py -k 'instruct and 1_batch' pytest models/demos/llama3/demo/demo.py -k 'general and 2_batch' ``` +By default we run the models in `LlamaOptimizations.performance` mode. You can override this by setting the `optimizations` argument in the demo. To compare the two on a long prompt, you can run: + +``` +pytest models/demos/llama3/demo/demo.py -k 'long-performance' +pytest models/demos/llama3/demo/demo.py -k 'long-accuracy' +``` + ### Expected performance and accuracy -See [PERF.md](PERF.md) for expected performance and accuracy. +See [PERF.md](PERF.md) for expected performance and accuracy across different configurations. diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index 837e03a5dbc8..94fd07d35e5a 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -28,6 +28,7 @@ from models.perf.benchmarking_utils import BenchmarkProfiler from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf +from models.demos.llama3.tt.model_config import LlamaOptimizations def load_and_cache_context(context_url, cache_dir): @@ -152,7 +153,15 @@ def preprocess_inputs_prefill( def run_llama3_demo( - user_input, batch_size, single_layer, mesh_device, instruct_mode, is_ci_env, num_batches, print_to_file + user_input, + batch_size, + single_layer, + mesh_device, + instruct_mode, + is_ci_env, + num_batches, + print_to_file, + optimizations, ): # Creat batch output file timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") @@ -189,7 +198,7 @@ def run_llama3_demo( batch_prompts.append([input_prompts[(j + i) % len(input_prompts)] for j in range(len(input_prompts))]) # Load model args, weights, and tokenizer - model_args = TtModelArgs(mesh_device, instruct=instruct_mode) + model_args = TtModelArgs(mesh_device, instruct=instruct_mode, optimizations=optimizations) tokenizer = Tokenizer(model_args.tokenizer_path) if single_layer: @@ -700,21 +709,41 @@ def run_llama3_demo( @pytest.mark.parametrize( - "input_prompts, instruct_weights, num_batches, single_layer", + "input_prompts, instruct_weights, num_batches, single_layer, optimizations", [ - ("models/demos/llama3/demo/input_data_prefill_128.json", False, 1, False), - ("models/demos/llama3/demo/input_data_prefill_128.json", False, 2, False), - ("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 1, False), - ("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 2, False), - ("models/demos/llama3/demo/input_data_long.json", True, 1, False), - ("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 1, True), + ("models/demos/llama3/demo/input_data_prefill_128.json", False, 1, False, LlamaOptimizations.performance), + ("models/demos/llama3/demo/input_data_prefill_128.json", False, 2, False, LlamaOptimizations.performance), + ( + "models/demos/llama3/demo/input_data_questions_prefill_128.json", + True, + 1, + False, + LlamaOptimizations.performance, + ), + ( + "models/demos/llama3/demo/input_data_questions_prefill_128.json", + True, + 2, + False, + LlamaOptimizations.performance, + ), + ("models/demos/llama3/demo/input_data_long.json", True, 1, False, LlamaOptimizations.performance), + ("models/demos/llama3/demo/input_data_long.json", True, 1, False, LlamaOptimizations.accuracy), + ( + "models/demos/llama3/demo/input_data_questions_prefill_128.json", + True, + 1, + True, + LlamaOptimizations.performance, + ), ], ids=[ "general_weights-1_batch", "general_weights-2_batch", "instruct_weights-1_batch", "instruct_weights-2_batch", - "instruct_weights-long", + "instruct_weights-long-performance", + "instruct_weights-long-accuracy", "single_layer", ], ) @@ -729,7 +758,15 @@ def run_llama3_demo( indirect=True, ) def test_llama_demo( - mesh_device, use_program_cache, input_prompts, instruct_weights, is_ci_env, num_batches, single_layer, reset_seeds + mesh_device, + use_program_cache, + input_prompts, + instruct_weights, + is_ci_env, + num_batches, + single_layer, + optimizations, + reset_seeds, ): if is_ci_env and (instruct_weights == False or "long" in input_prompts or single_layer == True): pytest.skip("CI demo test only runs instruct weights to reduce CI pipeline load (both are supported)") @@ -745,4 +782,5 @@ def test_llama_demo( is_ci_env=is_ci_env, num_batches=num_batches, print_to_file=False, + optimizations=optimizations, ) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt index 75e04befb9ce..23a8fb33f150 100755 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -733,9 +733,9 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): "decoder": "pytest models/demos/llama3/tests/test_llama_decoder.py", "decoder-prefill": "pytest models/demos/llama3/tests/test_llama_decoder_prefill.py", "lm-head": "pytest models/demos/llama3/tests/test_lm_head.py", - "model": "pytest models/demos/llama3/tests/test_llama_model.py -k full", - "model-quick": "pytest models/demos/llama3/tests/test_llama_model.py -k quick", - "model-prefill": "pytest models/demos/llama3/tests/test_llama_model_prefill.py", + "model": "pytest models/demos/llama3/tests/test_llama_model.py -k performance-full", + "model-quick": "pytest models/demos/llama3/tests/test_llama_model.py -k performance-quick", + "model-prefill": "pytest models/demos/llama3/tests/test_llama_model_prefill.py -k performance", "vision-mlp": "pytest models/demos/llama3/tests/multimodal/test_llama_image_mlp.py", "vision-attn": "pytest models/demos/llama3/tests/multimodal/test_llama_image_attention.py", "vision-block": "pytest models/demos/llama3/tests/multimodal/test_llama_image_block.py", @@ -750,7 +750,7 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update): "vision-text-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py", "vision-vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py", "perf": "pytest models/demos/llama3/tests/test_llama_perf.py -k 1024", - "accuracy": "pytest models/demos/llama3/tests/test_llama_accuracy.py", + "accuracy": "pytest models/demos/llama3/tests/test_llama_accuracy.py -k performance", } # Check if the command is a shortcut and replace it if necessary diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index acdcc2579016..e5f81a8840e4 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -14,9 +14,48 @@ HostEmbedding, ) from models.demos.llama3.tt.llama_model import TtTransformer -from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.demo.demo import preprocess_inputs_prefill +from pathlib import Path + + +def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: LlamaOptimizations): + """Parse accuracy thresholds from PERF.md for the given model, optimization mode, and device.""" + # Get model size (e.g., "1b", "3b", etc.) + model_size = model_name.split("-")[1].lower() + + # Read PERF.md + perf_file = Path(__file__).parent.parent / "PERF.md" + with open(perf_file, "r") as f: + content = f.read() + + # Split into sections based on optimization mode + sections = content.split("## ") + target_section = next(s for s in sections if s.startswith(f"LlamaOptimizations.{optimizations.__name__}\n")) + + print(target_section) + + # Parse the table and find the row for our model and device + rows = [ + line.split("|")[1:] # Each row starts with a separator + for line in target_section.split("\n") + if f"| {model_size} | {device_name} |" in line + ] + if not rows: + raise ValueError( + f"Could not find accuracy data for {model_size} on {device_name} in {optimizations.__name__} mode" + ) + + assert ( + len(rows) == 1 + ), f"Found multiple rows for {model_size} on {device_name} in {optimizations.__name__} mode in PERF.md" + row = rows[0] + top1_acc = float(row[2].strip()) + top5_acc = float(row[3].strip()) + + # Allow for rounding + return top1_acc - 0.5, top5_acc - 0.5 @torch.no_grad() @@ -32,17 +71,29 @@ ], indirect=True, ) -def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cache, reset_seeds): +@pytest.mark.parametrize( + "optimizations", + [ + pytest.param(LlamaOptimizations.accuracy, id="accuracy"), + pytest.param(LlamaOptimizations.performance, id="performance"), + ], +) +def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cache, reset_seeds, optimizations): dtype = ttnn.bfloat8_b - min_top1_acc = 75 - min_top5_acc = 96 mesh_device.enable_async(True) # Load model args and tokenizer - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, optimizations=optimizations) tokenizer = Tokenizer(model_args.tokenizer_path) + # Get accuracy thresholds from PERF.md + min_top1_acc, min_top5_acc = get_accuracy_thresholds( + model_args.model_name, + model_args.device_name, + optimizations, + ) + # Load state_dict for TT model logger.info("Loading weights...") state_dict = model_args.load_state_dict() diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index 803381ffce34..45dea505c240 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -14,7 +14,7 @@ HostEmbedding, ) from models.demos.llama3.tt.llama_model import TtTransformer -from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( @@ -36,6 +36,13 @@ ], ids=["quick", "full"], ) +@pytest.mark.parametrize( + "optimizations", + [ + pytest.param(LlamaOptimizations.accuracy, id="accuracy"), + pytest.param(LlamaOptimizations.performance, id="performance"), + ], +) @pytest.mark.parametrize( "mesh_device", [ @@ -45,20 +52,16 @@ ], indirect=True, ) -def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, reset_seeds, ensure_gc): +def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_program_cache, reset_seeds, ensure_gc): + mesh_device.enable_async(True) + run_ref_pt = True # Flag to run reference PyTorch model and compare PCC cache_pcc = layers == 1 # Flag to measure KV cache PCC. Avoid running for all layers to speed up test time. - dtype = ttnn.bfloat8_b - - mesh_device.enable_async(True) - - # This sets the minimum PCC for each iteration - pcc = 0.88 if layers == 1 else 0.94 # TODO For model test quick (1 layer) one iteration might get a worse PCC - + mode_accuracy = optimizations == LlamaOptimizations.accuracy instruct = True if weights == "instruct" else False dummy_weights = True if weights == "random" else False - model_args = TtModelArgs(mesh_device, instruct=instruct, dummy_weights=dummy_weights) + model_args = TtModelArgs(mesh_device, instruct=instruct, dummy_weights=dummy_weights, optimizations=optimizations) model_name = { (16, False): "llama32_1b", @@ -68,12 +71,19 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, (80, False): "llama31_70b", }[(model_args.n_layers, model_args.is_vision())] + # Define minimum PCC for each iteration + if layers == 1: + pcc = 0.88 if mode_accuracy else 0.86 + else: + pcc = 0.94 if mode_accuracy else 0.94 + + # Define tight final PCC thresholds for quick mode final_model_pcc = { - "llama32_1b": 0.9991, - "llama32_3b": 0.9989, - "llama31_8b": 0.9987, - "llama32_11b": 0.9987, - "llama31_70b": 0.9843, + "llama32_1b": 0.9991 if mode_accuracy else 0.9864, + "llama32_3b": 0.9989 if mode_accuracy else 0.9837, + "llama31_8b": 0.9987 if mode_accuracy else 0.9850, + "llama32_11b": 0.9987 if mode_accuracy else 0.9850, + "llama31_70b": 0.9843 if mode_accuracy else 0.9843, }[model_name] final_k_cache_pcc = { @@ -90,6 +100,7 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, "llama32_11b": 0.9996, "llama31_70b": 0.9998, }[model_name] + quick_iterations = {"llama32_1b": 2, "llama32_3b": 4, "llama31_8b": 6, "llama32_11b": 6, "llama31_70b": 6}[ model_name ] @@ -156,6 +167,8 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, if run_ref_pt: all_tests_pass = True + final_tests_pass = True + kv_cache_tests_pass = True seqlen = 1 # Generating one token per user at a time batch = model_args.max_batch_size @@ -230,6 +243,8 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, if run_ref_pt: if layers == 1 and i == iterations - 1: # On last iteration in the quick test, set a tighter PCC passing, pcc_message = comp_pcc(ref_output, tt_output_torch, final_model_pcc) + if not passing: + final_tests_pass = False else: passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc) @@ -282,9 +297,9 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, logger.info(f"V cache output: {output_pcc}") if does_pass: - logger.info(f"V Cache Passed!") + logger.info(f"KV Cache Passed!") else: - logger.warning(f"V Cache Failed! PCC value is lower than {pcc}") + logger.warning(f"KV Cache Failed! PCC value is lower than {pcc}") all_tests_pass = False if not dummy_weights: @@ -297,4 +312,6 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, logger.info(f"All {generation_length} Llama decode iterations Passed!") else: logger.warning("One or more iterations of Llama decode had bad PCC") + assert final_tests_pass, f"PCC value is lower than {final_model_pcc} for final output. Check Warnings!" + assert kv_cache_tests_pass, f"KV Cache PCC value is lower expected for some of the outputs. Check Warnings!" assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index ca48efd8b118..8eaa7af8ed83 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -15,7 +15,7 @@ encode_prompt_llama_instruct, ) from models.demos.llama3.tt.llama_model import TtTransformer -from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer, precompute_freqs_cis from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( @@ -46,19 +46,31 @@ ], indirect=True, ) -def test_llama_model_inference(mesh_device, seq_len, use_program_cache, reset_seeds, ensure_gc): +@pytest.mark.parametrize( + "optimizations", + [ + pytest.param(LlamaOptimizations.accuracy, id="accuracy"), + pytest.param(LlamaOptimizations.performance, id="performance"), + ], +) +def test_llama_model_inference(mesh_device, seq_len, optimizations, use_program_cache, reset_seeds, ensure_gc): run_ref_pt = True # Flag to run reference PyTorch model and compare PCC cache_pcc = False # Flag to measure KV cache PCC for all layers dtype = ttnn.bfloat8_b - pcc = 0.91 # TODO Look on improving PCC + # This sets the minimum PCC for each iteration based on optimization mode + if optimizations == LlamaOptimizations.accuracy: + pcc = 0.91 # TODO Look on improving PCC + else: # performance mode + assert optimizations == LlamaOptimizations.performance + pcc = 0.91 mesh_device.enable_async(True) # Use instruct weights instead of general weights instruct = True - model_args = TtModelArgs(mesh_device, instruct=instruct, max_batch_size=1) + model_args = TtModelArgs(mesh_device, instruct=instruct, max_batch_size=1, optimizations=optimizations) tokenizer = Tokenizer(model_args.tokenizer_path) logger.info("Loading weights...") diff --git a/models/demos/llama3/tests/test_llama_perf.py b/models/demos/llama3/tests/test_llama_perf.py index c2cda7b346cd..a873449fe2c6 100644 --- a/models/demos/llama3/tests/test_llama_perf.py +++ b/models/demos/llama3/tests/test_llama_perf.py @@ -15,7 +15,7 @@ ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding -from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.perf.perf_utils import prep_perf_report @@ -50,7 +50,7 @@ def test_llama_model_perf(mesh_device, kv_cache_len, expected_compile_time, use_ mesh_device.enable_async(True) - model_args = TtModelArgs(mesh_device) + model_args = TtModelArgs(mesh_device, optimizations=LlamaOptimizations.performance) tokenizer = Tokenizer(model_args.tokenizer_path) if "3.2-1B" in model_args.DEFAULT_CACHE_PATH: diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index a119825cff87..7be102470a54 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -39,8 +39,7 @@ def __init__( cache_file_name=cache_name(name), ) - # Set to "self.args.is_large_model" for mixed-mode MLP which is slightly more accurate - self.four_bit_mlp = True + self.four_bit_mlp = args.optimizations.bfp4_mlp # Sharded weights self.w1 = as_sharded_tensor( diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index aaf0352c8094..407eda7668ab 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -21,6 +21,30 @@ from models.utility_functions import nearest_32 from pathlib import Path from tqdm import tqdm +from dataclasses import dataclass + + +@dataclass +class LlamaOptimizations: + bfp4_mlp: bool + # Future fields will go here: + # bfp8_activations: bool + # bfp8_layernorm: bool + # bfp8_ccl: bool + + @classmethod + def accuracy(cls, model_name): + """Configuration optimized for accuracy + Only 3.1-70B uses bfp4 MLPs in this configuration + """ + return cls(bfp4_mlp=model_name == "3.1-70B") + + @classmethod + def performance(cls, model_name): + """Configuration optimized for performance + All models use bfp4 MLPs in this configuration + """ + return cls(bfp4_mlp=True) class TtModelArgs: @@ -68,12 +92,17 @@ class TtModelArgs: "LLAMA3_1_70B_PARAMS": "models/demos/llama3/model_params/Llama3.1-70B-Instruct", } - def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_size=1): - # Add this near the top of the class, with other class attributes + def __init__( + self, + mesh_device, + instruct=False, + dummy_weights=False, + max_batch_size=1, + optimizations=LlamaOptimizations.accuracy, + ): self.num_devices = mesh_device.get_num_devices() if mesh_device else 0 self.mesh_device = mesh_device self.device_name = {0: "CPU", 1: "N150", 2: "N300", 8: "T3K", 32: "TG"}[self.num_devices] - self.is_large_model = False self.model_name = "Unknown" # Llama model name will be dependent on the checkpoint directory LLAMA_DIR = os.getenv("LLAMA_DIR") @@ -126,10 +155,14 @@ def __init__(self, mesh_device, instruct=False, dummy_weights=False, max_batch_s elif "3.1-70B" in LLAMA_DIR: local_params = "LLAMA3_1_70B_PARAMS" self.model_name = "3.1-70B" - self.is_large_model = True else: raise ValueError(f"Unsupported LLAMA model: {LLAMA_DIR}") + if callable(optimizations): + self.optimizations = optimizations(self.model_name) + else: + self.optimizations = optimizations + # Load model params if not dummy_weights: self._set_llama_params(self.DEFAULT_CKPT_DIR)