You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Following the basic sketch in composite_actor.py I tried to implement a simple but full blown example of a probabilistic actor with a composite distribution. The application is a custom environment representing a simple 2D navigation task.
Unfortunately despite being a rather simple task, the training of the actor does not bring satisfying results.
When evaluating the trained actor, it is in general not able to do the necessary moves to reach a target 2D Point from a given position.
Before giving up, I want to ask this community, if anyone has an idea, why the training does not succeed?
Remark: I did hyper parameter optimisation with ray tune, but still no success.
Is there a problem with the implementation?
import torch
from tensordict import TensorDict
from torchrl.data import UnboundedContinuous, Bounded, Composite, Categorical, OneHot
from torchrl.envs import Compose, StepCounter, TransformedEnv
from torchrl.envs import EnvBase
from torchrl.envs.utils import check_env_specs
class ToyNavigation(EnvBase):
def __init__(self, max_step_width=1., grid_extent=5, device="cuda"):
self.device = device
super().__init__(device=device, batch_size=[])
self.max_step_width = max_step_width
self.grid_extent = grid_extent
self.target = torch.zeros(2, device=self.device)
self.position = torch.zeros(2, device=self.device)
self.initial_position = torch.zeros(2, device=self.device)
self.steps = 0
self._make_spec()
def _reset(self, tensordict=None, **kwargs):
"""
Constructs and returns a TensorDict containing the current observation.
The target and position values are randomly generated within the range of
[-grid_extent, grid_extent].
"""
self.target = torch.rand(2,
device=self.device) * 2 * self.grid_extent - self.grid_extent # Random target in [-grid_extent, grid_extent]
self.position = torch.rand(2,
device=self.device) * 2 * self.grid_extent - self.grid_extent # Random position in [-grid_extent, grid_extent]
self.initial_position = self.position
self.steps = 0
position = self.position.clone() # do we have to clone to get distinct positions in the trajectory?
target = self.target.clone()
td = TensorDict({
"observation": {"position": position, "target": target},
}, batch_size=self.batch_size)
return td
def _step(self, tensordict):
"""
Perform a single step in the environment.
This function takes an input `tensordict` containing the action details and
updates the agent's position accordingly. The reward is calculated based
on the distance to the target, and whether the task is done is determined
by whether the agent is close enough to the target.
:param tensordict: A dictionary containing the action information.
It should have "axis" and "magnitude" keys to
specify the movement direction and (normalized) step size. The step_size is multiplied with max_step_width
:type tensordict: TensorDict
:return: A dictionary with updated observation including current position
and target, computed reward, and done flag.
:rtype: TensorDict
"""
self.steps += 1
direction = tensordict["action"]["axis"] # 0 for x, 1 for y
magnitude = tensordict["action"]["magnitude"]
move = torch.zeros(2, device=self.device)
move[direction] = magnitude * self.max_step_width
self.position += move
distance = torch.norm(self.position - self.target)
reward = -distance * 10 - self.steps
done = distance < 0.1
# Remark: gets truncated when steps > max_steps, see step counter transform
position = self.position.clone() # do we have to clone to get distinct positions in the trajectory?
target = self.target.clone()
next = TensorDict({
"observation": {"position": position, "target": target},
"reward": reward,
"done": done,
}, tensordict.shape)
return next
def _set_seed(self, seed):
torch.manual_seed(seed)
def _make_spec(self, **kwargs):
self.observation_spec = Composite(observation=Composite(
position=UnboundedContinuous(shape=torch.Size([2]), device=self.device),
target=UnboundedContinuous(shape=torch.Size([2]), device=self.device),
shape=torch.Size([])
), shape=torch.Size([]))
self.action_spec = action_spec()
self.reward_spec = UnboundedContinuous(1)
self.done_spec = Categorical(n=2, shape=torch.Size([1]), dtype=torch.bool)
self.terminated_spec = Categorical(n=2, shape=torch.Size([1]), dtype=torch.bool)
def action_spec():
return Composite(
action=Composite(
magnitude=Bounded(low=-1., high=1., shape=torch.Size([]), dtype=torch.float), # normalized width of step
axis=OneHot(n=2, dtype=torch.int), # x or y axis
shape=torch.Size([])
), shape=torch.Size([]))
def make_toy_env(max_step_width=1, grid_extent=5, max_steps=10, device='cuda'):
"""
This function initializes a ToyNavigation environment with the specified maximum step width,
grid extent, device, and maximum steps. It then applies the StepCounter transformation to the environment
which introduces step_count in the tensordict and limits the number of step to `max_steps`.
"""
env = ToyNavigation(max_step_width=max_step_width, grid_extent=grid_extent, device=device)
env = TransformedEnv(env, Compose(
#ObservationNorm(in_keys=[("observation", "position"), ("observation", "target")]),
StepCounter(max_steps=max_steps), ))
# env.transform[0].init_stats(key=("observation", "position"), num_iter=2000)
return env
def custom_tensor_print(tensor):
# Remove tensor brackets and device info, then wrap in parentheses
tensor_str = str(tensor.cpu().numpy()).strip('[]')
return f"({tensor_str})"
if __name__ == '__main__':
env = make_toy_env()
check_env_specs(env, check_dtype=True)
eval_rollout = env.rollout(max_steps=5)
print(eval_rollout)
Modules:
import torch
import torch.nn.functional as F
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, CompositeDistribution, InteractionType
from torch import nn
from torch.distributions import Categorical, Normal
from torchrl.modules import ProbabilisticActor, ValueOperator
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE
from toy_navigation.env import action_spec
def normalize(tensor):
"""
Normalize the input tensor to a range of [-1, 1].
"""
min_val = tensor.min(dim=-1, keepdim=True)[0]
max_val = tensor.max(dim=-1, keepdim=True)[0]
return 2 * (tensor - min_val) / (max_val - min_val + 1e-8) - 1
# actor
class Policy(nn.Module):
def __init__(self, n_hidden_layers=2, hidden_features=5, device='cuda'):
super().__init__()
self.device = device
self.hidden_layers = nn.ModuleList()
for i in range(n_hidden_layers):
in_channels = 4 if i == 0 else hidden_features
hidden_layer = nn.Linear(in_features=in_channels, out_features=hidden_features)
nn.init.xavier_uniform_(hidden_layer.weight)
self.hidden_layers.append(hidden_layer)
self.axis_output = nn.Linear(in_features=hidden_features, out_features=2)
nn.init.xavier_uniform_(self.axis_output.weight)
self.magnitude_output = nn.Linear(in_features=hidden_features, out_features=2)
nn.init.xavier_uniform_(self.magnitude_output.weight)
self.to(device)
def forward(self, position, target):
"""
Forward pass for the actor network. This method concatenates the `position` and `target` tensors,
processes them through a series of hidden layers, and computes the final output, which includes
axis logits for the direction of the movement and magnitude for the (normalized) step_width.
"""
# Normalize position and target
position = normalize(position)
target = normalize(target)
batched = position.ndim > 1 # called with batched params?
# Concatenate position and target
tensor = torch.cat([position, target], dim=-1).to(self.device)
# Process through hidden layers
for layer in self.hidden_layers:
tensor = F.leaky_relu(layer(tensor))
# Separate loc and scale from magnitude output
axis_logits = self.axis_output(tensor)
magnitude = self.magnitude_output(tensor)
if batched:
loc = magnitude[...,0]
scale = magnitude[...,1]
else:
loc = magnitude[0]
scale = magnitude[1]
loc = F.tanh(loc)
scale = torch.clamp(F.softplus(scale) + 1e-5,min=1e-5,max=2.)
result = TensorDict({
'params': {'axis': {'logits': axis_logits},
'magnitude': {'loc': loc,
'scale': scale}}
})
return result
def get_actor(n_layers=3, hidden_features=10, device='cuda'):
"""
Create an actor module composed of a Policy neural network, TensorDictModule to transform input tensordicts to ouput, and the
ProbabilisticActor configuration
"""
actor_net = Policy(n_layers, hidden_features, device)
td_module = TensorDictModule(
actor_net,
in_keys={("observation", "position"): 'position', ("observation", "target"): 'target'},
out_keys = [('params', 'axis', 'logits'), ('params', 'magnitude','loc'), ('params', 'magnitude','scale')]
)
policy_module = ProbabilisticActor(
module=td_module,
in_keys=["params"],
distribution_class=CompositeDistribution,
spec = action_spec(),
distribution_kwargs={
"aggregate_probabilities": True,
"distribution_map": {
"axis": Categorical,
"magnitude": Normal,
},
"name_map": {
"axis": ("action", "axis"), # see action_spec and nested structure of action
"magnitude": ("action", "magnitude"),
},
},
return_log_prob=True,
default_interaction_type=InteractionType.MODE # siehe hinweis zu Laufzeit bei DETERMINISTIC
)
policy_module.td_module = td_module
policy_module.actor_net = actor_net
return policy_module
class ValueNetwork(nn.Module):
def __init__(self, n_hidden_layers=1, hidden_features=2, device='cuda'):
super().__init__()
self.device = device
self.hidden_layers = nn.ModuleList()
for i in range(n_hidden_layers):
in_channels = 4 if i == 0 else hidden_features
hidden_layer = nn.Linear(in_features=in_channels, out_features=hidden_features)
nn.init.xavier_uniform_(hidden_layer.weight)
self.hidden_layers.append(hidden_layer)
self.value_output = nn.Linear(in_features=hidden_features, out_features=1)
nn.init.xavier_uniform_(self.value_output.weight)
self.to(device)
def forward(self, position, target):
# Normalize position and target
position = normalize(position)
target = normalize(target)
tensor = torch.cat([position, target], dim=-1).to(self.device)
for layer in self.hidden_layers:
tensor = F.leaky_relu(layer(tensor))
value = F.tanh(self.value_output(tensor))
return value
# value net
def get_critic( n_layers=3, hidden_features=10,device='cuda'):
"""
Creates a ValueOperator critic by initializing the ValueNetwork with specified
number of layers, and hidden features.
The critic is responsible for evaluating the value of given states in the environment using a neural
network configured with the provided parameters.
"""
critic_net = ValueNetwork(n_hidden_layers=n_layers, hidden_features=hidden_features, device=device)
critic = ValueOperator(
module=critic_net,
in_keys={("observation", "position"): 'position', ("observation", "target"): 'target'},
)
return critic
def get_advantage(critic, gamma=0.9995, lmbda=0.95, average_gae=True):
"""
Computes the generalized advantage estimation (GAE) for a given critic network.
This function initializes the GAE module with the specified parameters, which are used to
calculate the advantage function for the reinforcement learning scenario. Generalized advantage
estimation helps to reduce variance while maintaining an admissible level of bias.
"""
module = GAE(
gamma=gamma, lmbda=lmbda, value_network=critic, average_gae=average_gae
)
return module
def get_loss_module(actor, critic,
clip_epsilon = 0.2, entropy_eps = 1e-4,normalize_advantage=False,clip_value=True,separate_losses=False,critic_coef=1.0 ):
"""
Returns the loss module for the Clip Proximal Policy Optimization (PPO)
algorithm using the given actor and critic networks.
"""
loss_module = ClipPPOLoss(
actor_network=actor,
critic_network=critic,
clip_epsilon=clip_epsilon,
entropy_bonus=bool(entropy_eps),
entropy_coef=entropy_eps,
critic_coef=critic_coef,
loss_critic_type="smooth_l1",
normalize_advantage=normalize_advantage,
reduction="mean",
clip_value=clip_value,
separate_losses=separate_losses,
)
return loss_module
Training
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
from torchrl.collectors import SyncDataCollector
from torchrl.envs.utils import ExplorationType, set_exploration_type
from tqdm import tqdm
from toy_navigation.env import make_toy_env
from toy_navigation.modules import get_actor, get_critic, get_advantage, get_loss_module
def do_train(silent=True, device=None, **kwargs):
"""
Trains a reinforcement learning model for the custom "Toy Navigation" Environment using given or default configuration parameters.
This function initializes the necessary training modules, updates configuration settings, and
executes the training loop for a specified number of epochs and frames. The function supports
to either report to Ray Tune or generate visual logs for monitoring
the training progress.
:param silent: If True, the function suppresses detailed output and reports metrics to Ray Tune.
If False, the function shows progress and logging details using `tqdm`.
:param device: device for training
:param kwargs: Additional keyword arguments to override the default configuration parameters.
:return: None
"""
# Default configuration parameters
default_config = {
'frames_per_batch': 4500, 'total_frames': 76500,
'learn_rate': 2e-5,'lr_scheduling': True, 'num_epochs': 500,
'actor_n_layers': 2, 'actor_hidden_features': 40,
'critic_n_layers': 2,'critic_hidden_features': 40,
'split_trajs': False,
'gamma': 0.85,'lmbda': 0.96,'average_gae': True,
'clip_epsilon' : 0.2, 'entropy_eps' :3e-4, 'normalize_advantage' : True, 'clip_value' : True,
'separate_losses' : True, 'critic_coef' : 0.75
}
# Update default config with any provided kwargs
config = {**default_config, **kwargs}
frames_per_batch = config['frames_per_batch']
learn_rate = config['learn_rate']
num_epochs = config['num_epochs']
total_frames = config['total_frames']
lr_scheduling = config['lr_scheduling']
if silent:
from ray import train
if not device:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Toy Navigation parameters
grid_extent = 5
max_step_width = 1
max_steps = 30 #
# instantiate modules
actor = get_actor(device=device, n_layers=config['actor_n_layers'],
hidden_features=config['actor_hidden_features'])
critic = get_critic(device=device, n_layers=config['critic_n_layers'],
hidden_features=config['critic_hidden_features'])
advantage_module = get_advantage(critic, gamma=config['gamma'], lmbda=config['lmbda'],
average_gae=config['average_gae'])
loss_module = get_loss_module(actor, critic, clip_epsilon = config['clip_epsilon'],
entropy_eps = config['entropy_eps'],normalize_advantage=config['normalize_advantage'],clip_value=config['clip_value'],separate_losses=config['separate_losses'],critic_coef=config['critic_coef'])
# instantiate optimiser and scheduler
optim = torch.optim.Adam(loss_module.parameters(), learn_rate)
if lr_scheduling:
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optim, total_frames // frames_per_batch, 0.0)
# define collector
collector = SyncDataCollector(
lambda: make_toy_env(max_step_width=max_step_width, grid_extent=grid_extent, max_steps=max_steps),
device=device,
policy=actor,
frames_per_batch=frames_per_batch,
total_frames=total_frames,
split_trajs=config['split_trajs'],
exploration_type=ExplorationType.RANDOM,
set_truncated=True,
)
if not silent:
env = make_toy_env(max_step_width=max_step_width, grid_extent=grid_extent, max_steps=max_steps)
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""
for batch_idx, data in enumerate(collector):
for _ in range(num_epochs):
advantage_module(data) # do we need torch_no_grad context?
data_view = data.reshape(-1) # Dimension: Frames per bach
loss_vals = loss_module(data_view)
loss_value = (
loss_vals["loss_objective"]
+ loss_vals["loss_critic"]
+ loss_vals["loss_entropy"]
)
loss_value = loss_value.mean()
if silent:
# Report metrics to Tune
reported_loss = loss_value.detach().cpu().item()
train.report({'loss': reported_loss})
loss_value.backward()
torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_norm=1)
optim.step()
optim.zero_grad()
if lr_scheduling:
scheduler.step()
if not silent:
logs["reward"].append(data["next", "reward"].mean().item())
pbar.update(data.numel())
cum_reward_str = f"avg. reward={logs['reward'][-1]: 4.2f} (init={logs['reward'][0]: 4.2f})"
logs["learn_rate"].append(optim.param_groups[0]["lr"])
lr_str = f"learn_rate policy: {logs['learn_rate'][-1]: 4.6f}"
if batch_idx % 10 == 0:
# We evaluate the policy once every 10 batches of data.
# Evaluation is rather simple: execute the policy without exploration
# (take the expected value of the action distribution) for a given
# number of steps (max_steps, which is our ``env`` horizon).
# The ``rollout`` method of the ``env`` can take a policy as argument:
# it will then execute this policy at each step.
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
# execute a rollout with the trained policy
eval_rollout = env.rollout(max_steps, actor)
logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
eval_str = (
f"eval avg. reward: {logs['eval reward'][-1]: 4.2f} "
f"(init: {logs['eval reward'][0]: 4.2f}) "
)
del eval_rollout
pbar.set_description(", ".join([eval_str, cum_reward_str, lr_str]))
collector.shutdown()
if not silent:
pbar.close()
plt.figure(figsize=(10, 10))
plt.subplot(2, 2, 1)
plt.plot(logs["reward"])
plt.title("training rewards (average)")
plt.subplot(2, 2, 2)
plt.plot(logs["step_count"])
plt.title("Max step count (training)")
plt.subplot(2, 2, 3)
plt.plot(logs["eval reward (sum)"])
plt.title("Reward (test)")
plt.subplot(2, 2, 4)
plt.plot(logs["eval step_count"])
plt.title("Max step count (test)")
plt.show()
return actor
def eval_actor(actor, env, max_steps):
with set_exploration_type(ExplorationType.MODE), torch.no_grad():
for batch_idx in range(3):
print(f"\n Eval Experiment {batch_idx}")
env.reset()
print(f"Initial Pos: {env.initial_position}, Target: {env.target}")
eval_rollout = env.rollout(max_steps, actor)
for j in range(eval_rollout.shape[0]):
target = eval_rollout[j]['observation', 'target']
position = eval_rollout[j]['observation', 'position']
distance = torch.norm(position - target)
print(f"Step {j}: Position: {position} in distance {distance.item()} \n ")
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
trained_actor = do_train(silent=False,actor_n_layers=2, actor_hidden_features=20,
frames_per_batch=2000,total_frames=50000)
eval_env = make_toy_env(max_step_width=1., grid_extent=5., max_steps=20.)
eval_actor(trained_actor, eval_env, max_steps=20)
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Following the basic sketch in composite_actor.py I tried to implement a simple but full blown example of a probabilistic actor with a composite distribution. The application is a custom environment representing a simple 2D navigation task.
Unfortunately despite being a rather simple task, the training of the actor does not bring satisfying results.
When evaluating the trained actor, it is in general not able to do the necessary moves to reach a target 2D Point from a given position.
Before giving up, I want to ask this community, if anyone has an idea, why the training does not succeed?
Remark: I did hyper parameter optimisation with ray tune, but still no success.
Is there a problem with the implementation?
My code (see also TorchRL experiments):
Environment
Modules:
Training
Beta Was this translation helpful? Give feedback.
All reactions