Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jdf-prog committed Dec 7, 2023
1 parent 2d156f4 commit 14050c6
Showing 1 changed file with 63 additions and 17 deletions.
80 changes: 63 additions & 17 deletions llm_blender/blender/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,8 @@ def fuse(
_generations = self.fuser_tokenizer.batch_decode(output_ids, skip_special_tokens=True)
generations.extend(_generations)
return generations

def best_of_n_generate(
def n_generate(
self,
model, # Union[transformers.PreTrainedModel, vllm.LLM]
model_tokenizer:transformers.PreTrainedTokenizer,
Expand All @@ -500,14 +500,9 @@ def best_of_n_generate(
n:int=5,
sampling_mode:str="top_p_sampling",
batch_size:int=4,
pairrm_cmp_type:str="bubble",
return_all:bool=False,
**generate_kwargs:dict,
):
"""Decoding enhance generate.
In this process, we will generate multiple generations for each input,
Then we will rank these generations and only return the top-k generations,
thus enhancing the quality of generations.
"""We will generate n generations for each input,
Args:
model: Union[transformers.PreTrainedModel, vllm.LLM]
Expand All @@ -525,12 +520,6 @@ def best_of_n_generate(
if None, will use custom sampling strategy by generate_kwargs
batch_size int:
batch size for generation
pairrm_cmp_type str: one of ['bubble', 'full']
- 'bubble': use a single run of bubble sort to get the best of n for quicker speed. Time complexity: O(n)
- 'full': use full pairwise comparison matrix to get the best of n. Time complexity: O(n^2)
return_all bool:
If True, will return all candidates instead of the best of n candidates
The returned candidates will be sorted by the ranker, where the first candidate is the best
generate_kwargs:
kwargs for model.generate()
recommended kwargs:
Expand All @@ -541,9 +530,8 @@ def best_of_n_generate(
Note that num_return_sequences will be set to n, so you don't need to specify it
Returns:
best_candidates
- List[str]: Best candidates against the ranker for each input
- List[List[str]]: All candidates against the ranker for each input, when return_all is True
sampled_candidates
- List[List[str]]: All sampled candidates against the ranker for each input
"""
assert len(inputs) == len(instructions) if instructions is not None else True, "Number of inputs and instructions must be the same if instructions is not None"
if sampling_mode == "top_k_sampling":
Expand Down Expand Up @@ -595,6 +583,64 @@ def best_of_n_generate(
bz_outputs = model_tokenizer.batch_decode(bz_output_ids, skip_special_tokens=True)
bz_sampled_candidates = [bz_outputs[i: i+n] for i in range(0, len(bz_outputs), n)]
sampled_candidates.extend(bz_sampled_candidates)
return sampled_candidates

def best_of_n_generate(
self,
model, # Union[transformers.PreTrainedModel, vllm.LLM]
model_tokenizer:transformers.PreTrainedTokenizer,
inputs:List[str],
instructions:List[str]=None,
n:int=5,
sampling_mode:str="top_p_sampling",
batch_size:int=4,
pairrm_cmp_type:str="bubble",
return_all:bool=False,
**generate_kwargs:dict,
):
"""Decoding enhance generate.
In this process, we will generate multiple generations for each input,
Then we will rank these generations and only return the top-k generations,
thus enhancing the quality of generations.
Args:
model: Union[transformers.PreTrainedModel, vllm.LLM]
Huggingface model that could generate with .generate(**generate_kwargs)
model_tokenizer:
Huggingface tokenizer that could tokenize with .__call__(**generate_kwargs)
inputs List[str]:
List of input texts
instructions List[str]:
List of instructions. if not None, will be prepended to the corresponding input
n int:
the n parameter in best-of-n. That is, how many samples to generate for ranking for each input
sampling_mode:
"top_k_sampling" or "top_p_sampling"
if None, will use custom sampling strategy by generate_kwargs
batch_size int:
batch size for generation
pairrm_cmp_type str: one of ['bubble', 'full']
- 'bubble': use a single run of bubble sort to get the best of n for quicker speed. Time complexity: O(n)
- 'full': use full pairwise comparison matrix to get the best of n. Time complexity: O(n^2)
return_all bool:
If True, will return all candidates instead of the best of n candidates
The returned candidates will be sorted by the ranker, where the first candidate is the best
generate_kwargs:
kwargs for model.generate()
recommended kwargs:
- max_new_tokens: max length of the generation. If not specified, will use model_tokenizer.model_max_length
- top_k: if mode is "top_k_sampling", will use this top_k. if not specified, will use 50
- top_p: if mode is "top_p_sampling", will use this top_p. if not specified, will use 1.0
- temperature: temperature for sampling. if not specified, will use 0.7
Note that num_return_sequences will be set to n, so you don't need to specify it
Returns:
best_candidates
- List[str]: Best candidates against the ranker for each input
- List[List[str]]: All candidates against the ranker for each input, when return_all is True
"""
sampled_candidates = self.n_generate(model, model_tokenizer, inputs,
instructions=instructions, n=n, sampling_mode=sampling_mode, batch_size=batch_size, **generate_kwargs)

best_of_n_outputs = self.get_best_of_n(inputs, sampled_candidates,
instructions=instructions, batch_size=min(batch_size, 32),
Expand Down

0 comments on commit 14050c6

Please sign in to comment.