-
Notifications
You must be signed in to change notification settings - Fork 220
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
sarkar/Add support for max_length in run_generation #476
Conversation
The documentation is not available anymore as the PR was closed or merged. |
@regisss @libinta @ankurneog could you please review this PR |
@@ -211,6 +221,8 @@ def main(): | |||
) | |||
|
|||
args = parser.parse_args() | |||
if args.max_length is None and args.max_new_tokens is None: | |||
args.max_new_tokens = 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep back compatibility. In current case we can not specify --max_new_tokens
and it would ge a default of 100
Since it was optional I have: length_group = parser.add_mutually_exclusive_group(required=False)
However to maintain back-compat if its not specified, I put default=100
Note I cannot put the default value in " length_group.add_argument(", because in the case where --max_length
is specified and --max_new_tokens
isnt, arg parser will assign it default of 100, which will result in both max_length
and max_new_tokens
getting values
Hence teh default is set here, outside argparser
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't modify the text-generation example and just keep the change in generation.py
. This example doesn't have to enable us to reproduce all use cases.
Then, after that, the test leads to an error with eos_token_id
being None. I don't have a good understanding why yet as it doesn't happen in Transformers and I don't have the bandwidth to look more into it. Let's skip it for the moment.
closing for now |
The change to |
What does this PR do?
semi fix for this issue
Add support for max_length (as an alternative for max_new_tokens). The change in
utils.py
is needed if someone specifiesmax_length
instead ofmax_new_tokens
. The change inrun_generation.py
is to activate and test themax_length
pathThe issue itself is not fixed fully, but it moves past the error mentioned in the issue.
Tests:
runs:
only max_new_tokens is specified
python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --max_new_tokens 100 --batch 8
only max_length is specified
python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --batch 8 --max_length 100
only max_length is specified, with bucket
python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --batch 8 --max_length 100 --bucket_size 30
only max_new_tokens is specified, with bucket
python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --max_new_tokens 100 --batch 8 --bucket_size 30
dataset run with max_length
python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --batch 8 --max_length 100 --bucket_size 30 --dataset_name squad --column context --dataset_max_samples 32
both are not specified (max_new_tokens defaults to 100, old back compat behaviour)
python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --batch 8
(expected) fails
both are specified
python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --max_new_tokens 100 --batch 8 --max_length 108
pytest mentioned in the issue:
cd optimum-habana/tests/transformers/tests/models/gpt2
python -m pytest -vs test_modeling_gpt2.py::GPT2ModelTest::test_beam_search_generate
The original error is gone now. There is new error
"""
if sent_lengths[i] < sent_max_len:
# inserting only the first eos_token_id
E TypeError: 'NoneType' object is not subscriptable
"""
However not sure if beam search is tested/supported/prioritized for GPT2
FWIW, beam search generations with
run_generation.py
with opt/gpt2 using max_length/max_new_tokens pass:python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --max_new_tokens 100 --batch 8 --num_beams 2
python run_generation.py --model_name_or_path gpt2 --use_hpu_graphs --use_kv_cache --bf16 --max_new_tokens 100 --batch 8 --num_beams 2
python run_generation.py --model_name_or_path facebook/opt-350m --use_hpu_graphs --use_kv_cache --bf16 --max_length 100 --batch 8 --num_beams 2
python run_generation.py --model_name_or_path gpt2 --use_hpu_graphs --use_kv_cache --bf16 --max_length 100 --batch 8 --num_beams 2
Fixes # (472)
Before submitting