From 24376fee26d091e9debeda96c0f424be81b847a6 Mon Sep 17 00:00:00 2001 From: Dongfu Date: Sun, 21 Jan 2024 01:07:01 -0500 Subject: [PATCH] add get_worst_of_n() --- llm_blender/blender/blender.py | 73 ++++++++++++++++++++++++++++++- llm_blender/pair_ranker/ranker.py | 9 +++- 2 files changed, 79 insertions(+), 3 deletions(-) diff --git a/llm_blender/blender/blender.py b/llm_blender/blender/blender.py index 8bbe299..0845d9e 100755 --- a/llm_blender/blender/blender.py +++ b/llm_blender/blender/blender.py @@ -366,7 +366,78 @@ def get_best_of_n( else: best_candidates = get_topk_candidates_from_ranks(ranks, candidates, top_k=None) return best_candidates - + + def get_worst_of_n( + self, + inputs:List[str], + candidates:List[List[str]], + instructions:List[str]=None, + pairrm_cmp_type:str="bubble", + return_all:bool=False, + batch_size:int=8, + ): + """Get the worst of n candidates for each input using the ranker + Args: + inputs List[str]: List of input texts + candidates List[List[str]]: List of list of candidate texts, meaning each input can have multiple candidates + instructions List[str]: List of instructions. if not None, will be prepended to the corresponding input + pairrm_cmp_type str: one of ['bubble', 'full'] + - 'bubble': use a single run of bubble sort to get the worst of n for quicker speed. Time complexity: O(n) + - 'full': use full pairwise comparison matrix to get the worst of n. Time complexity: O(n^2) + return_all bool: + If True, will return all candidates instead of the worst of n candidates + The returned candidates will be sorted by the ranker, where the first candidate is the worst + batch_size int: batch size for ranking + Returns: + worst_candidates + - List[str]: worst candidates against the ranker for each input + - List[List[str]]: All candidates against the ranker for each input, when return_all is True + """ + if all([len(c) == 1 for c in candidates]): + # no need to rank + if not return_all: + worst_candidates = [x[0] for x in candidates] + else: + worst_candidates = candidates + return worst_candidates + if self.ranker_config.ranker_type == "pairranker" and pairrm_cmp_type == "bubble": + # use bubble sort single run to get the worst of n for quicker speed + collate_fn = copy.copy(self.ranker_collator) + dataset = RankerDataset(inputs, candidates, instructions=instructions) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn) + worst_idxs = [] + rest_idxs = [] + with torch.no_grad(): + for batch in tqdm(iter(dataloader), desc="Ranking candidates", disable=not self.blender_config.use_tqdm): + batch = {k: v.to(self.blender_config.device) for k, v in batch.items() if v is not None} + outputs = self.ranker._bubble_predict(**batch, best_or_worst="worst") + select_process = outputs['select_process'].detach().cpu().numpy() + worst_idx = select_process[:, 2, -1] + rest_idx = np.where( + select_process[:, 0, :] == select_process[:, 2, :], + select_process[:, 1, :], + select_process[:, 0, :] + ) + rest_idxs.append(rest_idx) + worst_idxs.append(worst_idx) + worst_idxs = np.concatenate(worst_idxs, axis=0) + if not return_all: + worst_candidates = np.array(candidates)[np.arange(len(candidates)), worst_idxs].tolist() + else: + rest_idxs = np.concatenate(rest_idxs, axis=0) + all_idxes = np.concatenate([worst_idxs.reshape(-1, 1), rest_idxs], axis=1) + worst_candidates = [] + for i in range(len(candidates)): + worst_candidates.append([candidates[i][x] for x in all_idxes[i]]) + else: + ranks = self.rank(inputs, candidates, instructions=instructions, batch_size=batch_size) + if not return_all: + worst_candidates = get_topk_candidates_from_ranks(ranks, candidates, top_k=1) + worst_candidates = [x[0] for x in worst_candidates] + else: + worst_candidates = get_topk_candidates_from_ranks(ranks, candidates, top_k=None) + return worst_candidates + def compare(self, inputs: List[str], diff --git a/llm_blender/pair_ranker/ranker.py b/llm_blender/pair_ranker/ranker.py index 20fe6d1..14a3a25 100755 --- a/llm_blender/pair_ranker/ranker.py +++ b/llm_blender/pair_ranker/ranker.py @@ -657,6 +657,7 @@ def _bubble_predict( candidate_attention_mask, scores=None, num_runs=1, + best_or_worst="best", ): """ bubble prediction @@ -711,8 +712,12 @@ def _bubble_predict( loss += _outputs['loss'] preds_inv = -_outputs['logits'] - permu[:, j] = torch.where(preds + preds_inv <= 0, cur_idx, next_idx) - permu[:, j+1] = torch.where(preds + preds_inv > 0, cur_idx, next_idx) + if best_or_worst == "best": + permu[:, j] = torch.where(preds + preds_inv <= 0, cur_idx, next_idx) + permu[:, j+1] = torch.where(preds + preds_inv > 0, cur_idx, next_idx) + elif best_or_worst == "worst": + permu[:, j] = torch.where(preds + preds_inv >= 0, cur_idx, next_idx) + permu[:, j+1] = torch.where(preds + preds_inv < 0, cur_idx, next_idx) assert torch.ne(permu[:, j], permu[:, j+1]).all() better_idx = permu[:, j+1].clone() better_idxs.append(better_idx)