-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdeep_value_function.py
executable file
·39 lines (33 loc) · 1.37 KB
/
deep_value_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch.nn as nn
from value_function import ValueFunction
from torch.optim import Adam
from deep_agent import DeepAgent
class DeepValueFunction(ValueFunction, DeepAgent):
"""
A neural network to represent the Value-function.
This class uses PyTorch for the neural network framework (https://pytorch.org/).
"""
def __init__(
self, mdp, state_space, hidden_dim=64, alpha=0.001
) -> None:
self.mdp = mdp
self.state_space = state_space
self.alpha = alpha
# Create a sequential neural network to represent the Q function
self.value_network = nn.Sequential(
nn.Linear(in_features=self.state_space, out_features=hidden_dim),
nn.ReLU(),
nn.Linear(in_features=hidden_dim, out_features=hidden_dim),
nn.ReLU(),
nn.Linear(in_features=hidden_dim, out_features=1),
)
self.optimiser = Adam(self.value_network.parameters(), lr=self.alpha)
def update(self, state, delta):
self.optimiser.zero_grad()
(delta ** 2).backward() # Back-propagate the loss through the network
self.optimiser.step() # Do a gradient descent step with the optimiser
def get_value(self, state):
# pass through the network to get the value
state = self.encode_state(state)
value = self.value_network(state)
return value