Skip to content

Commit

Permalink
#5383: [Falcon7b] Add perf and output verification settings to demo a…
Browse files Browse the repository at this point in the history
…nd add demo to CI

Signed-off-by: Salar Hosseini <[email protected]>
  • Loading branch information
skhorasganiTT committed May 17, 2024
1 parent 607faa1 commit 2e3f221
Show file tree
Hide file tree
Showing 12 changed files with 183 additions and 32 deletions.
98 changes: 83 additions & 15 deletions models/demos/falcon7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,19 @@ def run_falcon_demo_kv(
num_layers=32,
perf_mode=False, # Option to measure perf using max seq length (with invalid outputs)
greedy_sampling=False, # Option to use greedy decoding instead of top-k/p
expected_perf_prefill_decode=None, # Expected perf (t/s) for prefill and decode in perf mode
expected_greedy_output_path=None, # Path for expected outputs for greedy decoding
save_generated_text_path=None, # If provided, save generated text to this path (e.g. set to expected_greedy_output_path to update expected output)
):
assert not (expected_perf_prefill_decode and expected_greedy_output_path), "Cannot verify both perf and output!"
assert not (perf_mode and save_generated_text_path), "Cannot save generated text in perf mode!"
if expected_greedy_output_path is not None:
assert (
not perf_mode and greedy_sampling
), "Output verification only supported for greedy sampling in default mode!"
elif expected_perf_prefill_decode is not None:
assert perf_mode, "Performance verification is only supported for perf mode!"

disable_persistent_kernel_cache()
disable_compilation_reports()

Expand Down Expand Up @@ -402,7 +414,8 @@ def run_falcon_demo_kv(
logger.info("Finished inference prefill stage!")
num_users_generated_prefill = num_users if not perf_mode else (N - N_warmup) * num_devices

generated_ids = torch.concat((prefill_ids[..., :num_input_tokens], output_ids), dim=1)
if not perf_mode:
generated_ids = torch.concat((prefill_ids[..., :num_input_tokens], output_ids), dim=1)

### Inference run decode ###
logger.info("Running inference decode stage...")
Expand All @@ -422,9 +435,12 @@ def run_falcon_demo_kv(
N = max_seq_len - num_input_tokens
N_warmup = 0
else:
N = 15
N_warmup = 5
for output_token_index in range(N):
N = 30
N_warmup = 10
print_per_generated_token = (
expected_greedy_output_path is None
) # print per generated token if not verifying outputs
for output_token_index in range(N) if print_per_generated_token else tqdm(range(N), desc="Generating tokens"):
time_decode_inference_start = time.time()
(
tt_decode_input_ids,
Expand Down Expand Up @@ -473,22 +489,27 @@ def run_falcon_demo_kv(
generated_ids = torch.concat((generated_ids, decode_ids[:num_users]), dim=1)
kv_cache_len += 1

# TODO: Remove if we don't want to print per generated token
os.system("clear")
print_output_prompts(generated_ids, tokenizer, batch_size)
if print_per_generated_token:
os.system("clear")
print_output_prompts(generated_ids, tokenizer, batch_size)

logger.info("Finished inference decode stage!")
num_tokens_generated_decode = global_batch * (output_token_index - N_warmup + 1)
logger.info(f"Total number of tokens generated in decode: {num_tokens_generated_decode}")

if not perf_mode:
print_output_prompts(generated_ids, tokenizer, batch_size)
generated_text = tokenizer.batch_decode(generated_ids.tolist())

if save_generated_text_path is not None:
with open(save_generated_text_path, "w") as f:
json.dump(generated_text, f)
else:
generated_text = None

for device in devices:
device.disable_and_clear_program_cache()

generated_text = tokenizer.batch_decode(generated_ids.tolist())

measurements = {
"preprocessing": profiler.get("tokenizing_inputs"),
"loading_weights": profiler.get("loading_weights"),
Expand All @@ -500,8 +521,11 @@ def run_falcon_demo_kv(
"inference_prefill": time_prefill_inference,
"inference_decode": time_decode_inference,
"inference_total": time_prefill_inference + time_decode_inference,
"inference_throughput_prefill": num_users_generated_prefill / time_prefill_inference,
"inference_throughput_decode": num_tokens_generated_decode / time_decode_inference,
"inference_user_throughput_prefill": num_users_generated_prefill / time_prefill_inference, # users/s
"inference_token_throughput_prefill": num_users_generated_prefill
/ time_prefill_inference
* prefill_ids.shape[1], # tokens/s
"inference_token_throughput_decode": num_tokens_generated_decode / time_decode_inference, # tokens/s
}

logger.info(f"pre processing: {round(measurements['preprocessing'], 5)} s")
Expand All @@ -516,13 +540,57 @@ def run_falcon_demo_kv(
logger.info(f"prefill inference time: {round(measurements['inference_prefill'], 5)} s")
logger.info(f"decode inference time: {round(measurements['inference_decode'], 5)} s")
logger.info(f"total inference time: {round(measurements['inference_total'], 5)} s")
logger.info(f"inference throughput prefill: {round(measurements['inference_throughput_prefill'], 5)} users/s")
logger.info(f"inference throughput prefill: {round(measurements['inference_user_throughput_prefill'], 5)} users/s")
logger.info(
f"inference throughput prefill | seq_len={prefill_ids.shape[1]} : {round(measurements['inference_throughput_prefill']*prefill_ids.shape[1], 5)} tok/s"
f"inference throughput prefill | seq_len={prefill_ids.shape[1]} : {round(measurements['inference_token_throughput_prefill'], 5)} tok/s"
)
logger.info(f"inference throughput decode: {round(measurements['inference_throughput_decode'], 5)} tok/s")
logger.info(f"inference throughput decode: {round(measurements['inference_token_throughput_decode'], 5)} tok/s")
logger.info(
f"inference throughput decode (per user): {round(measurements['inference_throughput_decode']/global_batch, 5)} tok/s/user"
f"inference throughput decode (per user): {round(measurements['inference_token_throughput_decode']/global_batch, 5)} tok/s/user"
)

# Verify output or perf if expected values are provided
perf_prefill_decode = [
measurements["inference_token_throughput_prefill"],
measurements["inference_token_throughput_decode"],
]
verify_output_or_perf(
generated_text, perf_prefill_decode, expected_greedy_output_path, expected_perf_prefill_decode
)

return generated_text, measurements


def verify_output_or_perf(
generated_text, perf_prefill_decode, expected_greedy_output_path, expected_perf_prefill_decode
):
assert expected_perf_prefill_decode is None or expected_greedy_output_path is None
if expected_perf_prefill_decode is not None:
does_pass = True
if perf_prefill_decode[0] < expected_perf_prefill_decode[0]:
does_pass = False
logger.warning(
f"Prefill perf {perf_prefill_decode[0]} is lower than expected {expected_perf_prefill_decode[0]}"
)
if perf_prefill_decode[1] < expected_perf_prefill_decode[1]:
does_pass = False
logger.warning(
f"Decode perf {perf_prefill_decode[1]} is lower than expected {expected_perf_prefill_decode[1]}"
)
if does_pass:
logger.info("Perf Check Passed!")
else:
logger.warning("Perf Check Failed!")
assert (
does_pass
), f"Prefill or decode perf is lower than {expected_perf_prefill_decode}. See earlier warnings for more details."
elif expected_greedy_output_path is not None:
with open(expected_greedy_output_path, "r") as f:
expected_output = json.load(f)
does_pass = generated_text == expected_output
if does_pass:
logger.info("Output Check Passed!")
else:
assert (
does_pass
), f"Generated text does not match expected output! \n\n Generated text:\n {generated_text} \n\n Expected output:\n {expected_output}"
1 change: 0 additions & 1 deletion models/demos/falcon7b/tests/expected_output.json

This file was deleted.

Loading

0 comments on commit 2e3f221

Please sign in to comment.