From 347bbb6a9adf8cce293a41ebe462c5ca09b3812d Mon Sep 17 00:00:00 2001 From: Sean Lee Date: Thu, 29 Feb 2024 19:25:24 +0800 Subject: [PATCH] improve 2dmse; filter duplicate --- angle_emb/angle.py | 106 ++++++++++++++++++++++++++++----------------- 1 file changed, 66 insertions(+), 40 deletions(-) diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 18b6250..7ea538c 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -4,6 +4,7 @@ import re import sys import json +import math import random from functools import partial from typing import Any, Dict, Optional, List, Union, Tuple, Callable @@ -581,12 +582,13 @@ class AngleDataCollator: :param padding: Union[bool, str, PaddingStrategy], padding strategy :param max_length: Optional[int], max length :param return_tensors: str - + :param filter_duplicate: bool. Whether filter duplicate data """ tokenizer: PreTrainedTokenizerBase padding: Union[bool, str, PaddingStrategy] = 'longest' max_length: Optional[int] = None return_tensors: str = "pt" + filter_duplicate: bool = True def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str, torch.Tensor]: if return_tensors is None: @@ -595,6 +597,7 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str end_with_eos = features[0]['extra']['end_with_eos'] new_features = [] + duplicate_set = set() for feature in features: seperate_ids = feature['seperate_ids'] input_ids = feature['input_ids'] @@ -609,26 +612,41 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str max_seperate_id = max(seperate_ids) prev_start_idx = 0 + current_features = [] + is_duplicate = False for seperate_id in range(1, max_seperate_id + 1): start_idx = seperate_ids.index(seperate_id) - new_feature = {} - new_feature['input_ids'] = input_ids[prev_start_idx:start_idx] + new_input_ids = input_ids[prev_start_idx:start_idx] + if tuple(new_input_ids) in duplicate_set: + is_duplicate = True + if self.filter_duplicate: + break + duplicate_set.add(tuple(new_input_ids)) + new_feature['input_ids'] = new_input_ids new_feature['attention_mask'] = attention_mask[prev_start_idx:start_idx] if has_token_type_ids: new_feature['token_type_ids'] = token_type_ids[prev_start_idx:start_idx] new_feature['labels'] = feature['labels'] - new_features.append(new_feature) + current_features.append(new_feature) prev_start_idx = start_idx # last new_feature = {} - new_feature['input_ids'] = input_ids[prev_start_idx:] + new_input_ids = input_ids[prev_start_idx:] + if tuple(new_input_ids) in duplicate_set: + is_duplicate = True + duplicate_set.add(tuple(new_input_ids)) + new_feature['input_ids'] = new_input_ids new_feature['attention_mask'] = attention_mask[prev_start_idx:] if has_token_type_ids: new_feature['token_type_ids'] = token_type_ids[prev_start_idx:] new_feature['labels'] = feature['labels'] - new_features.append(new_feature) + current_features.append(new_feature) + + if self.filter_duplicate and is_duplicate: + continue + new_features += current_features # remove features del features @@ -685,13 +703,17 @@ def __init__(self, self.padding_strategy = padding_strategy self.is_llm = is_llm - def __call__(self, inputs: Dict, layer_index: int = -1, embedding_size: Optional[int] = None) -> torch.Tensor: + def __call__(self, inputs: Dict, layer_index: int = -1, embedding_size: Optional[int] = None, + return_all_layer_outputs: bool = False) -> torch.Tensor: """ :param inputs: Dict. Model inputs. :param layer_index: int. Get embeddings from specific layer. :param embedding_size: int. Set embedding size for sentence embeddings for 2DMSE models. """ - outputs = self.model(output_hidden_states=True, return_dict=True, **inputs).hidden_states[layer_index] + all_layer_outputs = self.model(output_hidden_states=True, return_dict=True, **inputs).hidden_states + if return_all_layer_outputs: + return all_layer_outputs + outputs = all_layer_outputs[layer_index] if self.is_llm: batch_size = inputs['input_ids'].shape[0] sequence_lengths = -1 if self.padding_strategy == 'left' else inputs["attention_mask"].sum(dim=1) - 1 @@ -802,46 +824,48 @@ def __init__(self, self.tdmse_student_lambda = tdmse_student_lambda self.apply_tdmse_kl = apply_tdmse_kl self.n_layers = self.pooler.model.config.num_hidden_layers - self.tdmse_hidden_sizes = get_geometric_hidden_sizes(base=8, max_hidden=self.pooler.model.config.hidden_size) + self.hidden_size = self.pooler.model.config.hidden_size + self.tdmse_hidden_sizes = get_geometric_hidden_sizes(base=8, max_hidden=self.hidden_size) self.kl_loss_fct = nn.KLDivLoss(reduction='batchmean') - logger.info('Train 2DMSE!') + logger.info('Train with 2DMSE!') def compute_loss(self, model, inputs, return_outputs=False): labels = inputs.pop("labels", None) # layer sample_layer = random.randint(1, self.n_layers - 1) - if self.fixed_teacher_name_or_path is not None: - all_teacher_outputs = self.pooler(inputs, layer_index=-1) - teacher_outputs = get_pooling(all_teacher_outputs, inputs, - self.alignment_pooling_strategy, - self.pooler.padding_strategy) - all_student_outputs = self.pooler(inputs, layer_index=sample_layer) - student_outputs = get_pooling(all_student_outputs, inputs, - self.alignment_pooling_strategy, - self.pooler.padding_strategy) - else: - teacher_outputs = self.pooler(inputs, layer_index=-1) - student_outputs = self.pooler(inputs, layer_index=sample_layer) - - kl_outputs = teacher_outputs + pooling_strategy = (self.alignment_pooling_strategy + if self.pooler.pooling_strategy == 'all' + else self.pooler.pooling_strategy) + all_layer_outputs = self.pooler(inputs, layer_index=-1, return_all_layer_outputs=True) + all_teacher_outputs = all_layer_outputs[-1] + teacher_outputs = get_pooling(all_teacher_outputs, inputs, + pooling_strategy, + self.pooler.padding_strategy) + all_student_outputs = all_layer_outputs[sample_layer] + student_outputs = get_pooling(all_student_outputs, + inputs, + pooling_strategy, + self.pooler.padding_strategy) + + teacher_kl_outputs = teacher_outputs if self.fixed_teacher_name_or_path is not None: with torch.no_grad(): self.fixed_teacher_pooler.model = self.fixed_teacher_pooler.model.to(self.pooler.model.device) all_fixed_outputs = self.fixed_teacher_pooler(inputs) - kl_outputs = get_pooling(all_fixed_outputs, inputs, - self.alignment_pooling_strategy, - self.pooler.padding_strategy) + teacher_kl_outputs = get_pooling(all_fixed_outputs, + inputs, + self.alignment_pooling_strategy, + self.pooler.padding_strategy) teacher_loss = self.loss_fct(labels, teacher_outputs) - loss1 = self.tdmse_teacher_lambda * teacher_loss - if self.tdmse_student_lambda > 0: - student_loss = self.loss_fct(labels, student_outputs) - loss1 += self.tdmse_student_lambda * student_loss + loss1 = teacher_loss + student_loss = self.loss_fct(labels, student_outputs) + loss1 += student_loss / sample_layer if self.apply_tdmse_kl and self.tdmse_student_lambda > 0: kl_loss = self.kl_loss_fct( - F.log_softmax(student_outputs[:, None, :] / self.tdmse_kl_temperature, dim=-1), - F.softmax(kl_outputs[:, None, :] / self.tdmse_kl_temperature, dim=-1) - ) * self.tdmse_kl_temperature**2 + F.log_softmax(student_outputs / self.tdmse_kl_temperature, dim=-1), + F.softmax(teacher_kl_outputs / self.tdmse_kl_temperature, dim=-1) + ) * self.tdmse_kl_temperature * math.log(2 + sample_layer) loss1 += kl_loss # feature @@ -850,10 +874,10 @@ def compute_loss(self, model, inputs, return_outputs=False): slimmed_student_outputs = student_outputs[:, :hidden_size] slimmed_teacher_loss = self.loss_fct(labels, slimmed_teacher_outputs) - loss2 = self.tdmse_teacher_lambda * slimmed_teacher_loss - if self.tdmse_student_lambda > 0: - slimmed_student_loss = self.loss_fct(labels, slimmed_student_outputs) - loss2 += self.tdmse_student_lambda * slimmed_student_loss + loss2 = slimmed_teacher_loss + slimmed_student_loss = self.loss_fct(labels, slimmed_student_outputs) + loss2 += slimmed_student_loss / sample_layer + loss = loss1 + loss2 if self.fixed_teacher_name_or_path is not None: @@ -1334,7 +1358,8 @@ def fit(self, argument_kwargs: Optional[Dict] = None, trainer_kwargs: Optional[Dict] = None, loss_kwargs: Optional[Dict] = None, - apply_tdmse: bool = False): + apply_tdmse: bool = False, + filter_duplicate: bool = True): """ Fit using AnglE. @@ -1412,7 +1437,7 @@ def fit(self, ), callbacks=callbacks, data_collator=AngleDataCollator( - self.tokenizer, return_tensors="pt", max_length=self.max_length + self.tokenizer, return_tensors="pt", max_length=self.max_length, filter_duplicate=filter_duplicate ), **trainer_kwargs ) @@ -1428,6 +1453,7 @@ def evaluate(self, data: Dataset, batch_size: int = 32, threshold: Optional[floa self.tokenizer, return_tensors="pt", max_length=self.max_length, + filter_duplicate=False, ) y_trues, y_preds = [], [] # for X, y in data.make_iter(random=False):