Skip to content

Commit

Permalink
add get_worst_of_n()
Browse files Browse the repository at this point in the history
  • Loading branch information
jdf-prog committed Jan 21, 2024
1 parent d1c51d3 commit 24376fe
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 3 deletions.
73 changes: 72 additions & 1 deletion llm_blender/blender/blender.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,78 @@ def get_best_of_n(
else:
best_candidates = get_topk_candidates_from_ranks(ranks, candidates, top_k=None)
return best_candidates


def get_worst_of_n(
self,
inputs:List[str],
candidates:List[List[str]],
instructions:List[str]=None,
pairrm_cmp_type:str="bubble",
return_all:bool=False,
batch_size:int=8,
):
"""Get the worst of n candidates for each input using the ranker
Args:
inputs List[str]: List of input texts
candidates List[List[str]]: List of list of candidate texts, meaning each input can have multiple candidates
instructions List[str]: List of instructions. if not None, will be prepended to the corresponding input
pairrm_cmp_type str: one of ['bubble', 'full']
- 'bubble': use a single run of bubble sort to get the worst of n for quicker speed. Time complexity: O(n)
- 'full': use full pairwise comparison matrix to get the worst of n. Time complexity: O(n^2)
return_all bool:
If True, will return all candidates instead of the worst of n candidates
The returned candidates will be sorted by the ranker, where the first candidate is the worst
batch_size int: batch size for ranking
Returns:
worst_candidates
- List[str]: worst candidates against the ranker for each input
- List[List[str]]: All candidates against the ranker for each input, when return_all is True
"""
if all([len(c) == 1 for c in candidates]):
# no need to rank
if not return_all:
worst_candidates = [x[0] for x in candidates]
else:
worst_candidates = candidates
return worst_candidates
if self.ranker_config.ranker_type == "pairranker" and pairrm_cmp_type == "bubble":
# use bubble sort single run to get the worst of n for quicker speed
collate_fn = copy.copy(self.ranker_collator)
dataset = RankerDataset(inputs, candidates, instructions=instructions)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
worst_idxs = []
rest_idxs = []
with torch.no_grad():
for batch in tqdm(iter(dataloader), desc="Ranking candidates", disable=not self.blender_config.use_tqdm):
batch = {k: v.to(self.blender_config.device) for k, v in batch.items() if v is not None}
outputs = self.ranker._bubble_predict(**batch, best_or_worst="worst")
select_process = outputs['select_process'].detach().cpu().numpy()
worst_idx = select_process[:, 2, -1]
rest_idx = np.where(
select_process[:, 0, :] == select_process[:, 2, :],
select_process[:, 1, :],
select_process[:, 0, :]
)
rest_idxs.append(rest_idx)
worst_idxs.append(worst_idx)
worst_idxs = np.concatenate(worst_idxs, axis=0)
if not return_all:
worst_candidates = np.array(candidates)[np.arange(len(candidates)), worst_idxs].tolist()
else:
rest_idxs = np.concatenate(rest_idxs, axis=0)
all_idxes = np.concatenate([worst_idxs.reshape(-1, 1), rest_idxs], axis=1)
worst_candidates = []
for i in range(len(candidates)):
worst_candidates.append([candidates[i][x] for x in all_idxes[i]])
else:
ranks = self.rank(inputs, candidates, instructions=instructions, batch_size=batch_size)
if not return_all:
worst_candidates = get_topk_candidates_from_ranks(ranks, candidates, top_k=1)
worst_candidates = [x[0] for x in worst_candidates]
else:
worst_candidates = get_topk_candidates_from_ranks(ranks, candidates, top_k=None)
return worst_candidates


def compare(self,
inputs: List[str],
Expand Down
9 changes: 7 additions & 2 deletions llm_blender/pair_ranker/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def _bubble_predict(
candidate_attention_mask,
scores=None,
num_runs=1,
best_or_worst="best",
):
"""
bubble prediction
Expand Down Expand Up @@ -711,8 +712,12 @@ def _bubble_predict(
loss += _outputs['loss']
preds_inv = -_outputs['logits']

permu[:, j] = torch.where(preds + preds_inv <= 0, cur_idx, next_idx)
permu[:, j+1] = torch.where(preds + preds_inv > 0, cur_idx, next_idx)
if best_or_worst == "best":
permu[:, j] = torch.where(preds + preds_inv <= 0, cur_idx, next_idx)
permu[:, j+1] = torch.where(preds + preds_inv > 0, cur_idx, next_idx)
elif best_or_worst == "worst":
permu[:, j] = torch.where(preds + preds_inv >= 0, cur_idx, next_idx)
permu[:, j+1] = torch.where(preds + preds_inv < 0, cur_idx, next_idx)
assert torch.ne(permu[:, j], permu[:, j+1]).all()
better_idx = permu[:, j+1].clone()
better_idxs.append(better_idx)
Expand Down

0 comments on commit 24376fe

Please sign in to comment.