Skip to content

Commit

Permalink
#0: Add missing row for 11b t3k
Browse files Browse the repository at this point in the history
  • Loading branch information
yieldthought committed Nov 28, 2024
1 parent 1885691 commit 9bdd24e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
2 changes: 2 additions & 0 deletions models/demos/llama3/PERF.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ This configuration uses bfp4 MLP FF1+FF3 for all models.
| 8b | N300 | 84 | 98 | 38.6 |
| 8b | T3K | 84 | 98 | 52.6 |
| 11b | N300 | 86 | 97 | 38.6 |
| 11b | T3K | 84 | 98 | 52.6 |
| 70b | T3K | 95 | 100 | 14.3 |

## LlamaOptimizations.accuracy
Expand All @@ -38,4 +39,5 @@ This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model.
| 8b | N300 | 90 | 98 | 34.1 |
| 8b | T3K | 88 | 97 | 49.9 |
| 11b | N300 | 90 | 97 | 33.8 |
| 11b | T3K | 88 | 97 | 52.6 |
| 70b | T3K | 95 | 100 | 14.5 |
16 changes: 7 additions & 9 deletions models/demos/llama3/tests/test_llama_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: Ll
sections = content.split("## ")
target_section = next(s for s in sections if s.startswith(f"LlamaOptimizations.{optimizations.__name__}\n"))

print(target_section)

# Parse the table and find the row for our model and device
rows = [
line.split("|")[1:] # Each row starts with a separator
Expand Down Expand Up @@ -87,13 +85,6 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac
model_args = TtModelArgs(mesh_device, optimizations=optimizations)
tokenizer = Tokenizer(model_args.tokenizer_path)

# Get accuracy thresholds from PERF.md
min_top1_acc, min_top5_acc = get_accuracy_thresholds(
model_args.model_name,
model_args.device_name,
optimizations,
)

# Load state_dict for TT model
logger.info("Loading weights...")
state_dict = model_args.load_state_dict()
Expand Down Expand Up @@ -320,6 +311,13 @@ def test_tt_model_accuracy(mesh_device, prefill_len, decode_len, use_program_cac
true_word = sanitize(tokenizer.decode([true_token]))
logger.info(f"{error['position']}: {context}[{incorrect}] != [{expected}], true: [{true_word}]")

# Get accuracy thresholds from PERF.md
min_top1_acc, min_top5_acc = get_accuracy_thresholds(
model_args.model_name,
model_args.device_name,
optimizations,
)

logger.info(f"Top-1: {total_top1_acc:.0f}% | Top-5: {total_top5_acc:.0f}%")
assert total_top1_acc > min_top1_acc, f"Top-1 accuracy {total_top1_acc:.1f}% is too low (expected >{min_top1_acc}%)"
assert total_top5_acc > min_top5_acc, f"Top-5 accuracy {total_top5_acc:.1f}% is too low (expected >{min_top5_acc}%)"

0 comments on commit 9bdd24e

Please sign in to comment.