-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
99 lines (83 loc) · 2.99 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from transformers import AutoModel, AlbertPreTrainedModel, AlbertModel
from transformers.models.albert.modeling_albert import AlbertMLMHead
from torch import nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import config_lm
class MLPHead(nn.Module):
def __init__(self, in_channels, mlp_hidden_size, projection_size):
super(MLPHead, self).__init__()
self.net = nn.Sequential(
nn.Linear(in_channels, mlp_hidden_size),
nn.BatchNorm1d(mlp_hidden_size),
nn.ReLU(inplace=True),
nn.Linear(mlp_hidden_size, projection_size),
)
def forward(self, x):
return self.net(x)
class ByolLanguegeModel(AlbertPreTrainedModel):
"""
a Pytorch nn Module that incorporate BYOL approach for transformer based models.
the model output masked tokens embeddings after being passed through an MLP layer.
plus the logits for the model predictions and the hidden states.
"""
def __init__(self, config):
super().__init__(config)
self.albert = AlbertModel(config)
self.cls = AlbertMLMHead(config)
self.config = config
self.mlp = MLPHead(
in_channels=config.hidden_size,
mlp_hidden_size=config.hidden_size * 10,
projection_size=config.hidden_size,
)
self.init_weights()
def get_output_embeddings(self):
return self.cls.decoder
def batched_index_select(self, input, dim, index):
for ii in range(1, len(input.shape)):
if ii != dim:
index = index.unsqueeze(ii)
expanse = list(input.shape)
expanse[0] = -1
expanse[dim] = -1
index = index.expand(expanse)
return torch.gather(input, dim, index)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
masked_index=None,
):
outputs = self.albert(
input_ids=input_ids,
return_dict=True,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_hidden_states=True,
)
sequence_outputs = outputs[0]
prediction_scores = self.cls(sequence_outputs)
"""
1: to access the hidden states
0: to access the embedding output layer (batch_size, seq_length, hidden_size:768)
masked_embeddings should be (batch_size, 768)
"""
masked_embeddings = outputs.hidden_states[0]
masked_index = masked_index.unsqueeze(1)
masked_embeddings = torch.cat(
[
torch.index_select(a, 0, i).unsqueeze(0)
for a, i in zip(masked_embeddings, masked_index)
]
)
masked_embeddings = self.mlp(masked_embeddings.squeeze())
return (
prediction_scores,
masked_embeddings,
outputs.hidden_states,
outputs.attentions,
)