diff --git a/angle_emb/angle.py b/angle_emb/angle.py index 81da649..ac49edb 100644 --- a/angle_emb/angle.py +++ b/angle_emb/angle.py @@ -636,20 +636,20 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str if cnt_common_tokens: sample_size = max(1, int(len(cnt_common_tokens) * self.coword_random_mask_rate)) sampled_mask_tokens = random.sample(cnt_common_tokens, sample_size) - cnt_fea['mask_target_labels'] = cnt_fea['input_ids'] cnt_fea['input_ids'] = [self.tokenizer.mask_token_id if idx in sampled_mask_tokens else idx for idx in cnt_fea['input_ids']] + cnt_fea['mask_target_labels'] = cnt_fea['input_ids'] # mask first text common_tokens_with_first_text = list(common_tokens_with_first_text) if common_tokens_with_first_text: sample_size = max(1, int(len(common_tokens_with_first_text) * self.coword_random_mask_rate)) sampled_mask_tokens = random.sample(common_tokens_with_first_text, sample_size) - current_features[0]['mask_target_labels'] = current_features[0]['input_ids'] current_features[0]['input_ids'] = [self.tokenizer.mask_token_id if idx in sampled_mask_tokens else idx for idx in current_features[0]['input_ids']] + current_features[0]['mask_target_labels'] = current_features[0]['input_ids'] new_features += current_features