Skip to content

Commit

Permalink
Fix entropy regularization to be more readable
Browse files Browse the repository at this point in the history
  • Loading branch information
greydanus authored Jan 18, 2019
1 parent f967681 commit 85899d7
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions baby-a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def cost_func(args, values, logps, actions, rewards):
discounted_r = torch.tensor(discounted_r.copy(), dtype=torch.float32)
value_loss = .5 * (discounted_r - values[:-1,0]).pow(2).sum()

entropy_loss = (-logps * torch.exp(logps)).sum() # encourage lower entropy
return policy_loss + 0.5 * value_loss + 0.01 * entropy_loss
entropy_loss = (-logps * torch.exp(logps)).sum() # entropy definition, for entropy regularization
return policy_loss + 0.5 * value_loss - 0.01 * entropy_loss

def train(shared_model, shared_optimizer, rank, args, info):
env = gym.make(args.env) # make a local (unshared) environment
Expand Down

0 comments on commit 85899d7

Please sign in to comment.