Skip to content

Commit

Permalink
Optimize reference output generation (#15695)
Browse files Browse the repository at this point in the history
Modify the accuracy reference generation code to process in chunks of
1024.
  • Loading branch information
yieldthought authored Dec 4, 2024
1 parent 1f7eccf commit 56bce0a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 71 deletions.
14 changes: 7 additions & 7 deletions models/demos/llama3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@ Below is an updated table with max prefill context-length support for our demo.

The main reason for a long context length not fitting on device is lack of memory memory. Any exceptions are marked in the table.

| | N150 | N300 | T3K | TG
---------------|---------------|-----------------|----------------|-------------|
| Llama3.2-1B | 64k tokens | 64k tokens | 64k tokens [1] | TBD |
| Llama3.2-3B | 32k tokens | 32k tokens [1] | 64k tokens [1] | TBD |
| Llama3.1-8B | 16k tokens | 64k tokens | 128k tokens | TBD |
| Llama3.2-11B | 16k tokens | 64k tokens | 128k tokens | TBD |
| Llama3.1-70B | Not supported | Not supported | 32k tokens [2] | 128k tokens |
| | N150 | N300 | T3K | TG |
|--------------|---------------|-----------------|---------------|-------------|
| Llama3.2-1B | 64k tokens | 64k tokens | 64k tokens [1] | TBD |
| Llama3.2-3B | 32k tokens | 64k tokens | 64k tokens [1] | TBD |
| Llama3.1-8B | 16k tokens | 64k tokens | 128k tokens | TBD |
| Llama3.2-11B | 16k tokens | 64k tokens | 128k tokens | TBD |
| Llama3.1-70B | Not supported | Not supported | 64k tokens [2] | 128k tokens |

[1] For these configurations, running context lengths greater than those specified on the table will generate a bad repetitive output.

Expand Down
115 changes: 56 additions & 59 deletions models/demos/llama3/tests/generate_reference_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,73 +56,70 @@ def generate_reference_outputs(total_length, output_file):
encoded_tokens = tokenizer.encode(text, bos=True, eos=False)[:total_length]
encoded_tokens_tensor = torch.tensor(encoded_tokens).unsqueeze(0) # Shape [1, seq_len]

print(f"{'Progress':<15}{'Correct':<8}{'Actual':<15}{'Top 5 Predictions':<75}")
print("-" * 113)

# Initialize lists to store results
all_top1_correct = []
all_top5_correct = []
all_top5_tokens = []
segment_top1_correct = []
segment_top5_correct = []
segment_accuracies = []
segment_summaries = []

print(f"{'ETA':<8}{'Progress':<15}{'Correct':<8}{'Actual':<15}{'Top 5 Predictions':<75}")
print("-" * 121)
chunk_size = 1024

start_time = None
with torch.no_grad():
for i in range(total_length):
pt_decode_input = embd(encoded_tokens_tensor[:, i]).view(1, 1, -1)
for chunk_start in range(0, total_length - 1, chunk_size):
chunk_end = min(chunk_start + chunk_size, total_length)
# Get input and target chunks, ensuring they have matching lengths
chunk_tokens = encoded_tokens_tensor[:, chunk_start:chunk_end]
chunk_next_tokens = encoded_tokens[chunk_start + 1 : chunk_end + 1]
actual_chunk_size = min(len(chunk_tokens[0]), len(chunk_next_tokens))

ref_output = reference_model(pt_decode_input, start_pos=i)
# Trim input chunk if needed
chunk_tokens = chunk_tokens[:, :actual_chunk_size]

if i < len(encoded_tokens) - 1:
next_token = encoded_tokens[i + 1]
else:
next_token = torch.argmax(ref_output, dim=-1).item()
# Process chunk
pt_decode_input = embd(chunk_tokens).view(1, actual_chunk_size, -1)
ref_output = reference_model(pt_decode_input, start_pos=chunk_start)

# Compute top-5 predictions for the current token
# Compute top-5 predictions
probs = torch.softmax(ref_output, dim=-1)
top5_probs, top5_indices = torch.topk(probs, k=5, dim=-1)
top5_indices = top5_indices.squeeze()
all_top5_tokens.append(top5_indices)

# Record top1 and top5 correctness
segment_top1_correct.append(top5_indices[0] == next_token)
segment_top5_correct.append(next_token in top5_indices)

sanitize = lambda x: x.replace("\n", "").replace("\r", "").replace("\x0c", "")
actual_token = tokenizer.decode([next_token])
top5_tokens = [tokenizer.decode([t.item()]) for t in top5_indices]
correct = "x" if segment_top1_correct[-1] else ("-" if segment_top5_correct[-1] else " ")
top5_str = " ".join(f"{t:<14}" for t in top5_tokens)
actual_token = sanitize(actual_token)
top5_str = sanitize(top5_str)

# Calculate ETA and progress
if start_time:
elapsed_time = time.time() - start_time
tokens_per_second = i / elapsed_time
remaining_tokens = total_length - 1 - i
eta_seconds = remaining_tokens / tokens_per_second
eta_str = f"{int(eta_seconds // 60):02d}:{int(eta_seconds % 60):02d}"
else:
eta_str = ""
start_time = time.time()

progress_str = f"{i+1}/{total_length}"

print(f"{eta_str:<8}{progress_str:<15}{correct:<8}{actual_token:<15}{top5_str}")

# Calculate and store segment accuracies every 100 tokens or at the end
if (i + 1) % 100 == 0 or i == total_length - 1:
top1_acc = sum(segment_top1_correct) / len(segment_top1_correct) * 100
top5_acc = sum(segment_top5_correct) / len(segment_top5_correct) * 100
segment_accuracies.append((top1_acc, top5_acc))
segment_summaries.append(
f"Tokens {i-len(segment_top1_correct)+1}-{i+1}: Top-1 Accuracy: {top1_acc:.0f} %, Top-5 Accuracy: {top5_acc:.0f} %"
)
segment_top1_correct = []
segment_top5_correct = []

# Convert list to tensor
all_top5_tokens = torch.stack(all_top5_tokens)
_, chunk_top5_tokens = torch.topk(probs, k=5, dim=-1) # Shape: [1, chunk_size, 5]
chunk_top5_tokens = chunk_top5_tokens.squeeze(0) # Shape: [chunk_size, 5]

# Get next tokens tensor, ensuring same length as predictions
chunk_next_tokens_tensor = torch.tensor(chunk_next_tokens[:actual_chunk_size])

# Calculate correctness
chunk_top1_correct = chunk_top5_tokens[:, 0] == chunk_next_tokens_tensor
chunk_top5_correct = torch.any(chunk_top5_tokens == chunk_next_tokens_tensor.unsqueeze(1), dim=1)

# Store results
all_top1_correct.extend(chunk_top1_correct.tolist())
all_top5_correct.extend(chunk_top5_correct.tolist())
all_top5_tokens.append(chunk_top5_tokens)

# Print predictions for this chunk
for i in range(len(chunk_next_tokens)):
global_pos = chunk_start + i
next_token = chunk_next_tokens[i]

sanitize = lambda x: x.replace("\n", "").replace("\r", "").replace("\x0c", "")
actual_token = sanitize(tokenizer.decode([next_token]))
top5_tokens = [sanitize(tokenizer.decode([t.item()])) for t in chunk_top5_tokens[i]]
correct = "x" if chunk_top1_correct[i] else ("-" if chunk_top5_correct[i] else " ")
top5_str = " ".join(f"{t:<14}" for t in top5_tokens)

progress_str = f"{global_pos+1}/{total_length-1}"
print(f"{progress_str:<15}{correct:<8}{actual_token:<15}{top5_str}")

# Calculate and store segment accuracies every 100 tokens
if (global_pos + 1) % 100 == 0 or global_pos == total_length - 2:
start_idx = (global_pos // 100) * 100
end_idx = min(start_idx + 100, len(all_top1_correct))
segment_top1_acc = sum(all_top1_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100
segment_top5_acc = sum(all_top5_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100
if len(segment_accuracies) <= global_pos // 100:
segment_accuracies.append((segment_top1_acc, segment_top5_acc))

# Save the data
data = {
Expand Down
36 changes: 31 additions & 5 deletions models/demos/llama3/tests/generate_reference_outputs.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
#!/bin/bash

# Parse command line arguments
TOTAL_LENGTH=1024 # Default value
while [[ $# -gt 0 ]]; do
case $1 in
--total-length)
TOTAL_LENGTH="$2"
shift 2
;;
--help|-h)
echo "Usage: $0 [OPTIONS]"
echo
echo "Generate reference outputs for Llama models"
echo
echo "Options:"
echo " --total-length N Set the total sequence length (default: 1024)"
echo " --help, -h Show this help message"
exit 0
;;
*)
echo "Unknown option: $1"
echo "Use --help to see available options"
exit 1
;;
esac
done

# Define model directories from environment variables with fallbacks
LLAMA_DIRS=(
# "${LLAMA_32_1B_DIR:-/proj_sw/user_dev/llama32-data/Llama3.2-1B-Instruct}"
# "${LLAMA_32_3B_DIR:-/proj_sw/user_dev/llama32-data/Llama3.2-3B-Instruct}"
# "${LLAMA_31_8B_DIR:-/proj_sw/user_dev/llama31-8b-data/Meta-Llama-3.1-8B-Instruct}"
"${LLAMA_32_1B_DIR:-/proj_sw/user_dev/llama32-data/Llama3.2-1B-Instruct}"
"${LLAMA_32_3B_DIR:-/proj_sw/user_dev/llama32-data/Llama3.2-3B-Instruct}"
"${LLAMA_31_8B_DIR:-/proj_sw/user_dev/llama31-8b-data/Meta-Llama-3.1-8B-Instruct}"
"${LLAMA_32_11B_DIR:-/proj_sw/user_dev/llama32-data/Llama3.2-11B-Vision-Instruct}"
# "${LLAMA_31_70B_DIR:-/proj_sw/llama3_1-weights/Meta-Llama-3.1-70B-Instruct/repacked}"
"${LLAMA_31_70B_DIR:-/proj_sw/llama3_1-weights/Meta-Llama-3.1-70B-Instruct/repacked}"
)

# Create reference_outputs directory if it doesn't exist
Expand Down Expand Up @@ -48,7 +74,7 @@ for DIR in "${LLAMA_DIRS[@]}"; do

# Set LLAMA_DIR environment variable and run the Python script
LLAMA_DIR="$DIR" python3 "${SCRIPT_DIR}/generate_reference_outputs.py" \
--total_length 1024 \
--total_length "$TOTAL_LENGTH" \
--output_file "$OUTPUT_FILE"
done

Expand Down

0 comments on commit 56bce0a

Please sign in to comment.