forked from ikostrikov/pytorch-a2c-ppo-acktr-gail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
distributions.py
81 lines (60 loc) · 2.39 KB
/
distributions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from utils import AddBias
class Categorical(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(Categorical, self).__init__()
self.linear = nn.Linear(num_inputs, num_outputs)
def forward(self, x):
x = self.linear(x)
return x
def sample(self, x, deterministic):
x = self(x)
probs = F.softmax(x)
if deterministic is False:
action = probs.multinomial()
else:
action = probs.max(1, keepdim=True)[1]
return action
def logprobs_and_entropy(self, x, actions):
x = self(x)
log_probs = F.log_softmax(x)
probs = F.softmax(x)
action_log_probs = log_probs.gather(1, actions)
dist_entropy = -(log_probs * probs).sum(-1).mean()
return action_log_probs, dist_entropy
class DiagGaussian(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(DiagGaussian, self).__init__()
self.fc_mean = nn.Linear(num_inputs, num_outputs)
self.logstd = AddBias(torch.zeros(num_outputs))
def forward(self, x):
action_mean = self.fc_mean(x)
# An ugly hack for my KFAC implementation.
zeros = Variable(torch.zeros(action_mean.size()), volatile=x.volatile)
if x.is_cuda:
zeros = zeros.cuda()
action_logstd = self.logstd(zeros)
return action_mean, action_logstd
def sample(self, x, deterministic):
action_mean, action_logstd = self(x)
action_std = action_logstd.exp()
if deterministic is False:
noise = Variable(torch.randn(action_std.size()))
if action_std.is_cuda:
noise = noise.cuda()
action = action_mean + action_std * noise
else:
action = action_mean
return action
def logprobs_and_entropy(self, x, actions):
action_mean, action_logstd = self(x)
action_std = action_logstd.exp()
action_log_probs = -0.5 * ((actions - action_mean) / action_std).pow(2) - 0.5 * math.log(2 * math.pi) - action_logstd
action_log_probs = action_log_probs.sum(-1, keepdim=True)
dist_entropy = 0.5 + 0.5 * math.log(2 * math.pi) + action_logstd
dist_entropy = dist_entropy.sum(-1).mean()
return action_log_probs, dist_entropy