forked from aioz-ai/MICCAI19-MedVQA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
language_model.py
executable file
·110 lines (95 loc) · 3.5 KB
/
language_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
This code is from Jin-Hwa Kim, Jaehyun Jun, Byoung-Tak Zhang's repository.
https://github.com/jnhwkim/ban-vqa
"""
import torch
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
class WordEmbedding(nn.Module):
"""Word Embedding
The ntoken-th dim is used for padding_idx, which agrees *implicitly*
with the definition in Dictionary.
"""
def __init__(self, ntoken, emb_dim, dropout, op=""):
super(WordEmbedding, self).__init__()
self.op = op
self.emb = nn.Embedding(ntoken + 1, emb_dim, padding_idx=ntoken)
if "c" in op:
self.emb_ = nn.Embedding(ntoken + 1, emb_dim, padding_idx=ntoken)
self.emb_.weight.requires_grad = False # fixed
self.dropout = nn.Dropout(dropout)
self.ntoken = ntoken
self.emb_dim = emb_dim
def init_embedding(self, np_file, tfidf=None, tfidf_weights=None):
weight_init = torch.from_numpy(np.load(np_file))
assert weight_init.shape == (self.ntoken, self.emb_dim)
self.emb.weight.data[: self.ntoken] = weight_init
if tfidf is not None:
if 0 < tfidf_weights.size:
weight_init = torch.cat(
[weight_init, torch.from_numpy(tfidf_weights)], 0
)
weight_init = tfidf.matmul(weight_init) # (N x N') x (N', F)
self.emb_.weight.requires_grad = True
if "c" in self.op:
self.emb_.weight.data[: self.ntoken] = weight_init.clone()
def forward(self, x):
emb = self.emb(x)
if "c" in self.op:
emb = torch.cat((emb, self.emb_(x)), 2)
emb = self.dropout(emb)
return emb
class QuestionEmbedding(nn.Module):
def __init__(self, in_dim, num_hid, nlayers, bidirect, dropout, rnn_type="GRU"):
"""Module for question embedding
"""
super(QuestionEmbedding, self).__init__()
assert rnn_type == "LSTM" or rnn_type == "GRU"
rnn_cls = (
nn.LSTM if rnn_type == "LSTM" else nn.GRU if rnn_type == "GRU" else None
)
self.rnn = rnn_cls(
in_dim,
num_hid,
nlayers,
bidirectional=bidirect,
dropout=dropout,
batch_first=True,
)
self.in_dim = in_dim
self.num_hid = num_hid
self.nlayers = nlayers
self.rnn_type = rnn_type
self.ndirections = 1 + int(bidirect)
def init_hidden(self, batch):
# just to get the type of tensor
weight = next(self.parameters()).data
hid_shape = (
self.nlayers * self.ndirections,
batch,
self.num_hid // self.ndirections,
)
if self.rnn_type == "LSTM":
return (
Variable(weight.new(*hid_shape).zero_()),
Variable(weight.new(*hid_shape).zero_()),
)
else:
return Variable(weight.new(*hid_shape).zero_())
def forward(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
output, hidden = self.rnn(x, hidden)
if self.ndirections == 1:
return output[:, -1]
forward_ = output[:, -1, : self.num_hid]
backward = output[:, 0, self.num_hid :]
return torch.cat((forward_, backward), dim=1)
def forward_all(self, x):
# x: [batch, sequence, in_dim]
batch = x.size(0)
hidden = self.init_hidden(batch)
output, hidden = self.rnn(x, hidden)
return output