-
Notifications
You must be signed in to change notification settings - Fork 207
/
Copy pathSentiGAN_G.py
65 lines (52 loc) · 2.39 KB
/
SentiGAN_G.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# -*- coding: utf-8 -*-
# @Author : William
# @Project : TextGAN-william
# @FileName : SentiGAN_G.py
# @Time : Created at 2019-07-26
# @Blog : http://zhiweil.ml/
# @Description :
# Copyrights (C) 2018. All Rights Reserved.
import torch
import torch.nn.functional as F
from models.generator import LSTMGenerator
class SentiGAN_G(LSTMGenerator):
def __init__(self, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu=False):
super(SentiGAN_G, self).__init__(embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, gpu)
self.name = 'sentigan'
def forward(self, inp, hidden, need_hidden=False, use_log=True):
"""
Embeds input and applies LSTM
:param inp: batch_size * seq_len
:param hidden: (h, c)
:param need_hidden: if return hidden, use for sampling
"""
emb = self.embeddings(inp) # batch_size * len * embedding_dim
if len(inp.size()) == 1:
emb = emb.unsqueeze(1) # batch_size * 1 * embedding_dim
out, hidden = self.lstm(emb, hidden) # out: batch_size * seq_len * hidden_dim
out = out.contiguous().view(-1, self.hidden_dim) # out: (batch_size * len) * hidden_dim
out = self.lstm2out(out) # batch_size * seq_len * vocab_size
# out = self.temperature * out # temperature
if use_log:
pred = F.log_softmax(out, dim=-1)
else:
pred = F.softmax(out, dim=-1)
if need_hidden:
return pred, hidden
else:
return pred
def batchPGLoss(self, inp, target, reward):
"""
Returns a policy gradient loss
:param inp: batch_size x seq_len, inp should be target with <s> (start letter) prepended
:param target: batch_size x seq_len
:param reward: batch_size (discriminator reward for each sentence, applied to each token of the corresponding sentence)
:return loss: policy loss
"""
batch_size, seq_len = inp.size()
hidden = self.init_hidden(batch_size)
out = self.forward(inp, hidden, use_log=False).view(batch_size, self.max_seq_len, self.vocab_size)
target_onehot = F.one_hot(target, self.vocab_size).float() # batch_size * seq_len * vocab_size
pred = torch.sum(out * target_onehot, dim=-1) # batch_size * seq_len
loss = -torch.sum(pred * (1 - reward))
return loss