diff --git a/examples/text-generation/run_generation.py b/examples/text-generation/run_generation.py index 4f8dd885d..e57ea2fb1 100644 --- a/examples/text-generation/run_generation.py +++ b/examples/text-generation/run_generation.py @@ -64,7 +64,17 @@ def main(): action="store_true", help="Whether to perform generation in bf16 precision.", ) - parser.add_argument("--max_new_tokens", type=int, default=100, help="Number of tokens to generate.") + length_group = parser.add_mutually_exclusive_group(required=False) + length_group.add_argument( + "--max_new_tokens", + type=int, + help="Number of tokens to generate.", + ) + length_group.add_argument( + "--max_length", + type=int, + help="Max number of tokens (prompt + generation).", + ) parser.add_argument( "--max_input_tokens", type=int, @@ -211,6 +221,8 @@ def main(): ) args = parser.parse_args() + if args.max_length is None and args.max_new_tokens is None: + args.max_new_tokens = 100 # If the DeepSpeed launcher is used, the env variable _ will be equal to /usr/local/bin/deepspeed # For multi node, the value of the env variable WORLD_SIZE should be larger than 8 @@ -381,7 +393,12 @@ def check_optimum_habana_min_version(*a, **b): # Generation configuration generation_config = copy.deepcopy(model.generation_config) - generation_config.max_new_tokens = args.max_new_tokens + if args.max_new_tokens is not None: + assert args.max_new_tokens > 0, "max_length is not set, expect a positive number for max_new_tokens" + generation_config.max_new_tokens = args.max_new_tokens + else: + assert args.max_length > 0, "max_new_tokens is not set, expect a positive number for max_length" + generation_config.max_length = args.max_length generation_config.use_cache = args.use_kv_cache generation_config.static_shapes = is_optimized generation_config.bucket_size = args.bucket_size if is_optimized else -1 @@ -449,7 +466,7 @@ def generate(): profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, ).cpu() - return tokenizer.batch_decode(outputs, skip_special_tokens=True) + return tokenizer.batch_decode(outputs, skip_special_tokens=True), input_tokens["input_ids"].shape[-1] from optimum.habana.utils import HabanaProfile @@ -471,9 +488,10 @@ def generate(): t0 = time.perf_counter() # Benchmark over n_iterations iterations for i in range(args.n_iterations): - generated = generate() + generated, inp_shape = generate() + max_new_tokens = args.max_length - inp_shape if args.max_new_tokens is None else args.max_new_tokens duration = time.perf_counter() - t0 - total_new_tokens_generated = args.n_iterations * args.batch_size * args.max_new_tokens + total_new_tokens_generated = args.n_iterations * args.batch_size * max_new_tokens throughput = total_new_tokens_generated / duration if rank in [-1, 0]: @@ -602,7 +620,7 @@ def generate_dataset(batch): profiling_steps=args.profiling_steps, profiling_warmup_steps=args.profiling_warmup_steps, ).cpu() - return prompt, outputs + return prompt, outputs, batch["input_ids"].shape[-1] # warmup if prompt_length > 0: @@ -630,9 +648,11 @@ def generate_dataset(batch): t_start = time.time() for i, batch in enumerate(dataloader): t0 = time.perf_counter() - prompt, outputs = generate_dataset(batch) + prompt, outputs, inp_len = generate_dataset(batch) duration += time.perf_counter() - t0 - total_new_tokens_generated += args.batch_size * args.max_new_tokens + total_new_tokens_generated += ( + (args.max_length - inp_len) if args.max_new_tokens is None else args.max_new_tokens + ) if rank in [-1, 0]: print(separator) print(f"Batch n°{i+1}") diff --git a/optimum/habana/transformers/generation/utils.py b/optimum/habana/transformers/generation/utils.py index d8bab89e7..ba81ad345 100755 --- a/optimum/habana/transformers/generation/utils.py +++ b/optimum/habana/transformers/generation/utils.py @@ -499,6 +499,9 @@ def generate( if generation_config.static_shapes: # Pad inputs to have static shapes during generation, this gives better performance than dynamic shapes on HPUs # In encoder_decoder models, Inputs are already padded + if generation_config.max_new_tokens is None or generation_config.max_new_tokens < 0: + assert generation_config.max_length > 0 + generation_config.max_new_tokens = generation_config.max_length - inputs_tensor.shape[-1] if not self.config.is_encoder_decoder: # only pad if bucket_size < -1. If we are bucketing (bucket_size > 0), then that is taken care in greedy_search()