Skip to content

Commit

Permalink
update pretraining
Browse files Browse the repository at this point in the history
  • Loading branch information
lonePatient committed Dec 25, 2019
1 parent 3da2ebc commit d27bb91
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 42 deletions.
4 changes: 0 additions & 4 deletions albert_chinese_pytorch/model/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,8 +719,6 @@ def tie_weights(self, config):
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
self._tie_or_clone_data(self.cls.predictions.project_layer,
self.bert.embeddings.word_embeddings_2)

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, next_sentence_label=None):
Expand Down Expand Up @@ -794,8 +792,6 @@ def tie_weights(self):
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
self._tie_or_clone_data(self.cls.predictions.project_layer,
self.bert.embeddings.word_embeddings_2)

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None):
Expand Down
18 changes: 1 addition & 17 deletions albert_chinese_pytorch/model/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _tie_or_clone_weights(self, first_module, second_module):
if self.config.torchscript:
first_module.weight = nn.Parameter(second_module.weight.clone())
else:
first_module.weight.data = second_module.weight
first_module.weight = second_module.weight


if hasattr(first_module, 'bias') and first_module.bias is not None:
Expand All @@ -135,22 +135,6 @@ def _tie_or_clone_weights(self, first_module, second_module):
0
)

def _tie_or_clone_data(self, first_module, second_module):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""

if self.config.torchscript:
first_module.weight.data = nn.Parameter(second_module.weight.data.t().clone())
else:
first_module.weight.data = second_module.weight.data.t()
if hasattr(first_module, 'bias') and first_module.bias is not None:
first_module.bias.data = torch.nn.functional.pad(
first_module.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
'constant',
0
)

def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Expand Down
4 changes: 0 additions & 4 deletions albert_english_pytorch/model/modeling_albert.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,8 +674,6 @@ def tie_weights(self):
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
self._tie_or_clone_data(self.cls.predictions.transform.dense,
self.bert.encoder.embedding_hidden_mapping_in)

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None, next_sentence_label=None):
Expand Down Expand Up @@ -743,8 +741,6 @@ def tie_weights(self):
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
self._tie_or_clone_data(self.cls.predictions.transform.dense,
self.bert.encoder.embedding_hidden_mapping_in)

def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
masked_lm_labels=None):
Expand Down
18 changes: 1 addition & 17 deletions albert_english_pytorch/model/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _tie_or_clone_weights(self, first_module, second_module):
if self.config.torchscript:
first_module.weight = nn.Parameter(second_module.weight.clone())
else:
first_module.weight.data = second_module.weight
first_module.weight = second_module.weight


if hasattr(first_module, 'bias') and first_module.bias is not None:
Expand All @@ -131,22 +131,6 @@ def _tie_or_clone_weights(self, first_module, second_module):
0
)

def _tie_or_clone_data(self, first_module, second_module):
""" Tie or clone module weights depending of weither we are using TorchScript or not
"""

if self.config.torchscript:
first_module.weight.data = nn.Parameter(second_module.weight.data.t().clone())
else:
first_module.weight.data = second_module.weight.data.t()
if hasattr(first_module, 'bias') and first_module.bias is not None:
first_module.bias.data = torch.nn.functional.pad(
first_module.bias.data,
(0, first_module.weight.shape[0] - first_module.bias.shape[0]),
'constant',
0
)

def resize_token_embeddings(self, new_num_tokens=None):
""" Resize input token embeddings matrix of the model if new_num_tokens != config.vocab_size.
Take care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
Expand Down

0 comments on commit d27bb91

Please sign in to comment.