Skip to content

Commit

Permalink
#0: Add performance configs and update key test PCCs for each
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 28, 2024
1 parent 2ce545a commit 1885691
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 52 deletions.
10 changes: 8 additions & 2 deletions models/demos/llama3/PERF.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected from [tests/test_llama_accuracy.py](tests/test_llama_accuracy.py). You can generate this table by running these tests with the `lt` tool (tell it to run `accuracy,demo`) and pressing `m` whilst in the results section to export to markdown.

4-bit MLP:
Note that `test_llama_accuracy.py` parses the below to determine expected values.

## LlamaOptimizations.performance

This configuration uses bfp4 MLP FF1+FF3 for all models.

| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|-------|--------|-----------|-----------|---------------|
Expand All @@ -18,7 +22,9 @@ Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected f
| 11b | N300 | 86 | 97 | 38.6 |
| 70b | T3K | 95 | 100 | 14.3 |

Mixed-bit MLP (main):
## LlamaOptimizations.accuracy

This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model.

| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|-------|--------|-----------|-----------|---------------|
Expand Down
9 changes: 8 additions & 1 deletion models/demos/llama3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ pytest models/demos/llama3/demo/demo.py -k 'instruct and 1_batch'
pytest models/demos/llama3/demo/demo.py -k 'general and 2_batch'
```

By default we run the models in `LlamaOptimizations.performance` mode. You can override this by setting the `optimizations` argument in the demo. To compare the two on a long prompt, you can run:

```
pytest models/demos/llama3/demo/demo.py -k 'long-performance'
pytest models/demos/llama3/demo/demo.py -k 'long-accuracy'
```

### Expected performance and accuracy

See [PERF.md](PERF.md) for expected performance and accuracy.
See [PERF.md](PERF.md) for expected performance and accuracy across different configurations.
60 changes: 49 additions & 11 deletions models/demos/llama3/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from models.perf.benchmarking_utils import BenchmarkProfiler
from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf
from models.demos.llama3.tt.model_config import LlamaOptimizations


def load_and_cache_context(context_url, cache_dir):
Expand Down Expand Up @@ -152,7 +153,15 @@ def preprocess_inputs_prefill(


def run_llama3_demo(
user_input, batch_size, single_layer, mesh_device, instruct_mode, is_ci_env, num_batches, print_to_file
user_input,
batch_size,
single_layer,
mesh_device,
instruct_mode,
is_ci_env,
num_batches,
print_to_file,
optimizations,
):
# Creat batch output file
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
Expand Down Expand Up @@ -189,7 +198,7 @@ def run_llama3_demo(
batch_prompts.append([input_prompts[(j + i) % len(input_prompts)] for j in range(len(input_prompts))])

# Load model args, weights, and tokenizer
model_args = TtModelArgs(mesh_device, instruct=instruct_mode)
model_args = TtModelArgs(mesh_device, instruct=instruct_mode, optimizations=optimizations)
tokenizer = Tokenizer(model_args.tokenizer_path)

if single_layer:
Expand Down Expand Up @@ -700,21 +709,41 @@ def run_llama3_demo(


@pytest.mark.parametrize(
"input_prompts, instruct_weights, num_batches, single_layer",
"input_prompts, instruct_weights, num_batches, single_layer, optimizations",
[
("models/demos/llama3/demo/input_data_prefill_128.json", False, 1, False),
("models/demos/llama3/demo/input_data_prefill_128.json", False, 2, False),
("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 1, False),
("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 2, False),
("models/demos/llama3/demo/input_data_long.json", True, 1, False),
("models/demos/llama3/demo/input_data_questions_prefill_128.json", True, 1, True),
("models/demos/llama3/demo/input_data_prefill_128.json", False, 1, False, LlamaOptimizations.performance),
("models/demos/llama3/demo/input_data_prefill_128.json", False, 2, False, LlamaOptimizations.performance),
(
"models/demos/llama3/demo/input_data_questions_prefill_128.json",
True,
1,
False,
LlamaOptimizations.performance,
),
(
"models/demos/llama3/demo/input_data_questions_prefill_128.json",
True,
2,
False,
LlamaOptimizations.performance,
),
("models/demos/llama3/demo/input_data_long.json", True, 1, False, LlamaOptimizations.performance),
("models/demos/llama3/demo/input_data_long.json", True, 1, False, LlamaOptimizations.accuracy),
(
"models/demos/llama3/demo/input_data_questions_prefill_128.json",
True,
1,
True,
LlamaOptimizations.performance,
),
],
ids=[
"general_weights-1_batch",
"general_weights-2_batch",
"instruct_weights-1_batch",
"instruct_weights-2_batch",
"instruct_weights-long",
"instruct_weights-long-performance",
"instruct_weights-long-accuracy",
"single_layer",
],
)
Expand All @@ -729,7 +758,15 @@ def run_llama3_demo(
indirect=True,
)
def test_llama_demo(
mesh_device, use_program_cache, input_prompts, instruct_weights, is_ci_env, num_batches, single_layer, reset_seeds
mesh_device,
use_program_cache,
input_prompts,
instruct_weights,
is_ci_env,
num_batches,
single_layer,
optimizations,
reset_seeds,
):
if is_ci_env and (instruct_weights == False or "long" in input_prompts or single_layer == True):
pytest.skip("CI demo test only runs instruct weights to reduce CI pipeline load (both are supported)")
Expand All @@ -745,4 +782,5 @@ def test_llama_demo(
is_ci_env=is_ci_env,
num_batches=num_batches,
print_to_file=False,
optimizations=optimizations,
)
8 changes: 4 additions & 4 deletions models/demos/llama3/lt
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,9 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update):
"decoder": "pytest models/demos/llama3/tests/test_llama_decoder.py",
"decoder-prefill": "pytest models/demos/llama3/tests/test_llama_decoder_prefill.py",
"lm-head": "pytest models/demos/llama3/tests/test_lm_head.py",
"model": "pytest models/demos/llama3/tests/test_llama_model.py -k full",
"model-quick": "pytest models/demos/llama3/tests/test_llama_model.py -k quick",
"model-prefill": "pytest models/demos/llama3/tests/test_llama_model_prefill.py",
"model": "pytest models/demos/llama3/tests/test_llama_model.py -k performance-full",
"model-quick": "pytest models/demos/llama3/tests/test_llama_model.py -k performance-quick",
"model-prefill": "pytest models/demos/llama3/tests/test_llama_model_prefill.py -k performance",
"vision-mlp": "pytest models/demos/llama3/tests/multimodal/test_llama_image_mlp.py",
"vision-attn": "pytest models/demos/llama3/tests/multimodal/test_llama_image_attention.py",
"vision-block": "pytest models/demos/llama3/tests/multimodal/test_llama_image_block.py",
Expand All @@ -750,7 +750,7 @@ def run_entry_command(entry, screen_lock, output_entries, screen_needs_update):
"vision-text-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py",
"vision-vision-xfmr": "pytest models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_vision.py",
"perf": "pytest models/demos/llama3/tests/test_llama_perf.py -k 1024",
"accuracy": "pytest models/demos/llama3/tests/test_llama_accuracy.py",
"accuracy": "pytest models/demos/llama3/tests/test_llama_accuracy.py -k performance",
}

# Check if the command is a shortcut and replace it if necessary
Expand Down
61 changes: 56 additions & 5 deletions models/demos/llama3/tests/test_llama_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,48 @@
HostEmbedding,
)
from models.demos.llama3.tt.llama_model import TtTransformer
from models.demos.llama3.tt.model_config import TtModelArgs
from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer
from models.demos.llama3.demo.demo import preprocess_inputs_prefill
from pathlib import Path


def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: LlamaOptimizations):
"""Parse accuracy thresholds from PERF.md for the given model, optimization mode, and device."""
# Get model size (e.g., "1b", "3b", etc.)
model_size = model_name.split("-")[1].lower()

# Read PERF.md
perf_file = Path(__file__).parent.parent / "PERF.md"
with open(perf_file, "r") as f:
content = f.read()

# Split into sections based on optimization mode
sections = content.split("## ")
target_section = next(s for s in sections if s.startswith(f"LlamaOptimizations.{optimizations.__name__}\n"))

print(target_section)

# Parse the table and find the row for our model and device
rows = [
line.split("|")[1:] # Each row starts with a separator
for line in target_section.split("\n")
if f"| {model_size} | {device_name} |" in line
]
if not rows:
raise ValueError(
f"Could not find accuracy data for {model_size} on {device_name} in {optimizations.__name__} mode"
)

assert (
len(rows) == 1
), f"Found multiple rows for {model_size} on {device_name} in {optimizations.__name__} mode in PERF.md"
row = rows[0]
top1_acc = float(row[2].strip())
top5_acc = float(row[3].strip())

# Allow for rounding
return top1_acc - 0.5, top5_acc - 0.5


@torch.no_grad()
Expand All @@ -32,17 +71,29 @@
],
indirect=True,
)
def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cache, reset_seeds):
@pytest.mark.parametrize(
"optimizations",
[
pytest.param(LlamaOptimizations.accuracy, id="accuracy"),
pytest.param(LlamaOptimizations.performance, id="performance"),
],
)
def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cache, reset_seeds, optimizations):
dtype = ttnn.bfloat8_b
min_top1_acc = 75
min_top5_acc = 96

mesh_device.enable_async(True)

# Load model args and tokenizer
model_args = TtModelArgs(mesh_device)
model_args = TtModelArgs(mesh_device, optimizations=optimizations)
tokenizer = Tokenizer(model_args.tokenizer_path)

# Get accuracy thresholds from PERF.md
min_top1_acc, min_top5_acc = get_accuracy_thresholds(
model_args.model_name,
model_args.device_name,
optimizations,
)

# Load state_dict for TT model
logger.info("Loading weights...")
state_dict = model_args.load_state_dict()
Expand Down
51 changes: 34 additions & 17 deletions models/demos/llama3/tests/test_llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
HostEmbedding,
)
from models.demos.llama3.tt.llama_model import TtTransformer
from models.demos.llama3.tt.model_config import TtModelArgs
from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer
from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer
from models.utility_functions import (
Expand All @@ -36,6 +36,13 @@
],
ids=["quick", "full"],
)
@pytest.mark.parametrize(
"optimizations",
[
pytest.param(LlamaOptimizations.accuracy, id="accuracy"),
pytest.param(LlamaOptimizations.performance, id="performance"),
],
)
@pytest.mark.parametrize(
"mesh_device",
[
Expand All @@ -45,20 +52,16 @@
],
indirect=True,
)
def test_llama_model_inference(mesh_device, weights, layers, use_program_cache, reset_seeds, ensure_gc):
def test_llama_model_inference(mesh_device, weights, layers, optimizations, use_program_cache, reset_seeds, ensure_gc):
mesh_device.enable_async(True)

run_ref_pt = True # Flag to run reference PyTorch model and compare PCC
cache_pcc = layers == 1 # Flag to measure KV cache PCC. Avoid running for all layers to speed up test time.

dtype = ttnn.bfloat8_b

mesh_device.enable_async(True)

# This sets the minimum PCC for each iteration
pcc = 0.88 if layers == 1 else 0.94 # TODO For model test quick (1 layer) one iteration might get a worse PCC

mode_accuracy = optimizations == LlamaOptimizations.accuracy
instruct = True if weights == "instruct" else False
dummy_weights = True if weights == "random" else False
model_args = TtModelArgs(mesh_device, instruct=instruct, dummy_weights=dummy_weights)
model_args = TtModelArgs(mesh_device, instruct=instruct, dummy_weights=dummy_weights, optimizations=optimizations)

model_name = {
(16, False): "llama32_1b",
Expand All @@ -68,12 +71,19 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,
(80, False): "llama31_70b",
}[(model_args.n_layers, model_args.is_vision())]

# Define minimum PCC for each iteration
if layers == 1:
pcc = 0.88 if mode_accuracy else 0.86
else:
pcc = 0.94 if mode_accuracy else 0.94

# Define tight final PCC thresholds for quick mode
final_model_pcc = {
"llama32_1b": 0.9991,
"llama32_3b": 0.9989,
"llama31_8b": 0.9987,
"llama32_11b": 0.9987,
"llama31_70b": 0.9843,
"llama32_1b": 0.9991 if mode_accuracy else 0.9864,
"llama32_3b": 0.9989 if mode_accuracy else 0.9837,
"llama31_8b": 0.9987 if mode_accuracy else 0.9850,
"llama32_11b": 0.9987 if mode_accuracy else 0.9850,
"llama31_70b": 0.9843 if mode_accuracy else 0.9843,
}[model_name]

final_k_cache_pcc = {
Expand All @@ -90,6 +100,7 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,
"llama32_11b": 0.9996,
"llama31_70b": 0.9998,
}[model_name]

quick_iterations = {"llama32_1b": 2, "llama32_3b": 4, "llama31_8b": 6, "llama32_11b": 6, "llama31_70b": 6}[
model_name
]
Expand Down Expand Up @@ -156,6 +167,8 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,

if run_ref_pt:
all_tests_pass = True
final_tests_pass = True
kv_cache_tests_pass = True

seqlen = 1 # Generating one token per user at a time
batch = model_args.max_batch_size
Expand Down Expand Up @@ -230,6 +243,8 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,
if run_ref_pt:
if layers == 1 and i == iterations - 1: # On last iteration in the quick test, set a tighter PCC
passing, pcc_message = comp_pcc(ref_output, tt_output_torch, final_model_pcc)
if not passing:
final_tests_pass = False
else:
passing, pcc_message = comp_pcc(ref_output, tt_output_torch, pcc)

Expand Down Expand Up @@ -282,9 +297,9 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,
logger.info(f"V cache output: {output_pcc}")

if does_pass:
logger.info(f"V Cache Passed!")
logger.info(f"KV Cache Passed!")
else:
logger.warning(f"V Cache Failed! PCC value is lower than {pcc}")
logger.warning(f"KV Cache Failed! PCC value is lower than {pcc}")
all_tests_pass = False

if not dummy_weights:
Expand All @@ -297,4 +312,6 @@ def test_llama_model_inference(mesh_device, weights, layers, use_program_cache,
logger.info(f"All {generation_length} Llama decode iterations Passed!")
else:
logger.warning("One or more iterations of Llama decode had bad PCC")
assert final_tests_pass, f"PCC value is lower than {final_model_pcc} for final output. Check Warnings!"
assert kv_cache_tests_pass, f"KV Cache PCC value is lower expected for some of the outputs. Check Warnings!"
assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!"
Loading

0 comments on commit 1885691

Please sign in to comment.