-
Notifications
You must be signed in to change notification settings - Fork 0
/
layers.py
44 lines (35 loc) · 1.34 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
# @Filename: layers.py
# @Author: Ashutosh Tiwari
# @Email: [email protected]
# @Time: 4/18/22 2:17 AM
from constants import *
import torch
import torch.nn as nn
from typing import Callable, Dict, Any, Optional, Type, Tuple
import gym
class ActorCriticLayer(nn.Module):
def __init__(
self,
observation_space: gym.spaces.Box,
last_layer_dim_pi: int = 64,
last_layer_dim_vf: int = 64,
):
super(ActorCriticLayer, self).__init__()
# these two variables are required by baselines
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# feature_dim = observation_space.shape[0]
self.pi_network = nn.Sequential(
nn.Linear(512, last_layer_dim_pi), nn.ReLU()
).to(DEVICE)
self.vf_network = nn.Sequential(
nn.Linear(512, last_layer_dim_vf), nn.ReLU()
).to(DEVICE)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self.pi_network(x), self.vf_network(x)
# functions required by baselines
def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
return self.pi_network(features)
# functions required by baselines
def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
return self.vf_network(features)