-
Notifications
You must be signed in to change notification settings - Fork 0
/
ops.py
80 lines (60 loc) · 2.64 KB
/
ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import torch
import numpy as np
import torch.nn.functional as F
import torch.autograd as autograd
import utils
def batch_zero_grad(modules):
for module in modules:
module.zero_grad()
def batch_update_optim(optimizers):
for optimizer in optimizers:
optimizer.step()
def free_params(modules):
if type(modules) is not list:
for p in modules.parameters():
p.requires_grad = False
else:
for module in modules:
for p in module.parameters():
p.requires_grad = False
def frozen_params(modules):
if type(modules) is not list:
for p in modules.parameters():
p.requires_grad = False
else:
for module in modules:
for p in module.parameters():
p.requires_grad = False
def pretrain_loss(encoded, noise):
mean_z = torch.mean(noise, dim=0, keepdim=True)
mean_e = torch.mean(encoded, dim=0, keepdim=True)
mean_loss = F.mse_loss(mean_z, mean_e)
cov_z = torch.matmul((noise-mean_z).transpose(0, 1), noise-mean_z)
cov_z /= 999
cov_e = torch.matmul((encoded-mean_e).transpose(0, 1), encoded-mean_e)
cov_e /= 999
cov_loss = F.mse_loss(cov_z, cov_e)
return mean_loss, cov_loss
def grad_penalty_1dim(args, netD, data, fake, device):
alpha = torch.randn(args.batch_size, 1, requires_grad=True).to(device)
alpha = alpha.expand(data.size()).to(device)
interpolates = alpha * data + ((1 - alpha) * fake).to(device)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * args.l
return gradient_penalty
def grad_penalty_3dim(args, netD, data, fake, device):
out_size = int(np.sqrt(args.output//3))
alpha = torch.randn(args.batch_size, 1, requires_grad=True).to(device)
alpha = alpha.expand(args.batch_size, data.nelement()/args.batch_size)
alpha = alpha.contiguous().view(args.batch_size, 3, out_size, out_size)
interpolates = alpha * data + ((1 - alpha) * fake).to(device)
disc_interpolates = netD(interpolates)
gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients = gradients.view(gradients.size(0), -1)
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * args.l
return gradient_penalty