Skip to content

Commit

Permalink
bugfix: mask_target_labels
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Aug 22, 2024
1 parent 3c4b272 commit a9721ca
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,20 +636,20 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str
if cnt_common_tokens:
sample_size = max(1, int(len(cnt_common_tokens) * self.coword_random_mask_rate))
sampled_mask_tokens = random.sample(cnt_common_tokens, sample_size)
cnt_fea['mask_target_labels'] = cnt_fea['input_ids']
cnt_fea['input_ids'] = [self.tokenizer.mask_token_id
if idx in sampled_mask_tokens
else idx for idx in cnt_fea['input_ids']]
cnt_fea['mask_target_labels'] = cnt_fea['input_ids']

# mask first text
common_tokens_with_first_text = list(common_tokens_with_first_text)
if common_tokens_with_first_text:
sample_size = max(1, int(len(common_tokens_with_first_text) * self.coword_random_mask_rate))
sampled_mask_tokens = random.sample(common_tokens_with_first_text, sample_size)
current_features[0]['mask_target_labels'] = current_features[0]['input_ids']
current_features[0]['input_ids'] = [self.tokenizer.mask_token_id
if idx in sampled_mask_tokens
else idx for idx in current_features[0]['input_ids']]
current_features[0]['mask_target_labels'] = current_features[0]['input_ids']

new_features += current_features

Expand Down

0 comments on commit a9721ca

Please sign in to comment.