-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathactor_critic.py
38 lines (35 loc) · 1.33 KB
/
actor_critic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# @Filename: actor_critic.py
# @Author: Ashutosh Tiwari
# @Email: [email protected]
# @Time: 4/12/22 6:03 AM
from typing import Callable, Dict, Any, Optional, Type, Tuple
import gym
from torch import nn
from stable_baselines3.common.policies import ActorCriticCnnPolicy
class A2CCNNPolicy(ActorCriticCnnPolicy):
def __init__(
self,
observation_space: gym.spaces.Space,
action_space: gym.spaces.Space,
lr_schedule: Callable[[float], float],
actor_critic_class: Type[nn.Module],
features_extractor_class: Type[nn.Module],
features_extractor_kwargs: Optional[Dict[str, Any]] = dict(),
*args,
**kwargs,
):
self.actor_critic_layer = actor_critic_class
super(A2CCNNPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
# Pass remaining arguments to base class
*args,
**kwargs,
)
# Disable orthogonal initialization
self.ortho_init = False
self.features_extractor = features_extractor_class(self.observation_space, **features_extractor_kwargs)
self.features_dim = self.features_extractor.features_dim
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = self.actor_critic_layer(self.observation_space)