Skip to content

Commit

Permalink
Remove "generation_length" in favor of "generation_kwargs" (#3014)
Browse files Browse the repository at this point in the history
* kill generation_length

* fix tests

* fix test

* add deprecation warning

* fix test

* add gen_len back into static_keys

* simplify setting variable in forward and add test

* simply test

* trailing comma

* trailing comma

* linting

---------

Co-authored-by: Daniel King <[email protected]>
  • Loading branch information
2 people authored and Chuck Tang committed May 16, 2024
1 parent 2c4dfb9 commit 5be0eed
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 17 deletions.
5 changes: 2 additions & 3 deletions composer/datasets/in_context_learning_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,10 +715,10 @@ def __init__(
'mode': 'generate',
'labels': [],
'cot_delimiter': self.cot_delimiter,
'generation_length': self.max_answer_length,
'stopping_criteria': early_stopping_criteria,
'do_normalization': do_normalization,
'generation_kwargs': {
'max_new_tokens': self.max_answer_length,
'pad_token_id': self.pad_tok_id,
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id,
Expand Down Expand Up @@ -1260,7 +1260,6 @@ class InContextLearningCodeEvalDataset(InContextLearningDataset):
- test_outputs: List of test outputs
- languages: List of languages
- pass_at_k: Passed value for pass_at_k
- generation_length: Derrived maximum generation length
- generation_kwargs: Dictionary of kwargs neeeded for generation. Includes the following, which will be individually overwritten
by keys in generaiton_kwargs if set (see https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
for more details):
Expand Down Expand Up @@ -1349,14 +1348,14 @@ def __init__(
'test_outputs': [],
'languages': [],
'pass_at_k': pass_at_k,
'generation_length': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length),
'generation_kwargs': {
'pad_token_id': self.pad_tok_id,
'num_beams': 1, # single beam
'do_sample': True,
'temperature': 0.2, # good default for code
'use_cache': True,
'eos_token_id': self.tokenizer.eos_token_id,
'max_new_tokens': min(self.max_answer_length, self.max_seq_len - self.max_prompt_length),
},
'sample_id': [],
'pass_at_k': list(pass_at_k),
Expand Down
12 changes: 11 additions & 1 deletion composer/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,20 @@ def eval_forward(self, batch, outputs: Optional[Any] = None):
raise ValueError(
'Generation eval cannot be used without providing a tokenizer to the model constructor.')

if 'generation_length' in batch:
warnings.warn(
('`generation_length` has been deprecated in favor of passing `max_new_tokens` directly into `generation_kwargs`.'
'It will be removed in v0.21'),
DeprecationWarning,
)
if 'generation_kwargs' in batch:
batch['generation_kwargs']['max_new_tokens'] = batch['generation_length']
else:
batch['generation_kwargs'] = {'max_new_tokens': batch['generation_length']}

self.labels = batch.pop('labels')
generation = self.generate(batch['input_ids'],
attention_mask=batch['attention_mask'],
max_new_tokens=batch['generation_length'],
synced_gpus=dist.get_world_size() > 1,
**batch.get('generation_kwargs', {}))

Expand Down
25 changes: 14 additions & 11 deletions tests/datasets/test_in_context_learning_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ def test_update_generation_kwargs_no_kwargs_qa_dataset(tmp_path):
continuation_delimiter=': ',
destination_path=str(Path(gathered_paths[0]) / 'icl.jsonl'),
generation_kwargs=None)
assert len(dl.base_batch['generation_kwargs']) == 3
assert len(dl.base_batch['generation_kwargs']) == 4


def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path):
Expand All @@ -321,7 +321,7 @@ def test_update_generation_kwargs_with_kwargs_qa_dataset(tmp_path):
generation_kwargs={'temperature': 0.9})
assert 'generation_kwargs' in dl.base_batch
assert dl.base_batch['generation_kwargs']['temperature'] == 0.9
assert len(dl.base_batch['generation_kwargs']) == 4
assert len(dl.base_batch['generation_kwargs']) == 5


@pytest.mark.filterwarnings(
Expand Down Expand Up @@ -1255,8 +1255,8 @@ def test_qa_split_batch(tiny_opt_tokenizer, dataset_uri, tmp_path):
assert len(split2['labels']) == 1
assert all(isinstance(v, list) for v in split1['labels'] + split2['labels'])

assert isinstance(split1['generation_length'], int)
assert isinstance(split2['generation_length'], int)
assert isinstance(split1['generation_kwargs']['max_new_tokens'], int)
assert isinstance(split2['generation_kwargs']['max_new_tokens'], int)

assert isinstance(split1['generation_kwargs'], dict)
assert isinstance(split2['generation_kwargs'], dict)
Expand Down Expand Up @@ -1326,7 +1326,7 @@ def test_qa_task_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path, num_fews
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data

assert batch['generation_length'] == maximum_answer_length
assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
Expand Down Expand Up @@ -1376,7 +1376,7 @@ def test_qa_task_with_cot_dataloader(dataset_uri, tiny_gpt2_tokenizer, tmp_path,
assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == maximum_answer_length
assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])
decoded_batch = tokenizer.batch_decode(batch['input_ids'])
assert all(item.count('Q: ') == num_fewshot + 1 for item in decoded_batch)
Expand Down Expand Up @@ -1491,8 +1491,11 @@ def test_code_eval_split_batch(dataset_uri, tmp_path):
assert len(batch[field]) == size
assert all(isinstance(val, type_) for val in batch[field])

static_keys = {'pass_at_k': (int, list), 'generation_length': int, 'generation_kwargs': dict}
static_keys = {'pass_at_k': (int, list), 'generation_kwargs': dict}
for batch in batches:
assert 'generation_kwargs' in batch
assert 'max_new_tokens' in batch['generation_kwargs']
assert isinstance(batch['generation_kwargs']['max_new_tokens'], int)
for field, type_ in static_keys.items():
assert isinstance(batch[field], type_)

Expand Down Expand Up @@ -1544,7 +1547,7 @@ def test_code_eval_sentpiece_dataloader(dataset_uri, tmp_path, num_fewshot, prom
assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == 129
assert batch['generation_kwargs']['max_new_tokens'] == 129
has_left_padding.extend([item[0] == tokenizer.eos_token_id for item in batch['input_ids']])
assert not all(has_left_padding) # longest should be pushed left

Expand Down Expand Up @@ -1613,7 +1616,7 @@ def test_code_eval_test_cases(dataset_uri, tmp_path, tiny_llama_tokenizer):
assert tuple(batch['attention_mask'].shape) == (batch_size, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == 129
assert batch['generation_kwargs']['max_new_tokens'] == 129
assert any(item[0] != tokenizer.eos_token_id for item in batch['input_ids']) # longest should be pushed left

mod = types.ModuleType('test_module')
Expand Down Expand Up @@ -1703,7 +1706,7 @@ def test_code_eval_task_dataloader(dataset_uri, tmp_path, num_fewshot, prompt_st
assert tuple(batch['attention_mask'].shape) == (bs, max_prompt_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == 122
assert batch['generation_kwargs']['max_new_tokens'] == 122
has_left_padding.extend([item[0] == tokenizer.eos_token_id for item in batch['input_ids']])
assert not all(has_left_padding) # longest should be pushed left

Expand Down Expand Up @@ -2459,7 +2462,7 @@ def test_hf_dataloading_custom_parsing(dataset_uri, tiny_gpt2_tokenizer, tmp_pat
assert tuple(batch['attention_mask'].shape) == (batch_size, seqlen - maximum_answer_length)
assert batch['mode'] == 'generate'
# the maximum generation length from the small test data
assert batch['generation_length'] == maximum_answer_length
assert batch['generation_kwargs']['max_new_tokens'] == maximum_answer_length
assert all(item[0] == tokenizer.eos_token_id for item in batch['input_ids'])

decoded_batch = tokenizer.batch_decode(batch['input_ids'])
Expand Down
25 changes: 23 additions & 2 deletions tests/models/test_hf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,11 +1195,12 @@ def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_f
for k, v in input_dict.items():
input_dict[k] = device.tensor_to_device(v)
input_dict['mode'] = 'generate'
input_dict['generation_kwargs'] = {}

input_dict['generation_length'] = 5
input_dict['generation_kwargs']['max_new_tokens'] = 5
input_dict['labels'] = [['answer1'], ['answer2']]
generation1 = model.eval_forward(input_dict, None)
input_dict['generation_length'] = 3
input_dict['generation_kwargs']['max_new_tokens'] = 3
input_dict['labels'] = [['answer1'], ['answer2']]
generation2 = model.eval_forward(input_dict, None)

Expand All @@ -1208,6 +1209,26 @@ def test_eval_forward_generate(device, world_size, hf_model, hf_tokenizer, use_f
assert all(isinstance(decoded_generation, str) for decoded_generation in generation2)


def test_eval_forward_generate_adjust_generation_length(tiny_gpt2_model, tiny_gpt2_tokenizer):
model = HuggingFaceModel(tiny_gpt2_model, tokenizer=tiny_gpt2_tokenizer, use_logits=True)
input_dict = tiny_gpt2_tokenizer(['hello', 'goodbyes'], return_tensors='pt', padding=True)

input_dict['mode'] = 'generate'
input_dict['generation_kwargs'] = {}
input_dict['generation_length'] = 5
input_dict['labels'] = [['answer1'], ['answer2']]
with pytest.warns(DeprecationWarning):
generation1 = model.eval_forward(input_dict, None)

input_dict['generation_length'] = 3
input_dict['labels'] = [['answer1'], ['answer2']]
generation2 = model.eval_forward(input_dict, None)

assert len(generation1) == len(generation2) == 2
assert all(isinstance(decoded_generation, str) for decoded_generation in generation1)
assert all(isinstance(decoded_generation, str) for decoded_generation in generation2)


@pytest.mark.parametrize('peft_type', ['LORA', 'loRa'])
@pytest.mark.parametrize('task_type', ['CAUSAL_LM', 'causal_lm'])
def test_peft_init(peft_type: str, task_type: str, tiny_gpt2_model, gpt2_peft_config):
Expand Down

0 comments on commit 5be0eed

Please sign in to comment.