Skip to content

Commit

Permalink
only mask co-word
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Aug 21, 2024
1 parent 6fce33e commit 3c4b272
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,8 +537,6 @@ class AngleDataCollator:
:param filter_duplicate: bool. Whether filter duplicate data
:param coword_random_mask_rate: float. Default 0.0.
If set it greater than 0, the random maked token prediction will be added to the training loss.
:param default_random_mask_rate: float. Default 0.1.
It works only when coword_random_mask_rate is greater than 0.
"""

tokenizer: PreTrainedTokenizerBase
Expand All @@ -547,7 +545,6 @@ class AngleDataCollator:
return_tensors: str = "pt"
filter_duplicate: bool = True
coword_random_mask_rate: float = 0.0
default_random_mask_rate: float = 0.05
special_token_id_names: List[str] = field(default_factory=lambda: [
'bos_token_id', 'eos_token_id', 'unk_token_id', 'sep_token_id',
'pad_token_id', 'cls_token_id', 'mask_token_id'])
Expand Down Expand Up @@ -639,26 +636,20 @@ def __call__(self, features: List[Dict], return_tensors: str = "pt") -> Dict[str
if cnt_common_tokens:
sample_size = max(1, int(len(cnt_common_tokens) * self.coword_random_mask_rate))
sampled_mask_tokens = random.sample(cnt_common_tokens, sample_size)
else:
sample_size = max(1, int(len(cnt_tokens) * self.default_random_mask_rate))
sampled_mask_tokens = random.sample(list(cnt_tokens), sample_size)
cnt_fea['mask_target_labels'] = cnt_fea['input_ids']
cnt_fea['input_ids'] = [self.tokenizer.mask_token_id
if idx in sampled_mask_tokens
else idx for idx in cnt_fea['input_ids']]
cnt_fea['mask_target_labels'] = cnt_fea['input_ids']
cnt_fea['input_ids'] = [self.tokenizer.mask_token_id
if idx in sampled_mask_tokens
else idx for idx in cnt_fea['input_ids']]

# mask first text
common_tokens_with_first_text = list(common_tokens_with_first_text)
if common_tokens_with_first_text:
sample_size = max(1, int(len(common_tokens_with_first_text) * self.coword_random_mask_rate))
sampled_mask_tokens = random.sample(common_tokens_with_first_text, sample_size)
else:
sample_size = max(1, int(len(first_text_tokens) * self.default_random_mask_rate))
sampled_mask_tokens = random.sample(list(first_text_tokens), sample_size)
current_features[0]['mask_target_labels'] = current_features[0]['input_ids']
current_features[0]['input_ids'] = [self.tokenizer.mask_token_id
if idx in sampled_mask_tokens
else idx for idx in current_features[0]['input_ids']]
current_features[0]['mask_target_labels'] = current_features[0]['input_ids']
current_features[0]['input_ids'] = [self.tokenizer.mask_token_id
if idx in sampled_mask_tokens
else idx for idx in current_features[0]['input_ids']]

new_features += current_features

Expand Down

0 comments on commit 3c4b272

Please sign in to comment.