From d5023a987893d8cd8ef85e54ee82415989c895fb Mon Sep 17 00:00:00 2001 From: chamorajg Date: Thu, 12 Sep 2024 13:07:30 -0700 Subject: [PATCH 1/4] update tdmpc into kscale sim module. --- sim/play_tdmpc.py | 196 ++++++++ sim/tdmpc/src/algorithm/__init__.py | 0 sim/tdmpc/src/algorithm/helper.py | 741 ++++++++++++++++++++++++++++ sim/tdmpc/src/algorithm/tdmpc.py | 248 ++++++++++ sim/tdmpc/src/logger.py | 173 +++++++ sim/train_tdmpc.py | 213 ++++++++ 6 files changed, 1571 insertions(+) create mode 100644 sim/play_tdmpc.py create mode 100644 sim/tdmpc/src/algorithm/__init__.py create mode 100644 sim/tdmpc/src/algorithm/helper.py create mode 100644 sim/tdmpc/src/algorithm/tdmpc.py create mode 100644 sim/tdmpc/src/logger.py create mode 100644 sim/train_tdmpc.py diff --git a/sim/play_tdmpc.py b/sim/play_tdmpc.py new file mode 100644 index 00000000..b1939554 --- /dev/null +++ b/sim/play_tdmpc.py @@ -0,0 +1,196 @@ +"""Trains a humanoid to stand up.""" + +import argparse +import isaacgym +import os +import cv2 +import torch +from sim.envs import task_registry +from sim.utils.helpers import get_args +from sim.tdmpc.src import logger +from sim.tdmpc.src.algorithm.helper import Episode, ReplayBuffer +from sim.tdmpc.src.algorithm.tdmpc import TDMPC +from dataclasses import dataclass, field +from isaacgym import gymapi +from typing import List +import time +import numpy as np +from pathlib import Path +import random +from datetime import datetime +torch.backends.cudnn.benchmark = True +__LOGS__ = "logs" + +@dataclass +class TDMPC_DoraConfigs: + seed: int = 42 + task : str = "walk" + exp_name : str = "dora" + device : str = "cuda:0" + num_envs : int = 10 + + lr : float = 1e-3 + modality : str = "state" + enc_dim: int = 512 # 256 + mlp_dim = [512, 256] # [256, 256] + latent_dim: int = 100 + + iterations : int = 12 + num_samples : int = 512 + num_elites : int = 50 + mixture_coef : float = 0.05 + min_std : float = 0.05 + temperature : float = 0.5 + momentum : float = 0.1 + horizon : int = 5 + std_schedule: str = f"linear(0.5, {min_std}, 3000)" + horizon_schedule: str = f"linear(1, {horizon}, 2500)" + + batch_size: int = 1024 + max_buffer_size : int = 1000000 + reward_coef : float = 1 + value_coef : float = 0.5 + consistency_coef : float = 2 + rho : float = 0.5 + kappa : float = 0.1 + per_alpha: float = 0.6 + per_beta : float = 0.4 + grad_clip_norm : float = 10 + seed_steps: int = 750 + update_freq: int = 2 + tau: int = 0.01 + + discount : float = 0.99 + buffer_device : str = "cpu" + train_steps : int = int(1e6) + num_q : int = 3 + + action_repeat : int = 2 + eval_freq: int = 15000 + eval_freq_episode : int = 10 + eval_episodes : int = 1 + + save_model : bool = True + save_video : bool = False + + use_wandb : bool = False + wandb_entity : str = "crajagopalan" + wandb_project : str = "xbot" + + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + +class VideoRecorder: + """Utility class for logging evaluation videos.""" + def __init__(self, root_dir, render_size=384, fps=25): + self.save_dir = (root_dir / 'eval_video') if root_dir else None + self.render_size = render_size + self.fps = fps + self.frames = [] + self.enabled = False + fourcc = cv2.VideoWriter_fourcc(*"MP4V") # type: ignore[attr-defined] + logger.make_dir(self.save_dir) + now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + # Creates a directory to store videos. + dir = os.path.join(self.save_dir, now + ".mp4") + self.video = cv2.VideoWriter(dir, fourcc, float(fps), (1920, 1080)) + + def init(self, env, h1, enabled=True): + self.frames = [] + self.enabled = self.save_dir and enabled + self.record(env, h1) + + def record(self, env, h1): + env.gym.fetch_results(env.sim, True) + env.gym.step_graphics(env.sim) + env.gym.render_all_camera_sensors(env.sim) + img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR) + img = np.reshape(img, (1080, 1920, 4)) + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + self.video.write(img) + + def save(self,): + self.video.release() + +def evaluate(test_env, agent, h1, step, video, action_repeat=1): + """Evaluate a trained agent and optionally save a video.""" + episode_rewards = [] + obs, privileged_obs = test_env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + dones, ep_reward, t = torch.tensor([False] * test_env.num_envs), torch.tensor([0.] * test_env.num_envs), 0 + if video: video.init(test_env, h1, enabled=True) + for i in range(int(1000 // action_repeat)): + actions = agent.plan(state, eval_mode=True, step=step, t0=t==0) + for _ in range(action_repeat): + obs, privileged_obs, rewards, dones, infos = test_env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + ep_reward += rewards.cpu() + t += 1 + if video: video.record(test_env, h1) + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + episode_rewards.append(ep_reward) + if video: video.save() + print(f"Timestep : {t} Episode Rewards - {torch.cat(episode_rewards).mean().item()}") + return torch.nanmean(torch.cat(episode_rewards)).item() + +def play(args: argparse.Namespace) -> None: + """Training script for TD-MPC. Requires a CUDA-enabled device.""" + assert torch.cuda.is_available() + env_cfg, _ = task_registry.get_cfgs(name=args.task) + env, _ = task_registry.make_env(name=args.task, args=args) + + fp = "/home/guest/sim/logs/2024-09-11_08-08-01_walk_state_dora/models/tdmpc_policy_2350.pt" + config = torch.load(fp)["config"] + tdmpc_cfg = TDMPC_DoraConfigs(**config) + env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) + + camera_properties = gymapi.CameraProperties() + camera_properties.width = 1920 + camera_properties.height = 1080 + h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties) + camera_offset = gymapi.Vec3(3, -3, 1) + camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135)) + actor_handle = env.gym.get_actor_handle(env.envs[0], 0) + body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0) + env.gym.attach_camera_to_body( + h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION + ) + + set_seed(tdmpc_cfg.seed) + work_dir = Path().cwd() / __LOGS__ / f"{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}_{str(tdmpc_cfg.seed)}" + + obs, privileged_obs = env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] + + tdmpc_cfg.obs_shape = [state.shape[0]] + tdmpc_cfg.action_shape = (env.num_actions) + tdmpc_cfg.action_dim = env.num_actions + tdmpc_cfg.episode_length = 100 # int(env.max_episode_length // tdmpc_cfg.action_repeat) + tdmpc_cfg.num_envs = env.num_envs + + L = logger.Logger(work_dir, tdmpc_cfg) + log_dir = logger.make_dir(work_dir) + video = VideoRecorder(log_dir) + + agent = TDMPC(tdmpc_cfg) + + agent.load(fp) + step = 0 + episode_idx, start_time = 0, time.time() + if fp is not None: + episode_idx = int(fp.split(".")[0].split("_")[-1]) + step = episode_idx * tdmpc_cfg.episode_length + + # Log training episode + evaluate(env, agent, h1, step, video, tdmpc_cfg.action_repeat) + print('Testing completed successfully') + +if __name__ == "__main__": + play(get_args()) \ No newline at end of file diff --git a/sim/tdmpc/src/algorithm/__init__.py b/sim/tdmpc/src/algorithm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sim/tdmpc/src/algorithm/helper.py b/sim/tdmpc/src/algorithm/helper.py new file mode 100644 index 00000000..87cc908a --- /dev/null +++ b/sim/tdmpc/src/algorithm/helper.py @@ -0,0 +1,741 @@ +import glob +import os +import pickle +import re +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from numpy.linalg import norm +from scipy.spatial import distance +from torch import distributions as pyd +from torch.distributions.utils import _standard_normal + +__REDUCE__ = lambda b: "mean" if b else "none" + + + +def l1(pred, target, reduce=False): + """Computes the L1-loss between predictions and targets.""" + return F.l1_loss(pred, target, reduction=__REDUCE__(reduce)) + + +def mse(pred, target, reduce=False): + """Computes the MSE loss between predictions and targets.""" + return F.mse_loss(pred, target, reduction=__REDUCE__(reduce)) + + +def bce(pred, target, logits=True, reduce=False): + """Computes the BCE loss between predictions and targets.""" + if logits: + return F.binary_cross_entropy_with_logits( + pred, target, reduction=__REDUCE__(reduce) + ) + return F.binary_cross_entropy(pred, target, reduction=__REDUCE__(reduce)) + + +def l1_quantile(y_true, y_pred, quantile=0.3): + """ + Compute the quantile loss. + + Args: + y_true (torch.Tensor): Ground truth values + y_pred (torch.Tensor): Predicted values + quantile (float): Quantile to compute, must be between 0 and 1 + + Returns: + torch.Tensor: Quantile loss + """ + errors = y_true - y_pred + loss = torch.max((quantile - 1) * errors, quantile * errors) + return torch.mean(loss) + + +def threshold_l2_expectile(diff, threshold=1e-2, expectile=0.99, reduce=False): + weight = torch.where(torch.abs(diff) > threshold, expectile, (1 - expectile)) + loss = weight * (diff**2) + reduction = __REDUCE__(reduce) + if reduction == "mean": + return torch.mean(loss) + elif reduction == "sum": + return torch.sum(loss) + return loss + + +def l2_expectile(diff, expectile=0.7, reduce=False): + weight = torch.where(diff > 0, expectile, (1 - expectile)) + loss = weight * (diff**2) + reduction = __REDUCE__(reduce) + if reduction == "mean": + return torch.mean(loss) + elif reduction == "sum": + return torch.sum(loss) + return loss + + +def mse_expectile(pred, target, expectile=0.7, reduce=False): + diff = pred - target + weight = torch.where(diff > 0, expectile, (1 - expectile)) + loss = weight * (diff**2) + reduction = __REDUCE__(reduce) + if reduction == "mean": + return torch.mean(loss) + elif reduction == "sum": + return torch.sum(loss) + return loss + + +def _get_out_shape(in_shape, layers): + """Utility function. Returns the output shape of a network for a given input shape.""" + x = torch.randn(*in_shape).unsqueeze(0) + return ( + (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x) + .squeeze(0) + .shape + ) + + +def gaussian_logprob(eps, log_std): + """Compute Gaussian log probability.""" + residual = (-0.5 * eps.pow(2) - log_std).sum(-1, keepdim=True) + return residual - 0.5 * np.log(2 * np.pi) * eps.size(-1) + + +def squash(mu, pi, log_pi): + """Apply squashing function.""" + mu = torch.tanh(mu) + pi = torch.tanh(pi) + log_pi -= torch.log(F.relu(1 - pi.pow(2)) + 1e-6).sum(-1, keepdim=True) + return mu, pi, log_pi + + +def orthogonal_init(m): + """Orthogonal layer initialization.""" + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data, gain) + if m.bias is not None: + nn.init.zeros_(m.bias) + + +def ema(m, m_target, tau): + """Update slow-moving average of online network (target network) at rate tau.""" + with torch.no_grad(): + for p, p_target in zip(m.parameters(), m_target.parameters()): + p_target.data.lerp_(p.data, tau) + + +def set_requires_grad(net, value): + """Enable/disable gradients for a given (sub)network.""" + for param in net.parameters(): + param.requires_grad_(value) + +def linear_schedule(schdl, step): + """ + Outputs values following a linear decay schedule. + Adapted from https://github.com/facebookresearch/drqv2 + """ + try: + return float(schdl) + except ValueError: + match = re.match(r"linear\((.+),(.+),(.+)\)", schdl) + if match: + init, final, duration = [float(g) for g in match.groups()] + mix = np.clip(step / duration, 0.0, 1.0) + return (1.0 - mix) * init + mix * final + raise NotImplementedError(schdl) + +class TruncatedNormal(pyd.Normal): + """Utility class implementing the truncated normal distribution.""" + + def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): + super().__init__(loc, scale, validate_args=False) + self.low = low + self.high = high + self.eps = eps + + def _clamp(self, x): + clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) + x = x - x.detach() + clamped_x.detach() + return x + + def sample(self, clip=None, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + eps *= self.scale + if clip is not None: + eps = torch.clamp(eps, -clip, clip) + x = self.loc + eps + return self._clamp(x) + + +class NormalizeImg(nn.Module): + """Normalizes pixel observations to [0,1) range.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.div(255.0) + + +class Flatten(nn.Module): + """Flattens its input to a (batched) vector.""" + + def __init__(self): + super().__init__() + + def forward(self, x): + return x.view(x.size(0), -1) + + +def enc(cfg): + """Returns a TOLD encoder.""" + pixels_enc_layers, state_enc_layers = None, None + if cfg.modality in {"pixels", "all"}: + C = int(3 * cfg.frame_stack) + pixels_enc_layers = [ + NormalizeImg(), + nn.Conv2d(C, cfg.num_channels, 7, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 5, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), + nn.ReLU(), + nn.Conv2d(cfg.num_channels, cfg.num_channels, 3, stride=2), + nn.ReLU(), + ] + out_shape = _get_out_shape((C, cfg.img_size, cfg.img_size), pixels_enc_layers) + pixels_enc_layers.extend( + [ + Flatten(), + nn.Linear(np.prod(out_shape), cfg.latent_dim), + nn.LayerNorm(cfg.latent_dim), + nn.Sigmoid(), + ] + ) + if cfg.modality == "pixels": + return ConvExt(nn.Sequential(*pixels_enc_layers)) + if cfg.modality in {"state", "all"}: + state_dim = ( + cfg.obs_shape[0] if cfg.modality == "state" else cfg.obs_shape["state"][0] + ) + state_enc_layers = [ + nn.Linear(state_dim, cfg.enc_dim), + nn.LayerNorm(cfg.enc_dim), + nn.ELU(), + nn.Linear(cfg.enc_dim, cfg.enc_dim), + nn.LayerNorm(cfg.enc_dim), + nn.ELU(), + nn.Linear(cfg.enc_dim, cfg.latent_dim), + nn.LayerNorm(cfg.latent_dim), + nn.Tanh(), + ] + if cfg.modality == "state": + return nn.Sequential(*state_enc_layers) + else: + raise NotImplementedError + + encoders = {} + for k in cfg.obs_shape: + if k == "state": + encoders[k] = nn.Sequential(*state_enc_layers) + elif k.endswith("rgb"): + encoders[k] = ConvExt(nn.Sequential(*pixels_enc_layers)) + else: + raise NotImplementedError + return Multiplexer(nn.ModuleDict(encoders)) + + +def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.ReLU(), layer_norm=True): + """Returns an MLP.""" + if isinstance(mlp_dim, int): + mlp_dim = [mlp_dim, mlp_dim] + layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] + for i in range(len(mlp_dim) - 1): + layers += [nn.Linear(mlp_dim[i], mlp_dim[i + 1]), nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), act_fn] + layers += [nn.Linear(mlp_dim[-1], out_dim)] + return nn.Sequential(*layers) + + +def dynamics(in_dim, mlp_dim, out_dim, act_fn=nn.Mish()): + """Returns a dynamics network.""" + return nn.Sequential( + mlp(in_dim, mlp_dim, out_dim, act_fn), + nn.LayerNorm(out_dim), + act_fn, + ) + + +def q(in_dim, mlp_dim, act_fn=nn.ReLU(), layer_norm=True): + """Returns a Q-function that uses Layer Normalization.""" + if isinstance(mlp_dim, int): + mlp_dim = [mlp_dim, mlp_dim] + layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] + for i in range(len(mlp_dim) - 1): + layers += [nn.Linear(mlp_dim[i], mlp_dim[i + 1]), nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), act_fn] + layers += [nn.Linear(mlp_dim[-1], 1)] + return nn.Sequential(*layers) + + +def v(in_dim, mlp_dim, act_fn=nn.ReLU(), layer_norm=True): + """Returns a Q-function that uses Layer Normalization.""" + if isinstance(mlp_dim, int): + mlp_dim = [mlp_dim, mlp_dim] + layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] + for i in range(len(mlp_dim) - 1): + layers += [nn.Linear(mlp_dim[i], mlp_dim[i + 1]), nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), act_fn] + layers += [[nn.Linear(mlp_dim[-1], 1)]] + return nn.Sequential(*layers) + + +def aug(cfg): + if cfg.modality == "state": + return nn.Identity() + else: + augs = {} + for k in cfg.obs_shape: + if k == "state": + augs[k] = nn.Identity() + else: + raise NotImplementedError + return Multiplexer(nn.ModuleDict(augs)) + + +class ConvExt(nn.Module): + def __init__(self, conv): + super().__init__() + self.conv = conv + + def forward(self, x): + if x.ndim > 4: + batch_shape = x.shape[:-3] + out = self.conv(x.view(-1, *x.shape[-3:])) + out = out.view(*batch_shape, *out.shape[1:]) + else: + out = self.conv(x) + return out + + +class Multiplexer(nn.Module): + + def __init__(self, choices): + super().__init__() + self.choices = choices + + def forward(self, x, key=None): + if isinstance(x, dict): + if key is not None: + return self.choices[key](x) + return {k: self.choices[k](_x) for k, _x in x.items()} + return self.choices(x) + +class Episode(object): + """Storage object for a single episode.""" + + def __init__(self, cfg, init_obs): + self.cfg = cfg + self.device = torch.device(cfg.buffer_device) + self.capacity = int(cfg.max_episode_length // cfg.action_repeat) + if cfg.modality in {"pixels", "state"}: + dtype = torch.float32 if cfg.modality == "state" else torch.uint8 + self.next_obses = torch.empty( + (cfg.num_envs, self.capacity, *init_obs.shape[1:]), + dtype=dtype, + device=self.device, + ) + self.obses = torch.empty( + (cfg.num_envs, self.capacity, *init_obs.shape[1:]), + dtype=dtype, + device=self.device, + ) + self.obses[:, 0] = init_obs.clone().to(self.device, dtype=dtype) + elif cfg.modality == "all": + self.obses = {} + for k, v in init_obs.items(): + assert k in {"rgb", "state"} + dtype = torch.float32 if k == "state" else torch.uint8 + self.next_obses[k] = torch.empty( + (cfg.num_envs, self.capacity, *v.shape[1:]), dtype=dtype, device=self.device + ) + self.obses[k] = torch.empty( + (cfg.num_envs, self.capacity, *v.shape[1:]), dtype=dtype, device=self.device + ) + self.obses[k][:, 0] = v.clone().to(self.device, dtype=dtype) + else: + raise ValueError + self.actions = torch.empty( + (cfg.num_envs, self.capacity, cfg.action_dim), + dtype=torch.float32, + device=self.device, + ) + self.rewards = torch.empty( + (cfg.num_envs, self.capacity,), dtype=torch.float32, device=self.device + ) + self.dones = torch.empty( + (cfg.num_envs, self.capacity,), dtype=torch.bool, device=self.device + ) + self.successes = torch.empty( + (cfg.num_envs, self.capacity,), dtype=torch.bool, device=self.device + ) + self.masks = torch.zeros( + (cfg.num_envs, self.capacity,), dtype=torch.float32, device=self.device + ) + self.cumulative_reward = torch.tensor([0.] * cfg.num_envs) + self.done = torch.tensor([False] * cfg.num_envs) + self.success = torch.tensor([False] * cfg.num_envs) + self._idx = 0 + + def __len__(self): + return self._idx + + @property + def episode_length(self): + num_dones = self.dones[:, :self._idx].sum().item() + if num_dones > 0: + return float(self._idx) * self.cfg.num_envs / num_dones + return float(self._idx) + + @classmethod + def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): + """Constructs an episode from a trajectory.""" + + if cfg.modality in {"pixels", "state"}: + episode = cls(cfg, obses[0]) + episode.obses[1:] = torch.tensor( + obses[1:], dtype=episode.obses.dtype, device=episode.device + ) + elif cfg.modality == "all": + episode = cls(cfg, {k: v[0] for k, v in obses.items()}) + for k, v in obses.items(): + episode.obses[k][1:] = torch.tensor( + obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device + ) + else: + raise NotImplementedError + episode.actions = torch.tensor( + actions, dtype=episode.actions.dtype, device=episode.device + ) + episode.rewards = torch.tensor( + rewards, dtype=episode.rewards.dtype, device=episode.device + ) + episode.dones = ( + torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) + if dones is not None + else torch.zeros_like(episode.dones) + ) + episode.masks = ( + torch.tensor(masks, dtype=episode.masks.dtype, device=episode.device) + if masks is not None + else torch.ones_like(episode.masks) + ) + episode.cumulative_reward = torch.sum(episode.rewards) + episode.done = True + episode._idx = cfg.episode_length + return episode + + @property + def first(self): + return len(self) == 0 + + @property + def full(self): + return len(self) == self.capacity + + @property + def buffer_capacity(self): + return self.capacity + + def __add__(self, transition): + self.add(*transition) + return self + + def add(self, obs, action, reward, done, timeouts, success=False): + if isinstance(obs, dict): + for k, v in obs.items(): + if self._idx == self.capacity - 1: + self.next_obses[k][:, self._idx] = v.clone().to(self.obses[k].device, dtype=self.obses[k].dtype) + elif self._idx < self.capacity - 1: + self.obses[k][:, self._idx + 1] = v.clone().to(self.obses[k].device, dtype=self.obses[k].dtype) + self.next_obses[k][:, self._idx] = self.obses[k][:, self._idx + 1].clone() + else: + if self._idx == self.capacity - 1: + self.next_obses[:, self._idx] = obs.clone().to(self.obses.device, dtype=self.obses.dtype) + elif self._idx < self.capacity - 1: + self.obses[:, self._idx + 1] = obs.clone().to(self.obses.device, dtype=self.obses.dtype) + self.next_obses[:, self._idx] = self.obses[:, self._idx + 1].clone() + self.actions[:, self._idx] = action.detach().cpu() + self.rewards[:, self._idx] = reward.detach().cpu() + self.dones[:, self._idx] = done.detach().cpu() + self.masks[:, self._idx] = 1.0 - timeouts.detach().cpu().float() # TODO + self.cumulative_reward += reward.detach().cpu() + self.done = done.detach().cpu() + self.success = torch.logical_or(self.success, success.detach().cpu()) + self.successes[:, self._idx] = torch.tensor(self.success).to(self.device) + self._idx += 1 + +class ReplayBuffer: + """ + Storage and sampling functionality for training TD-MPC / TOLD. + The replay buffer is stored in GPU memory when training from state. + Uses prioritized experience replay by default.""" + + def __init__(self, cfg, dataset=None): + self.cfg = cfg + self.buffer_device = torch.device(cfg.buffer_device) + self.device = torch.device(cfg.device) + self.batch_size = self.cfg.batch_size + + print("Replay buffer device: ", self.buffer_device) + print("Replay buffer sample device: ", self.device) + + if dataset is not None: + self.capacity = max( + dataset["rewards"].shape[0], cfg.max_offline_buffer_size + ) + print("Offline dataset size: ", dataset["rewards"].shape[0]) + else: + self.capacity = max(cfg.train_steps, cfg.max_buffer_size) + + print("Maximum capacity of the buffer is: ", self.capacity) + + if cfg.modality in {"pixels", "state"}: + dtype = torch.float32 if cfg.modality == "state" else torch.uint8 + # Note self.obs_shape always has single frame, which is different from cfg.obs_shape + self.obs_shape = ( + cfg.obs_shape if cfg.modality == "state" else (3, *cfg.obs_shape[-2:]) + ) + self._obs = torch.empty( + (self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device + ) + self._next_obs = torch.empty( + (self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device + ) + elif cfg.modality == "all": + self.obs_shape = {} + self._obs, self._next_obs = {}, {} + for k, v in cfg.obs_shape.items(): + assert k in {"rgb", "state"} + dtype = torch.float32 if k == "state" else torch.uint8 + self.obs_shape[k] = v if k == "state" else (3, *v[-2:]) + self._obs[k] = torch.empty( + (self.capacity, *self.obs_shape[k]), + dtype=dtype, + device=self.buffer_device, + ) + self._next_obs[k] = self._obs[k].clone() + else: + raise ValueError + + self._action = torch.empty( + (self.capacity, cfg.action_dim), + dtype=torch.float32, + device=self.buffer_device, + ) + self._reward = torch.empty( + (self.capacity,), dtype=torch.float32, device=self.buffer_device + ) + self._mask = torch.empty( + (self.capacity,), dtype=torch.float32, device=self.buffer_device + ) + self._done = torch.empty( + (self.capacity,), dtype=torch.bool, device=self.buffer_device + ) + self._success = torch.empty( + (self.capacity,), dtype=torch.bool, device=self.buffer_device + ) + self._priorities = torch.ones( + (self.capacity,), dtype=torch.float32, device=self.buffer_device + ) + self.ep_len = int(self.cfg.max_episode_length // self.cfg.action_repeat) + self._eps = 1e-6 + self._full = False + self.idx = 0 + self._sampling_idx = 0 + if dataset is not None: + self.init_from_offline_dataset(dataset) + + self._aug = aug(cfg) + + def init_from_offline_dataset(self, dataset): + assert self.idx == 0 and not self._full + n_transitions = int(len(dataset["rewards"]) * self.cfg.data_first_percent) + + def copy_data(dst, src, n): + assert isinstance(dst, dict) == isinstance(src, dict) + if isinstance(dst, dict): + for k in dst: + copy_data(dst[k], src[k], n) + else: + dst[:n] = torch.from_numpy(src[:n]) + + copy_data(self._obs, dataset["observations"], n_transitions) + copy_data(self._next_obs, dataset["next_observations"], n_transitions) + copy_data(self._action, dataset["actions"], n_transitions) + if self.cfg.task.startswith("xarm"): + # success = self._calc_sparse_success(dataset['success']) + # copy_data(self._reward, success.astype(np.float32), n_transitions) + if self.cfg.sparse_reward: + copy_data( + self._reward, dataset["success"].astype(np.float32), n_transitions + ) + copy_data(self._success, dataset["success"], n_transitions) + else: + copy_data(self._reward, dataset["rewards"], n_transitions) + copy_data(self._mask, dataset["masks"], n_transitions) + copy_data(self._done, dataset["dones"], n_transitions) + self.idx = (self.idx + n_transitions) % self.capacity + self._full = n_transitions >= self.capacity + mask_idxs = np.array([n_transitions - i for i in range(1, self.cfg.horizon)]) + # _, episode_ends, _ = get_trajectory_boundaries_and_returns(dataset) + # mask_idxs = np.array([np.array(episode_ends) - i for i in range(1, self.cfg.horizon)]).T.flatten() + mask_idxs = np.clip(mask_idxs, 0, n_transitions - 1) + self._priorities[mask_idxs] = 0 + + def __add__(self, episode: Episode): + self.add(episode) + return self + + def add(self, episode: Episode): + idxs = torch.arange(self.idx, self.idx + self.cfg.num_envs * self.ep_len) % self.capacity + self._sampling_idx = (self.idx + self.cfg.num_envs * self.ep_len) % self.capacity + mask_copy = episode.masks.clone() + mask_copy[:, episode._idx - self.cfg.horizon:] = 0. + if self.cfg.modality in {"pixels", "state"}: + self._obs[idxs] = ( + episode.obses.flatten(0, 1) + if self.cfg.modality == "state" + else episode.obses[:, -3:].flatten(0, 1) + ) + self._next_obs[idxs] = ( + episode.next_obses.flatten(0, 1) + if self.cfg.modality == "state" + else episode.next_obses[:, -3:].flatten(0, 1) + ) + elif self.cfg.modality == "all": + for k, v in episode.obses.items(): + assert k in {"rgb", "state"} + assert k in self._obs + assert k in self._next_obs + if k == "rgb": + self._obs[k][idxs] = episode.obses[k][:-1, -3:].flatten(0, 1) + self._next_obs[k][idxs] = episode.obses[k][1:, -3:].flatten(0, 1) + else: + self._obs[k][idxs] = episode.obses[k][:-1].flatten(0, 1) + self._next_obs[k][idxs] = episode.obses[k][1:].flatten(0, 1) + self._action[idxs] = episode.actions.flatten(0, 1) + self._reward[idxs] = episode.rewards.flatten(0, 1) + self._mask[idxs] = mask_copy.flatten(0, 1) + self._done[idxs] = episode.dones.flatten(0, 1) + self._success[idxs] = episode.successes.flatten(0, 1) + if self._full: + max_priority = self._priorities.max().to(self.device).item() + else: + max_priority = ( + 1.0 + if self.idx == 0 + else self._priorities[: self.idx].max().to(self.device).item() + ) + mask = torch.arange(self.ep_len) > self.ep_len - self.cfg.horizon + mask = torch.cat([mask] * self.cfg.num_envs) + new_priorities = torch.full((self.ep_len * self.cfg.num_envs,), max_priority, device=self.buffer_device) + new_priorities[mask] = 0 + new_priorities = new_priorities * self._mask[idxs] + self._priorities[idxs] = new_priorities + self._full = self._full or (self.idx + self.ep_len * self.cfg.num_envs > self.capacity) + if episode.full: + self.idx = (self.idx + self.ep_len * self.cfg.num_envs) % self.capacity + + def _set_bs(self, bs): + self.batch_size = bs + + def update_priorities(self, idxs, priorities): + self._priorities[idxs.to(self.buffer_device)] = ( + priorities.squeeze(1).to(self.buffer_device) + self._eps + ) + + def _get_obs(self, arr, idxs, bs=None, frame_stack=None): + if isinstance(arr, dict): + return { + k: self._get_obs(v, idxs, bs=bs, frame_stack=frame_stack) + for k, v in arr.items() + } + if arr.ndim <= 2: # if self.cfg.modality == 'state': + return arr[idxs].cuda(self.device) + obs = torch.empty( + ( + self.cfg.batch_size if bs is None else bs, + 3 * self.cfg.frame_stack if frame_stack is None else 3 * frame_stack, + *arr.shape[-2:], + ), + dtype=arr.dtype, + device=torch.device(self.device), + ) + obs[:, -3:] = arr[idxs].cuda(self.device) + _idxs = idxs.clone() + mask = torch.ones_like(_idxs, dtype=torch.bool) + for i in range(1, self.cfg.frame_stack if frame_stack is None else frame_stack): + mask[_idxs % self.cfg.episode_length == 0] = False + _idxs[mask] -= 1 + obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda(self.device) + return obs.float() + + def sample(self, bs=None): + probs = ( + self._priorities if self._full else self._priorities[:self._sampling_idx] + ) ** self.cfg.per_alpha + probs /= probs.sum() + total = len(probs) + if torch.isnan(self._priorities).any(): + print(torch.isnan(self._priorities).any()) + print(torch.where(torch.isnan(self._priorities))) + idxs = torch.from_numpy( + np.random.choice( + total, + self.cfg.batch_size if bs is None else bs, + p=probs.cpu().numpy(), + replace=((not self._full) or (self.cfg.batch_size > self.capacity)), + ) + ).to(self.buffer_device) % self.capacity + weights = (total * probs[idxs]) ** (-self.cfg.per_beta) + weights /= weights.max() + + idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)]) % self.capacity + + obs = self._aug(self._get_obs(self._obs, idxs, bs=bs)) + next_obs = [ + self._aug(self._get_obs(self._next_obs, _idxs, bs=bs)) + for _idxs in idxs_in_horizon + ] + if isinstance(next_obs[0], dict): + next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]} + else: + next_obs = torch.stack(next_obs) + action = self._action[idxs_in_horizon] + reward = self._reward[idxs_in_horizon] + mask = self._mask[idxs_in_horizon] + done = torch.logical_not(self._done[idxs_in_horizon]).float() + + if not action.is_cuda: + action, reward, done, idxs, weights = ( + action.cuda(self.device), + reward.cuda(self.device), + done.cuda(self.device), + idxs.cuda(self.device), + weights.cuda(self.device), + ) + return ( + obs, + next_obs, + action, + reward.unsqueeze(2), + done.unsqueeze(2), + idxs, + weights, + ) \ No newline at end of file diff --git a/sim/tdmpc/src/algorithm/tdmpc.py b/sim/tdmpc/src/algorithm/tdmpc.py new file mode 100644 index 00000000..b77e6b8f --- /dev/null +++ b/sim/tdmpc/src/algorithm/tdmpc.py @@ -0,0 +1,248 @@ +import numpy as np +import torch +import torch.nn as nn +from copy import deepcopy +from dataclasses import asdict +import sim.tdmpc.src.algorithm.helper as h + + +class TOLD(nn.Module): + """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self._encoder = h.enc(cfg) + self._dynamics = h.mlp(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim, cfg.latent_dim) + self._reward = h.mlp(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim, 1) + self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, cfg.action_dim) + self._Qs = nn.ModuleList( + [h.q(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim) for _ in range(cfg.num_q)]) + self.apply(h.orthogonal_init) + for m in [self._reward, *self._Qs]: + m[-1].weight.data.fill_(0) + m[-1].bias.data.fill_(0) + + def track_q_grad(self, enable=True): + """Utility function. Enables/disables gradient tracking of Q-networks.""" + for m in self._Qs: + h.set_requires_grad(m, enable) + + def h(self, obs): + """Encodes an observation into its latent representation (h).""" + return self._encoder(obs) + + def next(self, z, a): + """Predicts next latent state (d) and single-step reward (R).""" + x = torch.cat([z, a], dim=-1) + return self._dynamics(x), self._reward(x) + + def pi(self, z, std=0): + """Samples an action from the learned policy (pi).""" + mu = torch.tanh(self._pi(z)) + if std > 0: + std = torch.ones_like(mu) * std + return h.TruncatedNormal(mu, std).sample(clip=0.3) + return mu + + def Q(self, z, a): + """Predict state-action value (Q).""" + x = torch.cat([z, a], dim=-1) + Qs = torch.stack([self._Qs[i](x) for i in range(self.cfg.num_q)], dim=0) + return Qs + + +class TDMPC(): + """Implementation of TD-MPC learning + inference.""" + def __init__(self, cfg): + self.cfg = cfg + self.device = torch.device(cfg.device) + self.std = h.linear_schedule(cfg.std_schedule, 0) + self.model = TOLD(cfg).to(self.device) + self.model_target = deepcopy(self.model) + self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) + self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=3 * self.cfg.lr) + self.aug = nn.Identity() + self.model.eval() + self.model_target.eval() + + def state_dict(self): + """Retrieve state dict of TOLD model, including slow-moving target network.""" + return { + 'model': self.model.state_dict(), + 'model_target': self.model_target.state_dict(), + 'config': asdict(self.cfg), + } + + def save(self, fp): + """Save state dict of TOLD model to filepath.""" + torch.save(self.save_dict(), fp) + + def load(self, fp): + """Load a saved state dict from filepath into current agent.""" + d = torch.load(fp) + self.model.load_state_dict(d['model']) + self.model_target.load_state_dict(d['model_target']) + + @torch.no_grad() + def estimate_value(self, z, actions, horizon): + """Estimate value of a trajectory starting at latent state z and executing given actions.""" + G, discount = 0, 1 + for t in range(horizon): + z, reward = self.model.next(z, actions[t]) + G += discount * reward + discount *= self.cfg.discount + G += discount * torch.min(self.model.Q(z, self.model.pi(z, self.cfg.min_std)), dim=0)[0] + return G + + @torch.no_grad() + def plan(self, obs, eval_mode=False, step=None, t0=True): + """ + Plan next action using TD-MPC inference. + obs: raw input observation. + eval_mode: uniform sampling and action noise is disabled during evaluation. + step: current time step. determines e.g. planning horizon. + t0: whether current step is the first step of an episode. + """ + # Seed steps + if step < self.cfg.seed_steps and not eval_mode: + return torch.empty(self.cfg.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1) + + # Sample policy trajectories + obs = obs.clone().to(self.device, dtype=torch.float32).unsqueeze(1) + horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) + num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) + if num_pi_trajs > 0: + pi_actions = torch.empty(horizon, self.cfg.num_envs, num_pi_trajs, self.cfg.action_dim, device=self.device) + z = self.model.h(obs).repeat(1, num_pi_trajs, 1) + for t in range(horizon): + pi_actions[t] = self.model.pi(z, self.cfg.min_std) + z, _ = self.model.next(z, pi_actions[t]) + + # Initialize state and parameters + z = self.model.h(obs).repeat(1, self.cfg.num_samples + num_pi_trajs, 1) + mean = torch.zeros(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) + std = 2 * torch.ones(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) + + if isinstance(t0, bool) and t0 and hasattr(self, '_prev_mean') and self._prev_mean.shape[0] > 1: + _prev_h = self._prev_mean.shape[0] - 1 + mean[:_prev_h] = self._prev_mean[1:] + elif torch.is_tensor(t0) and t0.any() and hasattr(self, '_prev_mean') and self._prev_mean.shape[0] > 1: + _prev_h = self._prev_mean.shape[0] - 1 + mean[:_prev_h] = self._prev_mean[1:] + + # Iterate CEM + for i in range(self.cfg.iterations): + actions = torch.clamp(mean.unsqueeze(2) + std.unsqueeze(2) * \ + torch.randn(horizon, self.cfg.num_envs, self.cfg.num_samples, self.cfg.action_dim, device=std.device), -1, 1) + if num_pi_trajs > 0: + actions = torch.cat([actions, pi_actions], dim=-2) + + # Compute elite actions + value = self.estimate_value(z, actions, horizon).nan_to_num_(0) + elite_idxs = torch.topk(value.squeeze(-1), self.cfg.num_elites, dim=-1).indices + elite_value, elite_actions = value.squeeze(-1).gather(-1, elite_idxs), actions.gather(-2, elite_idxs.unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim)) + + # Update parameters + max_value = elite_value.max(1, keepdim=True)[0] + score = torch.exp(self.cfg.temperature*(elite_value - max_value)) + score /= score.sum(1, keepdim=True) + _mean = torch.sum(score.unsqueeze(0).unsqueeze(-1) * elite_actions, dim=-2) / (score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9) + _std = torch.sqrt(torch.sum(score.unsqueeze(0).unsqueeze(-1) * (elite_actions - _mean.unsqueeze(2)) ** 2, dim=-2) / (score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9)) + _std = _std.clamp_(self.std, 2) + mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std + + # Outputs + select_indices = torch.multinomial(score, 1) + actions = elite_actions.gather(-2, select_indices.unsqueeze(0).unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim)).squeeze(-2) + self._prev_mean = mean + mean, std = actions[0], _std[0] + a = mean + if not eval_mode: + a += std * torch.randn(self.cfg.action_dim, device=std.device) + return a + + def update_pi(self, zs): + """Update policy using a sequence of latent states.""" + self.pi_optim.zero_grad(set_to_none=True) + self.model.track_q_grad(False) + + # Loss is a weighted sum of Q-values + pi_loss = 0 + for t,z in enumerate(zs): + a = self.model.pi(z, self.cfg.min_std) + Q = torch.min(self.model.Q(z, a), dim=0)[0] + pi_loss += -Q.mean() * (self.cfg.rho ** t) + + pi_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False) + self.pi_optim.step() + self.model.track_q_grad(True) + return pi_loss.item() + + @torch.no_grad() + def _td_target(self, next_obs, reward, mask=1.0): + """Compute the TD-target from a reward and the observation at the following time step.""" + next_z = self.model.h(next_obs) + td_target = reward + self.cfg.discount * mask * \ + torch.min(self.model_target.Q(next_z, self.model.pi(next_z, self.cfg.min_std)), dim=0)[0] + return td_target + + def update(self, replay_buffer, step): + """Main update function. Corresponds to one iteration of the TOLD model learning.""" + obs, next_obses, action, reward, mask, idxs, weights = replay_buffer.sample() + self.optim.zero_grad(set_to_none=True) + self.std = h.linear_schedule(self.cfg.std_schedule, step) + self.model.train() + + # Representation + z = self.model.h(self.aug(obs)) + zs = [z.detach()] + + loss_mask = torch.ones_like(mask[0], device=self.device) + + + consistency_loss, reward_loss, value_loss, priority_loss = 0, 0, 0, 0 + for t in range(self.cfg.horizon): + if t > 0: + loss_mask = loss_mask * mask[t - 1] + # Predictions + Qs = self.model.Q(z, action[t]) + z, reward_pred = self.model.next(z, action[t]) + with torch.no_grad(): + next_obs = self.aug(next_obses[t]) + next_z = self.model_target.h(next_obs) + td_target = self._td_target(next_obs, mask[t], reward[t]) + zs.append(z.detach()) + + # Losses + rho = (self.cfg.rho ** t) + consistency_loss += loss_mask[t] * rho * torch.mean(h.mse(z, next_z), dim=1, keepdim=True) + reward_loss += loss_mask[t] * rho * h.mse(reward_pred, reward[t]) + for i in range(self.cfg.num_q): + value_loss += loss_mask[t] * rho * h.mse(Qs[i], td_target) + priority_loss += loss_mask[t] * rho * h.l1(Qs[i], td_target) + + # Optimize model + total_loss = self.cfg.consistency_coef * consistency_loss.clamp(max=1e4) + \ + self.cfg.reward_coef * reward_loss.clamp(max=1e4) + \ + self.cfg.value_coef * value_loss.clamp(max=1e4) + weighted_loss = (total_loss.squeeze(1) * weights).mean() + weighted_loss.register_hook(lambda grad: grad * (1/self.cfg.horizon)) + weighted_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False) + self.optim.step() + replay_buffer.update_priorities(idxs, priority_loss.clamp(max=1e4).detach()) + + if step % self.cfg.update_freq == 0: + # Update policy + target network + pi_loss = self.update_pi(zs) + h.ema(self.model, self.model_target, self.cfg.tau) + + self.model.eval() + return {'consistency_loss': float(consistency_loss.mean().item()), + 'reward_loss': float(reward_loss.mean().item()), + 'value_loss': float(value_loss.mean().item()), + 'pi_loss': pi_loss if step % self.cfg.update_freq == 0 else 0., + 'total_loss': float(total_loss.mean().item()), + 'weighted_loss': float(weighted_loss.mean().item()), + 'grad_norm': float(grad_norm)} diff --git a/sim/tdmpc/src/logger.py b/sim/tdmpc/src/logger.py new file mode 100644 index 00000000..e4111ebe --- /dev/null +++ b/sim/tdmpc/src/logger.py @@ -0,0 +1,173 @@ +import sys +import os +import datetime +import re +import cv2 +import numpy as np +import torch +import pandas as pd +from isaacgym import gymapi +from termcolor import colored + + +CONSOLE_FORMAT = [('episode', 'E', 'int'), ('step', 'S', 'int'), ('env_step', 'ES', 'int'), + ('episode_reward', 'R', 'float'), ('mean_episode_length', 'MEL', 'float'), ('total_time', 'T', 'time'),] + # ('consistency_loss', 'CL', 'float'), ('value_loss', 'PL', 'float'), ('pi_loss', 'PL', 'float'), ('total_loss', 'L', 'float')] +AGENT_METRICS = ['consistency_loss', 'reward_loss', 'value_loss', 'total_loss', 'weighted_loss', 'pi_loss', 'grad_norm'] + + +def make_dir(dir_path): + """Create directory if it does not already exist.""" + try: + os.makedirs(dir_path) + except OSError: + pass + return dir_path + + +def print_run(cfg, reward=None): + """Pretty-printing of run information. Call at start of training.""" + prefix, color, attrs = ' ', 'green', ['bold'] + def limstr(s, maxlen=32): + return str(s[:maxlen]) + '...' if len(str(s)) > maxlen else s + def pprint(k, v): + print(prefix + colored(f'{k.capitalize()+":":<16}', color, attrs=attrs), limstr(v)) + kvs = [('task', cfg.task), + ('train steps', f'{int(cfg.train_steps*cfg.action_repeat):,}'), + ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), + ('actions', cfg.action_dim), + ('experiment', cfg.exp_name)] + if reward is not None: + kvs.append(('episode reward', colored(str(int(reward)), 'white', attrs=['bold']))) + w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 + div = '-'*w + print(div) + for k,v in kvs: + pprint(k, v) + print(div) + + +def cfg_to_group(cfg, return_list=False): + """Return a wandb-safe group name for logging. Optionally returns group name as list.""" + lst = [cfg.task, cfg.modality, re.sub('[^0-9a-zA-Z]+', '-', cfg.exp_name)] + return lst if return_list else '-'.join(lst) + + +class VideoRecorder: + """Utility class for logging evaluation videos.""" + def __init__(self, root_dir, wandb, render_size=384, fps=15): + self.save_dir = (root_dir / 'eval_video') if root_dir else None + self._wandb = wandb + self.render_size = render_size + self.fps = fps + self.frames = [] + self.enabled = False + + def init(self, env, h1, enabled=True): + self.frames = [] + self.enabled = self.save_dir and self._wandb and enabled + self.record(env, h1) + + def record(self, env, h1): + if self.enabled: + env.gym.fetch_results(env.sim, True) + env.gym.step_graphics(env.sim) + env.gym.render_all_camera_sensors(env.sim) + img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR) + img = np.reshape(img, (120, 160, 4)) + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + self.frames.append(img) + + def save(self, step): + if self.enabled: + frames = np.stack(self.frames).transpose(0, 3, 1, 2) + self._wandb.log({'eval_video': self._wandb.Video(frames, fps=self.fps, format='mp4')}, step=step) + + +class Logger(object): + """Primary logger object. Logs either locally or using wandb.""" + def __init__(self, log_dir, cfg): + project, entity = cfg.wandb_project, cfg.wandb_entity + run_offline = not cfg.use_wandb or project == 'none' or entity == 'none' + + self._save_model = cfg.save_model + if not run_offline or self._save_model: + self._log_dir = make_dir(log_dir) + + if self._save_model: + self._model_dir = make_dir(self._log_dir / 'models') + + self._group = cfg_to_group(cfg) + self._seed = cfg.seed + self._cfg = cfg + self._eval = [] + print_run(cfg) + if run_offline: + print(colored('Logs will be saved locally.', 'yellow', attrs=['bold'])) + self._wandb = None + else: + try: + os.environ["WANDB_SILENT"] = "true" + import wandb + wandb.init(project=project, + entity=entity, + name=str(cfg.seed), + group=self._group, + tags=cfg_to_group(cfg, return_list=True) + [f'seed:{cfg.seed}'], + dir=self._log_dir,) + print(colored('Logs will be synced with wandb.', 'blue', attrs=['bold'])) + self._wandb = wandb + except: + print(colored('Warning: failed to init wandb. Logs will be saved locally.', 'yellow', attrs=['bold'])) + self._wandb = None + self._video = VideoRecorder(log_dir, self._wandb) if (self._wandb and cfg.save_video) else None + + @property + def video(self): + return self._video + + def save(self, agent, model_name="model.pt"): + if self._save_model: + fp = self._model_dir / f"{model_name}" + torch.save(agent.state_dict(), fp) + + def finish(self, agent, model_name="model.pt"): + if self._save_model: + fp = self._model_dir / f"{model_name}" + torch.save(agent.state_dict(), fp) + # if self._wandb: + # artifact = self._wandb.Artifact(self._group+'-'+str(self._seed), type='model') + # artifact.add_file(fp) + # self._wandb.log_artifact(artifact) + if self._wandb: + self._wandb.finish() + print_run(self._cfg, self._eval[-1][-1]) + + def _format(self, key, value, ty): + if ty == 'int': + return f'{colored(key+":", "grey")} {int(value):,}' + elif ty == 'float': + return f'{colored(key+":", "grey")} {value:.03f}' + elif ty == 'time': + value = str(datetime.timedelta(seconds=int(value))) + return f'{colored(key+":", "grey")} {value}' + else: + raise f'invalid log format type: {ty}' + + def _print(self, d, category): + category = colored(category, 'blue' if category == 'train' else 'green') + pieces = [f' {category:<14}'] + for k, disp_k, ty in CONSOLE_FORMAT: + pieces.append(f'{self._format(disp_k, d.get(k, 0), ty):<26}') + print(' '.join(pieces)) + + def log(self, d, category='train'): + assert category in {'train', 'eval'} + if self._wandb is not None: + for k,v in d.items(): + self._wandb.log({category + '/' + k: v}, step=d['env_step']) + if category == 'eval': + keys = ['env_step', 'episode_reward'] + self._eval.append(np.array([d[keys[0]], d[keys[1]]])) + pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / 'eval.log', header=keys, index=None) + self._print(d, category) diff --git a/sim/train_tdmpc.py b/sim/train_tdmpc.py new file mode 100644 index 00000000..d8bab9e4 --- /dev/null +++ b/sim/train_tdmpc.py @@ -0,0 +1,213 @@ +"""Trains a humanoid to stand up.""" + +import argparse +import isaacgym +import torch +from sim.envs import task_registry +from sim.utils.helpers import get_args +from sim.tdmpc.src import logger +from sim.tdmpc.src.algorithm.helper import Episode, ReplayBuffer +from sim.tdmpc.src.algorithm.tdmpc import TDMPC +from dataclasses import dataclass, field, asdict +from datetime import datetime +from isaacgym import gymapi +from typing import List +import time +import numpy as np +from pathlib import Path +import random +torch.backends.cudnn.benchmark = True +__LOGS__ = "logs" + +@dataclass +class TDMPC_DoraConfigs: + seed: int = 42 + task : str = "walk" + exp_name : str = "dora" + device : str = "cuda:0" + num_envs : int = 10 + + lr : float = 1e-3 + modality : str = "state" + enc_dim: int = 512 + mlp_dim = [512, 256] + latent_dim: int = 100 + + iterations : int = 12 + num_samples : int = 512 + num_elites : int = 50 + mixture_coef : float = 0.05 + min_std : float = 0.1 + temperature : float = 0.5 + momentum : float = 0.1 + horizon : int = 5 + std_schedule: str = f"linear(0.5, {min_std}, 60000)" + horizon_schedule: str = f"linear(1, {horizon}, 15000)" + + batch_size: int = 8192 + max_buffer_size : int = int(5e6) + reward_coef : float = 1 + value_coef : float = 0.75 + consistency_coef : float = 2 + rho : float = 0.75 + kappa : float = 0.1 + per_alpha: float = 0.6 + per_beta : float = 0.4 + grad_clip_norm : float = 100 + seed_steps: int = 500 + update_freq: int = 3 + tau: int = 0.05 + + discount : float = 0.99 + buffer_device : str = "cpu" + train_steps : int = int(1e6) + num_q : int = 2 + + action_repeat : int = 2 + eval_freq: int = 15000 + eval_freq_episode : int = 10 + eval_episodes : int = 1 + + save_model : bool = True + save_video : bool = False + + use_wandb : bool = False + wandb_entity : str = "crajagopalan" + wandb_project : str = "xbot" + + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def evaluate(test_env, agent, h1, num_episodes, step, env_step, video, action_repeat=1): + """Evaluate a trained agent and optionally save a video.""" + episode_rewards = [] + for i in range(num_episodes): + obs, privileged_obs = test_env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] + dones, ep_reward, t = torch.tensor([False]), 0, 0 + if video: video.init(test_env, h1, enabled=(i==0)) + while not dones[0].item(): + actions = agent.plan(state, eval_mode=True, step=step, t0=t==0) + for _ in range(action_repeat): + obs, privileged_obs, rewards, dones, infos = test_env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] + ep_reward += rewards[0] + if video: video.record(test_env, h1) + t += 1 + episode_rewards.append(ep_reward) + if video: video.save(env_step) + return torch.nanmean(torch.tensor(episode_rewards)).item() + +def train(args: argparse.Namespace) -> None: + """Training script for TD-MPC. Requires a CUDA-enabled device.""" + assert torch.cuda.is_available() + env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) + env, _ = task_registry.make_env(name=args.task, args=args) + + tdmpc_cfg = TDMPC_DoraConfigs() + + if tdmpc_cfg.save_video: + env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) + + camera_properties = gymapi.CameraProperties() + camera_properties.width = 160 + camera_properties.height = 120 + h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties) + camera_offset = gymapi.Vec3(3, -3, 1) + camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135)) + actor_handle = env.gym.get_actor_handle(env.envs[0], 0) + body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0) + env.gym.attach_camera_to_body( + h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION + ) + + set_seed(tdmpc_cfg.seed) + now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + work_dir = Path().cwd() / __LOGS__ / f"{now}_{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}" + + obs, privileged_obs = env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + + tdmpc_cfg.obs_shape = state.shape[1:] + tdmpc_cfg.action_shape = (env.num_actions) + tdmpc_cfg.action_dim = env.num_actions + episode_length = 100 + tdmpc_cfg.episode_length = episode_length + tdmpc_cfg.num_envs = env.num_envs + tdmpc_cfg.max_episode_length = int(env.max_episode_length) + + agent = TDMPC(tdmpc_cfg) + buffer = ReplayBuffer(tdmpc_cfg) + fp = None + + init_step = 0 + episode_idx, start_time = 0, time.time() + + if fp is not None: + agent.load(fp) + episode_idx = int(fp.split(".")[0].split("_")[-1]) + init_step = episode_idx * tdmpc_cfg.episode_length + episode = Episode(tdmpc_cfg, state) + + L = logger.Logger(work_dir, tdmpc_cfg) + + for step in range(init_step, tdmpc_cfg.train_steps + tdmpc_cfg.episode_length, tdmpc_cfg.episode_length): + if episode.full: + episode = Episode(tdmpc_cfg, state) + for i in range(tdmpc_cfg.episode_length): + actions = agent.plan(state, t0 = episode.first, eval_mode=False, step=step) + original_state = state.clone() + total_rewards, total_dones, total_timeouts = [], [], [] + for _ in range(tdmpc_cfg.action_repeat): + obs, privileged_obs, rewards, dones, infos = env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + total_rewards.append(rewards) + total_dones.append(dones) + total_timeouts.append(infos["time_outs"]) + episode += (original_state, actions, torch.stack(total_rewards).sum(dim=0), torch.stack(total_dones).any(dim=0), torch.stack(total_timeouts).any(dim=0), torch.stack(total_dones).any(dim=0)) + buffer += episode + + # Update model + train_metrics = {} + if step >= tdmpc_cfg.seed_steps: + num_updates = tdmpc_cfg.seed_steps if step == tdmpc_cfg.seed_steps else tdmpc_cfg.episode_length + for i in range(num_updates): + train_metrics.update(agent.update(buffer, step+i)) + + # Log training episode + episode_idx += 1 + env_step = int((step + tdmpc_cfg.episode_length) * tdmpc_cfg.num_envs) + common_metrics = { + 'episode': episode_idx, + 'step': step + tdmpc_cfg.episode_length, + 'env_step': env_step, + 'total_time': time.time() - start_time, + 'episode_reward': episode.cumulative_reward.sum().item() / env.num_envs, + 'mean_episode_length' : episode.episode_length, + } + train_metrics.update(common_metrics) + L.log(train_metrics, category='train') + + + # Evaluate agent periodically + if tdmpc_cfg.save_model and episode_idx % tdmpc_cfg.eval_freq_episode == 0: + L.save(agent, f"tdmpc_policy_{int(step // tdmpc_cfg.episode_length) + 1}.pt") + # # common_metrics['episode_reward'] = evaluate(env, agent, h1 if L.video is not None else None, tdmpc_cfg.eval_episodes, step, env_step, L.video, tdmpc_cfg.action_repeat) + # # L.log(common_metrics, category='eval') + + + # print('Training completed successfully') + +if __name__ == "__main__": + # python -m sim.humanoid_gym.train + train(get_args()) \ No newline at end of file From 31ce27f3df98bb2fc1f2ccfb9e94c451ec9a5b35 Mon Sep 17 00:00:00 2001 From: chamorajg Date: Thu, 12 Sep 2024 17:35:15 -0700 Subject: [PATCH 2/4] update moved tdmpc up into the algo folder. --- sim/{ => algo}/tdmpc/src/algorithm/__init__.py | 0 sim/{ => algo}/tdmpc/src/algorithm/helper.py | 12 ++++++++++++ sim/{ => algo}/tdmpc/src/algorithm/tdmpc.py | 2 +- sim/{ => algo}/tdmpc/src/logger.py | 0 sim/train_tdmpc.py | 10 +++++----- 5 files changed, 18 insertions(+), 6 deletions(-) rename sim/{ => algo}/tdmpc/src/algorithm/__init__.py (100%) rename sim/{ => algo}/tdmpc/src/algorithm/helper.py (98%) rename sim/{ => algo}/tdmpc/src/algorithm/tdmpc.py (99%) rename sim/{ => algo}/tdmpc/src/logger.py (100%) diff --git a/sim/tdmpc/src/algorithm/__init__.py b/sim/algo/tdmpc/src/algorithm/__init__.py similarity index 100% rename from sim/tdmpc/src/algorithm/__init__.py rename to sim/algo/tdmpc/src/algorithm/__init__.py diff --git a/sim/tdmpc/src/algorithm/helper.py b/sim/algo/tdmpc/src/algorithm/helper.py similarity index 98% rename from sim/tdmpc/src/algorithm/helper.py rename to sim/algo/tdmpc/src/algorithm/helper.py index 87cc908a..4403d71d 100644 --- a/sim/tdmpc/src/algorithm/helper.py +++ b/sim/algo/tdmpc/src/algorithm/helper.py @@ -685,6 +685,18 @@ def _get_obs(self, arr, idxs, bs=None, frame_stack=None): _idxs[mask] -= 1 obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda(self.device) return obs.float() + + def save(self, buffer_fp): + data = { + "obs": self._obs, + "next_obs": self._next_obs, + "action": self._action, + "reward": self._reward, + "mask": self._mask, + "done": self._done, + "priorities": self._priorities, + } + torch.save(data, buffer_fp) def sample(self, bs=None): probs = ( diff --git a/sim/tdmpc/src/algorithm/tdmpc.py b/sim/algo/tdmpc/src/algorithm/tdmpc.py similarity index 99% rename from sim/tdmpc/src/algorithm/tdmpc.py rename to sim/algo/tdmpc/src/algorithm/tdmpc.py index b77e6b8f..b928eb0f 100644 --- a/sim/tdmpc/src/algorithm/tdmpc.py +++ b/sim/algo/tdmpc/src/algorithm/tdmpc.py @@ -3,7 +3,7 @@ import torch.nn as nn from copy import deepcopy from dataclasses import asdict -import sim.tdmpc.src.algorithm.helper as h +import sim.algo.tdmpc.src.algorithm.helper as h class TOLD(nn.Module): diff --git a/sim/tdmpc/src/logger.py b/sim/algo/tdmpc/src/logger.py similarity index 100% rename from sim/tdmpc/src/logger.py rename to sim/algo/tdmpc/src/logger.py diff --git a/sim/train_tdmpc.py b/sim/train_tdmpc.py index d8bab9e4..b9bce1f6 100644 --- a/sim/train_tdmpc.py +++ b/sim/train_tdmpc.py @@ -5,9 +5,9 @@ import torch from sim.envs import task_registry from sim.utils.helpers import get_args -from sim.tdmpc.src import logger -from sim.tdmpc.src.algorithm.helper import Episode, ReplayBuffer -from sim.tdmpc.src.algorithm.tdmpc import TDMPC +from sim.algo.tdmpc.src import logger +from sim.algo.tdmpc.src.algorithm.helper import Episode, ReplayBuffer +from sim.algo.tdmpc.src.algorithm.tdmpc import TDMPC from dataclasses import dataclass, field, asdict from datetime import datetime from isaacgym import gymapi @@ -202,11 +202,11 @@ def train(args: argparse.Namespace) -> None: # Evaluate agent periodically if tdmpc_cfg.save_model and episode_idx % tdmpc_cfg.eval_freq_episode == 0: L.save(agent, f"tdmpc_policy_{int(step // tdmpc_cfg.episode_length) + 1}.pt") + buffer.save(str(work_dir / "buffer.pt")) # # common_metrics['episode_reward'] = evaluate(env, agent, h1 if L.video is not None else None, tdmpc_cfg.eval_episodes, step, env_step, L.video, tdmpc_cfg.action_repeat) # # L.log(common_metrics, category='eval') - - # print('Training completed successfully') + print('Training completed successfully') if __name__ == "__main__": # python -m sim.humanoid_gym.train From 96f0b9194664def2de1006eb9ee302e1a2c5226d Mon Sep 17 00:00:00 2001 From: chamorajg Date: Thu, 12 Sep 2024 17:48:21 -0700 Subject: [PATCH 3/4] update tdmpc fix play_tdmpc.py imports. --- sim/play_tdmpc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sim/play_tdmpc.py b/sim/play_tdmpc.py index b1939554..289d069e 100644 --- a/sim/play_tdmpc.py +++ b/sim/play_tdmpc.py @@ -7,9 +7,9 @@ import torch from sim.envs import task_registry from sim.utils.helpers import get_args -from sim.tdmpc.src import logger -from sim.tdmpc.src.algorithm.helper import Episode, ReplayBuffer -from sim.tdmpc.src.algorithm.tdmpc import TDMPC +from sim.algo.tdmpc.src import logger +from sim.algo.tdmpc.src.algorithm.helper import Episode, ReplayBuffer +from sim.algo.tdmpc.src.algorithm.tdmpc import TDMPC from dataclasses import dataclass, field from isaacgym import gymapi from typing import List @@ -145,7 +145,7 @@ def play(args: argparse.Namespace) -> None: env_cfg, _ = task_registry.get_cfgs(name=args.task) env, _ = task_registry.make_env(name=args.task, args=args) - fp = "/home/guest/sim/logs/2024-09-11_08-08-01_walk_state_dora/models/tdmpc_policy_2350.pt" + fp = "/home/guest/sim/logs/2024-09-12_17-42-05_walk_state_dora/models/tdmpc_policy_10.pt" config = torch.load(fp)["config"] tdmpc_cfg = TDMPC_DoraConfigs(**config) env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) From f3468cf553b31dab315a02da8a711dc4be368815 Mon Sep 17 00:00:00 2001 From: chamorajg Date: Wed, 18 Sep 2024 13:47:48 -0700 Subject: [PATCH 4/4] feat: update tdmpc and buffer update. --- sim/algo/tdmpc/src/algorithm/helper.py | 223 +++--- sim/algo/tdmpc/src/algorithm/tdmpc.py | 510 ++++++------ sim/algo/tdmpc/src/logger.py | 308 ++++---- sim/play_tdmpc.py | 152 ++-- sim/resources/stompymini/__init__.py | 0 sim/resources/stompymini/robot.xml | 598 ++++++++++++++ sim/resources/stompymini/robot_fixed.xml | 961 +++++++++++++++++++++++ sim/train_tdmpc.py | 398 +++++----- 8 files changed, 2416 insertions(+), 734 deletions(-) create mode 100755 sim/resources/stompymini/__init__.py create mode 100755 sim/resources/stompymini/robot.xml create mode 100755 sim/resources/stompymini/robot_fixed.xml diff --git a/sim/algo/tdmpc/src/algorithm/helper.py b/sim/algo/tdmpc/src/algorithm/helper.py index 4403d71d..28d8093f 100644 --- a/sim/algo/tdmpc/src/algorithm/helper.py +++ b/sim/algo/tdmpc/src/algorithm/helper.py @@ -16,7 +16,6 @@ __REDUCE__ = lambda b: "mean" if b else "none" - def l1(pred, target, reduce=False): """Computes the L1-loss between predictions and targets.""" return F.l1_loss(pred, target, reduction=__REDUCE__(reduce)) @@ -30,9 +29,7 @@ def mse(pred, target, reduce=False): def bce(pred, target, logits=True, reduce=False): """Computes the BCE loss between predictions and targets.""" if logits: - return F.binary_cross_entropy_with_logits( - pred, target, reduction=__REDUCE__(reduce) - ) + return F.binary_cross_entropy_with_logits(pred, target, reduction=__REDUCE__(reduce)) return F.binary_cross_entropy(pred, target, reduction=__REDUCE__(reduce)) @@ -90,11 +87,7 @@ def mse_expectile(pred, target, expectile=0.7, reduce=False): def _get_out_shape(in_shape, layers): """Utility function. Returns the output shape of a network for a given input shape.""" x = torch.randn(*in_shape).unsqueeze(0) - return ( - (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x) - .squeeze(0) - .shape - ) + return (nn.Sequential(*layers) if isinstance(layers, list) else layers)(x).squeeze(0).shape def gaussian_logprob(eps, log_std): @@ -136,6 +129,7 @@ def set_requires_grad(net, value): for param in net.parameters(): param.requires_grad_(value) + def linear_schedule(schdl, step): """ Outputs values following a linear decay schedule. @@ -151,6 +145,7 @@ def linear_schedule(schdl, step): return (1.0 - mix) * init + mix * final raise NotImplementedError(schdl) + class TruncatedNormal(pyd.Normal): """Utility class implementing the truncated normal distribution.""" @@ -223,9 +218,7 @@ def enc(cfg): if cfg.modality == "pixels": return ConvExt(nn.Sequential(*pixels_enc_layers)) if cfg.modality in {"state", "all"}: - state_dim = ( - cfg.obs_shape[0] if cfg.modality == "state" else cfg.obs_shape["state"][0] - ) + state_dim = cfg.obs_shape[0] if cfg.modality == "state" else cfg.obs_shape["state"][0] state_enc_layers = [ nn.Linear(state_dim, cfg.enc_dim), nn.LayerNorm(cfg.enc_dim), @@ -259,7 +252,11 @@ def mlp(in_dim, mlp_dim, out_dim, act_fn=nn.ReLU(), layer_norm=True): mlp_dim = [mlp_dim, mlp_dim] layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] for i in range(len(mlp_dim) - 1): - layers += [nn.Linear(mlp_dim[i], mlp_dim[i + 1]), nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), act_fn] + layers += [ + nn.Linear(mlp_dim[i], mlp_dim[i + 1]), + nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), + act_fn, + ] layers += [nn.Linear(mlp_dim[-1], out_dim)] return nn.Sequential(*layers) @@ -279,7 +276,11 @@ def q(in_dim, mlp_dim, act_fn=nn.ReLU(), layer_norm=True): mlp_dim = [mlp_dim, mlp_dim] layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] for i in range(len(mlp_dim) - 1): - layers += [nn.Linear(mlp_dim[i], mlp_dim[i + 1]), nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), act_fn] + layers += [ + nn.Linear(mlp_dim[i], mlp_dim[i + 1]), + nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), + act_fn, + ] layers += [nn.Linear(mlp_dim[-1], 1)] return nn.Sequential(*layers) @@ -290,7 +291,11 @@ def v(in_dim, mlp_dim, act_fn=nn.ReLU(), layer_norm=True): mlp_dim = [mlp_dim, mlp_dim] layers = [nn.Linear(in_dim, mlp_dim[0]), nn.LayerNorm(mlp_dim[0]) if layer_norm else nn.Identity(), act_fn] for i in range(len(mlp_dim) - 1): - layers += [nn.Linear(mlp_dim[i], mlp_dim[i + 1]), nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), act_fn] + layers += [ + nn.Linear(mlp_dim[i], mlp_dim[i + 1]), + nn.LayerNorm(mlp_dim[i + 1]) if layer_norm else nn.Identity(), + act_fn, + ] layers += [[nn.Linear(mlp_dim[-1], 1)]] return nn.Sequential(*layers) @@ -335,7 +340,8 @@ def forward(self, x, key=None): return self.choices[key](x) return {k: self.choices[k](_x) for k, _x in x.items()} return self.choices(x) - + + class Episode(object): """Storage object for a single episode.""" @@ -376,18 +382,38 @@ def __init__(self, cfg, init_obs): device=self.device, ) self.rewards = torch.empty( - (cfg.num_envs, self.capacity,), dtype=torch.float32, device=self.device + ( + cfg.num_envs, + self.capacity, + ), + dtype=torch.float32, + device=self.device, ) self.dones = torch.empty( - (cfg.num_envs, self.capacity,), dtype=torch.bool, device=self.device + ( + cfg.num_envs, + self.capacity, + ), + dtype=torch.bool, + device=self.device, ) self.successes = torch.empty( - (cfg.num_envs, self.capacity,), dtype=torch.bool, device=self.device + ( + cfg.num_envs, + self.capacity, + ), + dtype=torch.bool, + device=self.device, ) self.masks = torch.zeros( - (cfg.num_envs, self.capacity,), dtype=torch.float32, device=self.device + ( + cfg.num_envs, + self.capacity, + ), + dtype=torch.float32, + device=self.device, ) - self.cumulative_reward = torch.tensor([0.] * cfg.num_envs) + self.cumulative_reward = torch.tensor([0.0] * cfg.num_envs) self.done = torch.tensor([False] * cfg.num_envs) self.success = torch.tensor([False] * cfg.num_envs) self._idx = 0 @@ -397,7 +423,7 @@ def __len__(self): @property def episode_length(self): - num_dones = self.dones[:, :self._idx].sum().item() + num_dones = self.dones[:, : self._idx].sum().item() if num_dones > 0: return float(self._idx) * self.cfg.num_envs / num_dones return float(self._idx) @@ -408,23 +434,15 @@ def from_trajectory(cls, cfg, obses, actions, rewards, dones=None, masks=None): if cfg.modality in {"pixels", "state"}: episode = cls(cfg, obses[0]) - episode.obses[1:] = torch.tensor( - obses[1:], dtype=episode.obses.dtype, device=episode.device - ) + episode.obses[1:] = torch.tensor(obses[1:], dtype=episode.obses.dtype, device=episode.device) elif cfg.modality == "all": episode = cls(cfg, {k: v[0] for k, v in obses.items()}) for k, v in obses.items(): - episode.obses[k][1:] = torch.tensor( - obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device - ) + episode.obses[k][1:] = torch.tensor(obses[k][1:], dtype=episode.obses[k].dtype, device=episode.device) else: raise NotImplementedError - episode.actions = torch.tensor( - actions, dtype=episode.actions.dtype, device=episode.device - ) - episode.rewards = torch.tensor( - rewards, dtype=episode.rewards.dtype, device=episode.device - ) + episode.actions = torch.tensor(actions, dtype=episode.actions.dtype, device=episode.device) + episode.rewards = torch.tensor(rewards, dtype=episode.rewards.dtype, device=episode.device) episode.dones = ( torch.tensor(dones, dtype=episode.dones.dtype, device=episode.device) if dones is not None @@ -480,6 +498,7 @@ def add(self, obs, action, reward, done, timeouts, success=False): self.successes[:, self._idx] = torch.tensor(self.success).to(self.device) self._idx += 1 + class ReplayBuffer: """ Storage and sampling functionality for training TD-MPC / TOLD. @@ -496,9 +515,7 @@ def __init__(self, cfg, dataset=None): print("Replay buffer sample device: ", self.device) if dataset is not None: - self.capacity = max( - dataset["rewards"].shape[0], cfg.max_offline_buffer_size - ) + self.capacity = max(dataset["rewards"].shape[0], cfg.max_offline_buffer_size) print("Offline dataset size: ", dataset["rewards"].shape[0]) else: self.capacity = max(cfg.train_steps, cfg.max_buffer_size) @@ -508,15 +525,9 @@ def __init__(self, cfg, dataset=None): if cfg.modality in {"pixels", "state"}: dtype = torch.float32 if cfg.modality == "state" else torch.uint8 # Note self.obs_shape always has single frame, which is different from cfg.obs_shape - self.obs_shape = ( - cfg.obs_shape if cfg.modality == "state" else (3, *cfg.obs_shape[-2:]) - ) - self._obs = torch.empty( - (self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device - ) - self._next_obs = torch.empty( - (self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device - ) + self.obs_shape = cfg.obs_shape if cfg.modality == "state" else (3, *cfg.obs_shape[-2:]) + self._obs = torch.empty((self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device) + self._next_obs = torch.empty((self.capacity, *self.obs_shape), dtype=dtype, device=self.buffer_device) elif cfg.modality == "all": self.obs_shape = {} self._obs, self._next_obs = {}, {} @@ -538,21 +549,11 @@ def __init__(self, cfg, dataset=None): dtype=torch.float32, device=self.buffer_device, ) - self._reward = torch.empty( - (self.capacity,), dtype=torch.float32, device=self.buffer_device - ) - self._mask = torch.empty( - (self.capacity,), dtype=torch.float32, device=self.buffer_device - ) - self._done = torch.empty( - (self.capacity,), dtype=torch.bool, device=self.buffer_device - ) - self._success = torch.empty( - (self.capacity,), dtype=torch.bool, device=self.buffer_device - ) - self._priorities = torch.ones( - (self.capacity,), dtype=torch.float32, device=self.buffer_device - ) + self._reward = torch.empty((self.capacity,), dtype=torch.float32, device=self.buffer_device) + self._mask = torch.empty((self.capacity,), dtype=torch.float32, device=self.buffer_device) + self._done = torch.empty((self.capacity,), dtype=torch.bool, device=self.buffer_device) + self._success = torch.empty((self.capacity,), dtype=torch.bool, device=self.buffer_device) + self._priorities = torch.ones((self.capacity,), dtype=torch.float32, device=self.buffer_device) self.ep_len = int(self.cfg.max_episode_length // self.cfg.action_repeat) self._eps = 1e-6 self._full = False @@ -582,9 +583,7 @@ def copy_data(dst, src, n): # success = self._calc_sparse_success(dataset['success']) # copy_data(self._reward, success.astype(np.float32), n_transitions) if self.cfg.sparse_reward: - copy_data( - self._reward, dataset["success"].astype(np.float32), n_transitions - ) + copy_data(self._reward, dataset["success"].astype(np.float32), n_transitions) copy_data(self._success, dataset["success"], n_transitions) else: copy_data(self._reward, dataset["rewards"], n_transitions) @@ -602,16 +601,14 @@ def __add__(self, episode: Episode): self.add(episode) return self - def add(self, episode: Episode): + def add(self, episode: Episode): idxs = torch.arange(self.idx, self.idx + self.cfg.num_envs * self.ep_len) % self.capacity self._sampling_idx = (self.idx + self.cfg.num_envs * self.ep_len) % self.capacity mask_copy = episode.masks.clone() - mask_copy[:, episode._idx - self.cfg.horizon:] = 0. + mask_copy[:, episode._idx - self.cfg.horizon :] = 0.0 if self.cfg.modality in {"pixels", "state"}: self._obs[idxs] = ( - episode.obses.flatten(0, 1) - if self.cfg.modality == "state" - else episode.obses[:, -3:].flatten(0, 1) + episode.obses.flatten(0, 1) if self.cfg.modality == "state" else episode.obses[:, -3:].flatten(0, 1) ) self._next_obs[idxs] = ( episode.next_obses.flatten(0, 1) @@ -637,14 +634,10 @@ def add(self, episode: Episode): if self._full: max_priority = self._priorities.max().to(self.device).item() else: - max_priority = ( - 1.0 - if self.idx == 0 - else self._priorities[: self.idx].max().to(self.device).item() - ) + max_priority = 1.0 if self.idx == 0 else self._priorities[: self.idx].max().to(self.device).item() mask = torch.arange(self.ep_len) > self.ep_len - self.cfg.horizon mask = torch.cat([mask] * self.cfg.num_envs) - new_priorities = torch.full((self.ep_len * self.cfg.num_envs,), max_priority, device=self.buffer_device) + new_priorities = torch.full((self.ep_len * self.cfg.num_envs,), max_priority, device=self.buffer_device) new_priorities[mask] = 0 new_priorities = new_priorities * self._mask[idxs] self._priorities[idxs] = new_priorities @@ -656,16 +649,11 @@ def _set_bs(self, bs): self.batch_size = bs def update_priorities(self, idxs, priorities): - self._priorities[idxs.to(self.buffer_device)] = ( - priorities.squeeze(1).to(self.buffer_device) + self._eps - ) + self._priorities[idxs.to(self.buffer_device)] = priorities.squeeze(1).to(self.buffer_device) + self._eps def _get_obs(self, arr, idxs, bs=None, frame_stack=None): if isinstance(arr, dict): - return { - k: self._get_obs(v, idxs, bs=bs, frame_stack=frame_stack) - for k, v in arr.items() - } + return {k: self._get_obs(v, idxs, bs=bs, frame_stack=frame_stack) for k, v in arr.items()} if arr.ndim <= 2: # if self.cfg.modality == 'state': return arr[idxs].cuda(self.device) obs = torch.empty( @@ -685,46 +673,69 @@ def _get_obs(self, arr, idxs, bs=None, frame_stack=None): _idxs[mask] -= 1 obs[:, -(i + 1) * 3 : -i * 3] = arr[_idxs].cuda(self.device) return obs.float() - + def save(self, buffer_fp): data = { - "obs": self._obs, - "next_obs": self._next_obs, - "action": self._action, - "reward": self._reward, - "mask": self._mask, - "done": self._done, - "priorities": self._priorities, + "obs": self._obs.cpu(), + "next_obs": self._next_obs.cpu(), + "action": self._action.cpu(), + "reward": self._reward.cpu(), + "mask": self._mask.cpu(), + "done": self._done.cpu(), + "priorities": self._priorities.cpu(), } torch.save(data, buffer_fp) + def load(self, buffer_fp): + data = torch.load(buffer_fp) + n_transitions = data["obs"].shape[0] + + if n_transitions >= self.capacity: + self._obs = self._obs[-self.capacity :] + self._next_obs = data["next_obs"][-self.capacity :].to(self.buffer_device) + self._action = data["action"][-self.capacity :].to(self.buffer_device) + self._reward = data["reward"][-self.capacity :].to(self.buffer_device) + self._mask = data["mask"][-self.capacity :].to(self.buffer_device) + self._done = data["done"][-self.capacity :].to(self.buffer_device) + self._priorities = data["priorities"][-self.capacity :].to(self.buffer_device) + self.idx = 0 + self._full = True + else: + self._obs[:n_transitions] = self._obs + self._next_obs[:n_transitions] = data["next_obs"].to(self.buffer_device) + self._action[:n_transitions] = data["action"].to(self.buffer_device) + self._reward[:n_transitions] = data["reward"].to(self.buffer_device) + self._mask[:n_transitions] = data["mask"].to(self.buffer_device) + self._done[:n_transitions] = data["done"].to(self.buffer_device) + self._priorities[:n_transitions] = data["priorities"].to(self.buffer_device) + self.idx = n_transitions + self._full = False + def sample(self, bs=None): - probs = ( - self._priorities if self._full else self._priorities[:self._sampling_idx] - ) ** self.cfg.per_alpha + probs = (self._priorities if self._full else self._priorities[: self._sampling_idx]) ** self.cfg.per_alpha probs /= probs.sum() total = len(probs) if torch.isnan(self._priorities).any(): print(torch.isnan(self._priorities).any()) print(torch.where(torch.isnan(self._priorities))) - idxs = torch.from_numpy( - np.random.choice( - total, - self.cfg.batch_size if bs is None else bs, - p=probs.cpu().numpy(), - replace=((not self._full) or (self.cfg.batch_size > self.capacity)), - ) - ).to(self.buffer_device) % self.capacity + idxs = ( + torch.from_numpy( + np.random.choice( + total, + self.cfg.batch_size if bs is None else bs, + p=probs.cpu().numpy(), + replace=((not self._full) or (self.cfg.batch_size > self.capacity)), + ) + ).to(self.buffer_device) + % self.capacity + ) weights = (total * probs[idxs]) ** (-self.cfg.per_beta) weights /= weights.max() idxs_in_horizon = torch.stack([idxs + t for t in range(self.cfg.horizon)]) % self.capacity obs = self._aug(self._get_obs(self._obs, idxs, bs=bs)) - next_obs = [ - self._aug(self._get_obs(self._next_obs, _idxs, bs=bs)) - for _idxs in idxs_in_horizon - ] + next_obs = [self._aug(self._get_obs(self._next_obs, _idxs, bs=bs)) for _idxs in idxs_in_horizon] if isinstance(next_obs[0], dict): next_obs = {k: torch.stack([o[k] for o in next_obs]) for k in next_obs[0]} else: @@ -750,4 +761,4 @@ def sample(self, bs=None): done.unsqueeze(2), idxs, weights, - ) \ No newline at end of file + ) diff --git a/sim/algo/tdmpc/src/algorithm/tdmpc.py b/sim/algo/tdmpc/src/algorithm/tdmpc.py index b928eb0f..09028197 100644 --- a/sim/algo/tdmpc/src/algorithm/tdmpc.py +++ b/sim/algo/tdmpc/src/algorithm/tdmpc.py @@ -7,242 +7,274 @@ class TOLD(nn.Module): - """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" - def __init__(self, cfg): - super().__init__() - self.cfg = cfg - self._encoder = h.enc(cfg) - self._dynamics = h.mlp(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim, cfg.latent_dim) - self._reward = h.mlp(cfg.latent_dim+cfg.action_dim, cfg.mlp_dim, 1) - self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, cfg.action_dim) - self._Qs = nn.ModuleList( - [h.q(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim) for _ in range(cfg.num_q)]) - self.apply(h.orthogonal_init) - for m in [self._reward, *self._Qs]: - m[-1].weight.data.fill_(0) - m[-1].bias.data.fill_(0) - - def track_q_grad(self, enable=True): - """Utility function. Enables/disables gradient tracking of Q-networks.""" - for m in self._Qs: - h.set_requires_grad(m, enable) - - def h(self, obs): - """Encodes an observation into its latent representation (h).""" - return self._encoder(obs) - - def next(self, z, a): - """Predicts next latent state (d) and single-step reward (R).""" - x = torch.cat([z, a], dim=-1) - return self._dynamics(x), self._reward(x) - - def pi(self, z, std=0): - """Samples an action from the learned policy (pi).""" - mu = torch.tanh(self._pi(z)) - if std > 0: - std = torch.ones_like(mu) * std - return h.TruncatedNormal(mu, std).sample(clip=0.3) - return mu - - def Q(self, z, a): - """Predict state-action value (Q).""" - x = torch.cat([z, a], dim=-1) - Qs = torch.stack([self._Qs[i](x) for i in range(self.cfg.num_q)], dim=0) - return Qs - - -class TDMPC(): - """Implementation of TD-MPC learning + inference.""" - def __init__(self, cfg): - self.cfg = cfg - self.device = torch.device(cfg.device) - self.std = h.linear_schedule(cfg.std_schedule, 0) - self.model = TOLD(cfg).to(self.device) - self.model_target = deepcopy(self.model) - self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) - self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=3 * self.cfg.lr) - self.aug = nn.Identity() - self.model.eval() - self.model_target.eval() - - def state_dict(self): - """Retrieve state dict of TOLD model, including slow-moving target network.""" - return { - 'model': self.model.state_dict(), - 'model_target': self.model_target.state_dict(), - 'config': asdict(self.cfg), - } - - def save(self, fp): - """Save state dict of TOLD model to filepath.""" - torch.save(self.save_dict(), fp) - - def load(self, fp): - """Load a saved state dict from filepath into current agent.""" - d = torch.load(fp) - self.model.load_state_dict(d['model']) - self.model_target.load_state_dict(d['model_target']) - - @torch.no_grad() - def estimate_value(self, z, actions, horizon): - """Estimate value of a trajectory starting at latent state z and executing given actions.""" - G, discount = 0, 1 - for t in range(horizon): - z, reward = self.model.next(z, actions[t]) - G += discount * reward - discount *= self.cfg.discount - G += discount * torch.min(self.model.Q(z, self.model.pi(z, self.cfg.min_std)), dim=0)[0] - return G - - @torch.no_grad() - def plan(self, obs, eval_mode=False, step=None, t0=True): - """ - Plan next action using TD-MPC inference. - obs: raw input observation. - eval_mode: uniform sampling and action noise is disabled during evaluation. - step: current time step. determines e.g. planning horizon. - t0: whether current step is the first step of an episode. - """ - # Seed steps - if step < self.cfg.seed_steps and not eval_mode: - return torch.empty(self.cfg.action_dim, dtype=torch.float32, device=self.device).uniform_(-1, 1) - - # Sample policy trajectories - obs = obs.clone().to(self.device, dtype=torch.float32).unsqueeze(1) - horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) - num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) - if num_pi_trajs > 0: - pi_actions = torch.empty(horizon, self.cfg.num_envs, num_pi_trajs, self.cfg.action_dim, device=self.device) - z = self.model.h(obs).repeat(1, num_pi_trajs, 1) - for t in range(horizon): - pi_actions[t] = self.model.pi(z, self.cfg.min_std) - z, _ = self.model.next(z, pi_actions[t]) - - # Initialize state and parameters - z = self.model.h(obs).repeat(1, self.cfg.num_samples + num_pi_trajs, 1) - mean = torch.zeros(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) - std = 2 * torch.ones(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) - - if isinstance(t0, bool) and t0 and hasattr(self, '_prev_mean') and self._prev_mean.shape[0] > 1: - _prev_h = self._prev_mean.shape[0] - 1 - mean[:_prev_h] = self._prev_mean[1:] - elif torch.is_tensor(t0) and t0.any() and hasattr(self, '_prev_mean') and self._prev_mean.shape[0] > 1: - _prev_h = self._prev_mean.shape[0] - 1 - mean[:_prev_h] = self._prev_mean[1:] - - # Iterate CEM - for i in range(self.cfg.iterations): - actions = torch.clamp(mean.unsqueeze(2) + std.unsqueeze(2) * \ - torch.randn(horizon, self.cfg.num_envs, self.cfg.num_samples, self.cfg.action_dim, device=std.device), -1, 1) - if num_pi_trajs > 0: - actions = torch.cat([actions, pi_actions], dim=-2) - - # Compute elite actions - value = self.estimate_value(z, actions, horizon).nan_to_num_(0) - elite_idxs = torch.topk(value.squeeze(-1), self.cfg.num_elites, dim=-1).indices - elite_value, elite_actions = value.squeeze(-1).gather(-1, elite_idxs), actions.gather(-2, elite_idxs.unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim)) - - # Update parameters - max_value = elite_value.max(1, keepdim=True)[0] - score = torch.exp(self.cfg.temperature*(elite_value - max_value)) - score /= score.sum(1, keepdim=True) - _mean = torch.sum(score.unsqueeze(0).unsqueeze(-1) * elite_actions, dim=-2) / (score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9) - _std = torch.sqrt(torch.sum(score.unsqueeze(0).unsqueeze(-1) * (elite_actions - _mean.unsqueeze(2)) ** 2, dim=-2) / (score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9)) - _std = _std.clamp_(self.std, 2) - mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std - - # Outputs - select_indices = torch.multinomial(score, 1) - actions = elite_actions.gather(-2, select_indices.unsqueeze(0).unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim)).squeeze(-2) - self._prev_mean = mean - mean, std = actions[0], _std[0] - a = mean - if not eval_mode: - a += std * torch.randn(self.cfg.action_dim, device=std.device) - return a - - def update_pi(self, zs): - """Update policy using a sequence of latent states.""" - self.pi_optim.zero_grad(set_to_none=True) - self.model.track_q_grad(False) - - # Loss is a weighted sum of Q-values - pi_loss = 0 - for t,z in enumerate(zs): - a = self.model.pi(z, self.cfg.min_std) - Q = torch.min(self.model.Q(z, a), dim=0)[0] - pi_loss += -Q.mean() * (self.cfg.rho ** t) - - pi_loss.backward() - torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False) - self.pi_optim.step() - self.model.track_q_grad(True) - return pi_loss.item() - - @torch.no_grad() - def _td_target(self, next_obs, reward, mask=1.0): - """Compute the TD-target from a reward and the observation at the following time step.""" - next_z = self.model.h(next_obs) - td_target = reward + self.cfg.discount * mask * \ - torch.min(self.model_target.Q(next_z, self.model.pi(next_z, self.cfg.min_std)), dim=0)[0] - return td_target - - def update(self, replay_buffer, step): - """Main update function. Corresponds to one iteration of the TOLD model learning.""" - obs, next_obses, action, reward, mask, idxs, weights = replay_buffer.sample() - self.optim.zero_grad(set_to_none=True) - self.std = h.linear_schedule(self.cfg.std_schedule, step) - self.model.train() - - # Representation - z = self.model.h(self.aug(obs)) - zs = [z.detach()] - - loss_mask = torch.ones_like(mask[0], device=self.device) - - - consistency_loss, reward_loss, value_loss, priority_loss = 0, 0, 0, 0 - for t in range(self.cfg.horizon): - if t > 0: - loss_mask = loss_mask * mask[t - 1] - # Predictions - Qs = self.model.Q(z, action[t]) - z, reward_pred = self.model.next(z, action[t]) - with torch.no_grad(): - next_obs = self.aug(next_obses[t]) - next_z = self.model_target.h(next_obs) - td_target = self._td_target(next_obs, mask[t], reward[t]) - zs.append(z.detach()) - - # Losses - rho = (self.cfg.rho ** t) - consistency_loss += loss_mask[t] * rho * torch.mean(h.mse(z, next_z), dim=1, keepdim=True) - reward_loss += loss_mask[t] * rho * h.mse(reward_pred, reward[t]) - for i in range(self.cfg.num_q): - value_loss += loss_mask[t] * rho * h.mse(Qs[i], td_target) - priority_loss += loss_mask[t] * rho * h.l1(Qs[i], td_target) - - # Optimize model - total_loss = self.cfg.consistency_coef * consistency_loss.clamp(max=1e4) + \ - self.cfg.reward_coef * reward_loss.clamp(max=1e4) + \ - self.cfg.value_coef * value_loss.clamp(max=1e4) - weighted_loss = (total_loss.squeeze(1) * weights).mean() - weighted_loss.register_hook(lambda grad: grad * (1/self.cfg.horizon)) - weighted_loss.backward() - grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False) - self.optim.step() - replay_buffer.update_priorities(idxs, priority_loss.clamp(max=1e4).detach()) - - if step % self.cfg.update_freq == 0: - # Update policy + target network - pi_loss = self.update_pi(zs) - h.ema(self.model, self.model_target, self.cfg.tau) - - self.model.eval() - return {'consistency_loss': float(consistency_loss.mean().item()), - 'reward_loss': float(reward_loss.mean().item()), - 'value_loss': float(value_loss.mean().item()), - 'pi_loss': pi_loss if step % self.cfg.update_freq == 0 else 0., - 'total_loss': float(total_loss.mean().item()), - 'weighted_loss': float(weighted_loss.mean().item()), - 'grad_norm': float(grad_norm)} + """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" + + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self._encoder = h.enc(cfg) + self._dynamics = h.mlp(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim, cfg.latent_dim) + self._reward = h.mlp(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim, 1) + self._pi = h.mlp(cfg.latent_dim, cfg.mlp_dim, cfg.action_dim) + self._Qs = nn.ModuleList([h.q(cfg.latent_dim + cfg.action_dim, cfg.mlp_dim) for _ in range(cfg.num_q)]) + self.apply(h.orthogonal_init) + for m in [self._reward, *self._Qs]: + m[-1].weight.data.fill_(0) + m[-1].bias.data.fill_(0) + + def track_q_grad(self, enable=True): + """Utility function. Enables/disables gradient tracking of Q-networks.""" + for m in self._Qs: + h.set_requires_grad(m, enable) + + def h(self, obs): + """Encodes an observation into its latent representation (h).""" + return self._encoder(obs) + + def next(self, z, a): + """Predicts next latent state (d) and single-step reward (R).""" + x = torch.cat([z, a], dim=-1) + return self._dynamics(x), self._reward(x) + + def pi(self, z, std=0): + """Samples an action from the learned policy (pi).""" + mu = self.cfg.max_clip_actions * torch.tanh(self._pi(z)) + if std > 0: + std = torch.ones_like(mu) * std + return h.TruncatedNormal(mu, std, low=-self.cfg.max_clip_actions, high=self.cfg.max_clip_actions).sample( + clip=0.3 + ) + return mu + + def Q(self, z, a): + """Predict state-action value (Q).""" + x = torch.cat([z, a], dim=-1) + Qs = torch.stack([self._Qs[i](x) for i in range(self.cfg.num_q)], dim=0) + return Qs + + +class TDMPC: + """Implementation of TD-MPC learning + inference.""" + + def __init__(self, cfg): + self.cfg = cfg + self.device = torch.device(cfg.device) + self.std = h.linear_schedule(cfg.std_schedule, 0) + self.model = TOLD(cfg).to(self.device) + self.model_target = deepcopy(self.model) + self.optim = torch.optim.Adam(self.model.parameters(), lr=self.cfg.lr) + self.pi_optim = torch.optim.Adam(self.model._pi.parameters(), lr=3 * self.cfg.lr) + self.aug = nn.Identity() + self.model.eval() + self.model_target.eval() + + def state_dict(self, step=0, env_step=0): + """Retrieve state dict of TOLD model, including slow-moving target network.""" + return { + "model": self.model.state_dict(), + "model_target": self.model_target.state_dict(), + "config": asdict(self.cfg), + "step": step, + "env_step": env_step, + } + + def save(self, fp): + """Save state dict of TOLD model to filepath.""" + torch.save(self.save_dict(), fp) + + def load(self, fp): + """Load a saved state dict from filepath into current agent.""" + d = torch.load(fp) + self.model.load_state_dict(d["model"]) + self.model_target.load_state_dict(d["model_target"]) + return d["step"], d["env_step"] + + @torch.no_grad() + def estimate_value(self, z, actions, horizon): + """Estimate value of a trajectory starting at latent state z and executing given actions.""" + G, discount = 0, 1 + for t in range(horizon): + z, reward = self.model.next(z, actions[t]) + G += discount * reward + discount *= self.cfg.discount + G += discount * torch.min(self.model.Q(z, self.model.pi(z, self.cfg.min_std)), dim=0)[0] + return G + + @torch.no_grad() + def plan(self, obs, eval_mode=False, step=None, t0=True): + """ + Plan next action using TD-MPC inference. + obs: raw input observation. + eval_mode: uniform sampling and action noise is disabled during evaluation. + step: current time step. determines e.g. planning horizon. + t0: whether current step is the first step of an episode. + """ + clip_actions = h.linear_schedule(self.cfg.clip_actions, step=step) + # Seed steps + if step < self.cfg.seed_steps and not eval_mode: + return torch.empty(self.cfg.action_dim, dtype=torch.float32, device=self.device).uniform_( + -clip_actions, clip_actions + ) + + # Sample policy trajectories + obs = obs.clone().to(self.device, dtype=torch.float32).unsqueeze(1) + horizon = int(min(self.cfg.horizon, h.linear_schedule(self.cfg.horizon_schedule, step))) + num_pi_trajs = int(self.cfg.mixture_coef * self.cfg.num_samples) + if num_pi_trajs > 0: + pi_actions = torch.empty(horizon, self.cfg.num_envs, num_pi_trajs, self.cfg.action_dim, device=self.device) + z = self.model.h(obs).repeat(1, num_pi_trajs, 1) + for t in range(horizon): + pi_actions[t] = self.model.pi(z, self.cfg.min_std) + z, _ = self.model.next(z, pi_actions[t]) + + # Initialize state and parameters + z = self.model.h(obs).repeat(1, self.cfg.num_samples + num_pi_trajs, 1) + mean = torch.zeros(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) + std = 1.5 * torch.ones(horizon, self.cfg.num_envs, self.cfg.action_dim, device=self.device) * clip_actions + + if isinstance(t0, bool) and t0 and hasattr(self, "_prev_mean") and self._prev_mean.shape[0] > 1: + _prev_h = self._prev_mean.shape[0] - 1 + mean[:_prev_h] = self._prev_mean[1:] + elif torch.is_tensor(t0) and t0.any() and hasattr(self, "_prev_mean") and self._prev_mean.shape[0] > 1: + _prev_h = self._prev_mean.shape[0] - 1 + mean[:_prev_h] = self._prev_mean[1:] + + # Iterate CEM + for i in range(self.cfg.iterations): + actions = torch.clamp( + mean.unsqueeze(2) + + std.unsqueeze(2) + * torch.randn(horizon, self.cfg.num_envs, self.cfg.num_samples, self.cfg.action_dim, device=std.device), + -clip_actions, + clip_actions, + ) + if num_pi_trajs > 0: + actions = torch.cat([actions, pi_actions], dim=-2) + + # Compute elite actions + value = self.estimate_value(z, actions, horizon).nan_to_num_(0) + elite_idxs = torch.topk(value.squeeze(-1), self.cfg.num_elites, dim=-1).indices + elite_value, elite_actions = value.squeeze(-1).gather(-1, elite_idxs), actions.gather( + -2, elite_idxs.unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim) + ) + + # Update parameters + max_value = elite_value.max(1, keepdim=True)[0] + score = torch.exp(self.cfg.temperature * (elite_value - max_value)) + score /= score.sum(1, keepdim=True) + _mean = torch.sum(score.unsqueeze(0).unsqueeze(-1) * elite_actions, dim=-2) / ( + score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9 + ) + _std = torch.sqrt( + torch.sum(score.unsqueeze(0).unsqueeze(-1) * (elite_actions - _mean.unsqueeze(2)) ** 2, dim=-2) + / (score.sum(-1, keepdim=True).unsqueeze(0) + 1e-9) + ) + _std = _std.clamp_(self.std, 1.5 * clip_actions) + mean, std = self.cfg.momentum * mean + (1 - self.cfg.momentum) * _mean, _std + + # Outputs + select_indices = torch.multinomial(score, 1) + actions = elite_actions.gather( + -2, select_indices.unsqueeze(0).unsqueeze(-1).repeat(horizon, 1, 1, self.cfg.action_dim) + ).squeeze(-2) + self._prev_mean = mean + mean, std = actions[0], _std[0] + a = mean + if not eval_mode: + a = h.TruncatedNormal(a, std, low=-clip_actions, high=clip_actions).sample() + return a + + def update_pi(self, zs): + """Update policy using a sequence of latent states.""" + self.pi_optim.zero_grad(set_to_none=True) + self.model.track_q_grad(False) + + # Loss is a weighted sum of Q-values + pi_loss = 0 + for t, z in enumerate(zs): + a = self.model.pi(z, self.cfg.min_std) + Q = torch.min(self.model.Q(z, a), dim=0)[0] + pi_loss += -Q.mean() * (self.cfg.rho**t) + + pi_loss.backward() + torch.nn.utils.clip_grad_norm_(self.model._pi.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False) + self.pi_optim.step() + self.model.track_q_grad(True) + return pi_loss.item() + + @torch.no_grad() + def _td_target(self, next_obs, reward, mask=1.0): + """Compute the TD-target from a reward and the observation at the following time step.""" + next_z = self.model.h(next_obs) + td_target = ( + reward + + self.cfg.discount + * mask + * torch.min(self.model_target.Q(next_z, self.model.pi(next_z, self.cfg.min_std)), dim=0)[0] + ) + return td_target + + def update(self, replay_buffer, step): + """Main update function. Corresponds to one iteration of the TOLD model learning.""" + obs, next_obses, action, reward, mask, idxs, weights = replay_buffer.sample() + self.optim.zero_grad(set_to_none=True) + self.std = h.linear_schedule(self.cfg.std_schedule, step) + self.model.train() + + # Representation + z = self.model.h(self.aug(obs)) + zs = [z.detach()] + + loss_mask = torch.ones_like(mask[0], device=self.device) + + consistency_loss, reward_loss, value_loss, priority_loss = 0, 0, 0, 0 + for t in range(self.cfg.horizon): + if t > 0: + loss_mask = loss_mask * mask[t - 1] + # Predictions + Qs = self.model.Q(z, action[t]) + z, reward_pred = self.model.next(z, action[t]) + with torch.no_grad(): + next_obs = self.aug(next_obses[t]) + next_z = self.model_target.h(next_obs) + td_target = self._td_target(next_obs, mask[t], reward[t]) + zs.append(z.detach()) + + # Losses + rho = self.cfg.rho**t + consistency_loss += loss_mask[t] * rho * torch.mean(h.mse(z, next_z), dim=1, keepdim=True) + reward_loss += loss_mask[t] * rho * h.mse(reward_pred, reward[t]) + for i in range(self.cfg.num_q): + value_loss += loss_mask[t] * rho * h.mse(Qs[i], td_target) + priority_loss += loss_mask[t] * rho * h.l1(Qs[i], td_target) + + # Optimize model + total_loss = ( + self.cfg.consistency_coef * consistency_loss.clamp(max=1e4) + + self.cfg.reward_coef * reward_loss.clamp(max=1e4) + + self.cfg.value_coef * value_loss.clamp(max=1e4) + ) + weighted_loss = (total_loss.squeeze(1) * weights).mean() + weighted_loss.register_hook(lambda grad: grad * (1 / self.cfg.horizon)) + weighted_loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.cfg.grad_clip_norm, error_if_nonfinite=False + ) + self.optim.step() + replay_buffer.update_priorities(idxs, priority_loss.clamp(max=1e4).detach()) + + if step % self.cfg.update_freq == 0: + # Update policy + target network + pi_loss = self.update_pi(zs) + h.ema(self.model, self.model_target, self.cfg.tau) + + self.model.eval() + return { + "consistency_loss": float(consistency_loss.mean().item()), + "reward_loss": float(reward_loss.mean().item()), + "value_loss": float(value_loss.mean().item()), + "pi_loss": pi_loss if step % self.cfg.update_freq == 0 else 0.0, + "total_loss": float(total_loss.mean().item()), + "weighted_loss": float(weighted_loss.mean().item()), + "grad_norm": float(grad_norm), + } diff --git a/sim/algo/tdmpc/src/logger.py b/sim/algo/tdmpc/src/logger.py index e4111ebe..cdc6f9b7 100644 --- a/sim/algo/tdmpc/src/logger.py +++ b/sim/algo/tdmpc/src/logger.py @@ -10,164 +10,180 @@ from termcolor import colored -CONSOLE_FORMAT = [('episode', 'E', 'int'), ('step', 'S', 'int'), ('env_step', 'ES', 'int'), - ('episode_reward', 'R', 'float'), ('mean_episode_length', 'MEL', 'float'), ('total_time', 'T', 'time'),] - # ('consistency_loss', 'CL', 'float'), ('value_loss', 'PL', 'float'), ('pi_loss', 'PL', 'float'), ('total_loss', 'L', 'float')] -AGENT_METRICS = ['consistency_loss', 'reward_loss', 'value_loss', 'total_loss', 'weighted_loss', 'pi_loss', 'grad_norm'] +CONSOLE_FORMAT = [ + ("episode", "E", "int"), + ("step", "S", "int"), + ("env_step", "ES", "int"), + ("episode_reward", "R", "float"), + ("mean_episode_length", "MEL", "float"), + ("total_time", "T", "time"), +] +# ('consistency_loss', 'CL', 'float'), ('value_loss', 'PL', 'float'), ('pi_loss', 'PL', 'float'), ('total_loss', 'L', 'float')] +AGENT_METRICS = ["consistency_loss", "reward_loss", "value_loss", "total_loss", "weighted_loss", "pi_loss", "grad_norm"] def make_dir(dir_path): - """Create directory if it does not already exist.""" - try: - os.makedirs(dir_path) - except OSError: - pass - return dir_path + """Create directory if it does not already exist.""" + try: + os.makedirs(dir_path) + except OSError: + pass + return dir_path def print_run(cfg, reward=None): - """Pretty-printing of run information. Call at start of training.""" - prefix, color, attrs = ' ', 'green', ['bold'] - def limstr(s, maxlen=32): - return str(s[:maxlen]) + '...' if len(str(s)) > maxlen else s - def pprint(k, v): - print(prefix + colored(f'{k.capitalize()+":":<16}', color, attrs=attrs), limstr(v)) - kvs = [('task', cfg.task), - ('train steps', f'{int(cfg.train_steps*cfg.action_repeat):,}'), - ('observations', 'x'.join([str(s) for s in cfg.obs_shape])), - ('actions', cfg.action_dim), - ('experiment', cfg.exp_name)] - if reward is not None: - kvs.append(('episode reward', colored(str(int(reward)), 'white', attrs=['bold']))) - w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 - div = '-'*w - print(div) - for k,v in kvs: - pprint(k, v) - print(div) + """Pretty-printing of run information. Call at start of training.""" + prefix, color, attrs = " ", "green", ["bold"] + + def limstr(s, maxlen=32): + return str(s[:maxlen]) + "..." if len(str(s)) > maxlen else s + + def pprint(k, v): + print(prefix + colored(f'{k.capitalize()+":":<16}', color, attrs=attrs), limstr(v)) + + kvs = [ + ("task", cfg.task), + ("train steps", f"{int(cfg.train_steps*cfg.action_repeat):,}"), + ("observations", "x".join([str(s) for s in cfg.obs_shape])), + ("actions", cfg.action_dim), + ("experiment", cfg.exp_name), + ] + if reward is not None: + kvs.append(("episode reward", colored(str(int(reward)), "white", attrs=["bold"]))) + w = np.max([len(limstr(str(kv[1]))) for kv in kvs]) + 21 + div = "-" * w + print(div) + for k, v in kvs: + pprint(k, v) + print(div) def cfg_to_group(cfg, return_list=False): - """Return a wandb-safe group name for logging. Optionally returns group name as list.""" - lst = [cfg.task, cfg.modality, re.sub('[^0-9a-zA-Z]+', '-', cfg.exp_name)] - return lst if return_list else '-'.join(lst) + """Return a wandb-safe group name for logging. Optionally returns group name as list.""" + lst = [cfg.task, cfg.modality, re.sub("[^0-9a-zA-Z]+", "-", cfg.exp_name)] + return lst if return_list else "-".join(lst) class VideoRecorder: - """Utility class for logging evaluation videos.""" - def __init__(self, root_dir, wandb, render_size=384, fps=15): - self.save_dir = (root_dir / 'eval_video') if root_dir else None - self._wandb = wandb - self.render_size = render_size - self.fps = fps - self.frames = [] - self.enabled = False - - def init(self, env, h1, enabled=True): - self.frames = [] - self.enabled = self.save_dir and self._wandb and enabled - self.record(env, h1) - - def record(self, env, h1): - if self.enabled: - env.gym.fetch_results(env.sim, True) - env.gym.step_graphics(env.sim) - env.gym.render_all_camera_sensors(env.sim) - img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR) - img = np.reshape(img, (120, 160, 4)) - img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) - self.frames.append(img) - - def save(self, step): - if self.enabled: - frames = np.stack(self.frames).transpose(0, 3, 1, 2) - self._wandb.log({'eval_video': self._wandb.Video(frames, fps=self.fps, format='mp4')}, step=step) + """Utility class for logging evaluation videos.""" + + def __init__(self, root_dir, wandb, render_size=384, fps=15): + self.save_dir = (root_dir / "eval_video") if root_dir else None + self._wandb = wandb + self.render_size = render_size + self.fps = fps + self.frames = [] + self.enabled = False + + def init(self, env, h1, enabled=True): + self.frames = [] + self.enabled = self.save_dir and self._wandb and enabled + self.record(env, h1) + + def record(self, env, h1): + if self.enabled: + env.gym.fetch_results(env.sim, True) + env.gym.step_graphics(env.sim) + env.gym.render_all_camera_sensors(env.sim) + img = env.gym.get_camera_image(env.sim, env.envs[0], h1, gymapi.IMAGE_COLOR) + img = np.reshape(img, (120, 160, 4)) + img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) + self.frames.append(img) + + def save(self, step): + if self.enabled: + frames = np.stack(self.frames).transpose(0, 3, 1, 2) + self._wandb.log({"eval_video": self._wandb.Video(frames, fps=self.fps, format="mp4")}, step=step) class Logger(object): - """Primary logger object. Logs either locally or using wandb.""" - def __init__(self, log_dir, cfg): - project, entity = cfg.wandb_project, cfg.wandb_entity - run_offline = not cfg.use_wandb or project == 'none' or entity == 'none' - - self._save_model = cfg.save_model - if not run_offline or self._save_model: - self._log_dir = make_dir(log_dir) - - if self._save_model: - self._model_dir = make_dir(self._log_dir / 'models') - - self._group = cfg_to_group(cfg) - self._seed = cfg.seed - self._cfg = cfg - self._eval = [] - print_run(cfg) - if run_offline: - print(colored('Logs will be saved locally.', 'yellow', attrs=['bold'])) - self._wandb = None - else: - try: - os.environ["WANDB_SILENT"] = "true" - import wandb - wandb.init(project=project, - entity=entity, - name=str(cfg.seed), - group=self._group, - tags=cfg_to_group(cfg, return_list=True) + [f'seed:{cfg.seed}'], - dir=self._log_dir,) - print(colored('Logs will be synced with wandb.', 'blue', attrs=['bold'])) - self._wandb = wandb - except: - print(colored('Warning: failed to init wandb. Logs will be saved locally.', 'yellow', attrs=['bold'])) - self._wandb = None - self._video = VideoRecorder(log_dir, self._wandb) if (self._wandb and cfg.save_video) else None - - @property - def video(self): - return self._video - - def save(self, agent, model_name="model.pt"): - if self._save_model: - fp = self._model_dir / f"{model_name}" - torch.save(agent.state_dict(), fp) - - def finish(self, agent, model_name="model.pt"): - if self._save_model: - fp = self._model_dir / f"{model_name}" - torch.save(agent.state_dict(), fp) - # if self._wandb: - # artifact = self._wandb.Artifact(self._group+'-'+str(self._seed), type='model') - # artifact.add_file(fp) - # self._wandb.log_artifact(artifact) - if self._wandb: - self._wandb.finish() - print_run(self._cfg, self._eval[-1][-1]) - - def _format(self, key, value, ty): - if ty == 'int': - return f'{colored(key+":", "grey")} {int(value):,}' - elif ty == 'float': - return f'{colored(key+":", "grey")} {value:.03f}' - elif ty == 'time': - value = str(datetime.timedelta(seconds=int(value))) - return f'{colored(key+":", "grey")} {value}' - else: - raise f'invalid log format type: {ty}' - - def _print(self, d, category): - category = colored(category, 'blue' if category == 'train' else 'green') - pieces = [f' {category:<14}'] - for k, disp_k, ty in CONSOLE_FORMAT: - pieces.append(f'{self._format(disp_k, d.get(k, 0), ty):<26}') - print(' '.join(pieces)) - - def log(self, d, category='train'): - assert category in {'train', 'eval'} - if self._wandb is not None: - for k,v in d.items(): - self._wandb.log({category + '/' + k: v}, step=d['env_step']) - if category == 'eval': - keys = ['env_step', 'episode_reward'] - self._eval.append(np.array([d[keys[0]], d[keys[1]]])) - pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / 'eval.log', header=keys, index=None) - self._print(d, category) + """Primary logger object. Logs either locally or using wandb.""" + + def __init__(self, log_dir, cfg): + project, entity = cfg.wandb_project, cfg.wandb_entity + run_offline = not cfg.use_wandb or project == "none" or entity == "none" + + self._save_model = cfg.save_model + if not run_offline or self._save_model: + self._log_dir = make_dir(log_dir) + + if self._save_model: + self._model_dir = make_dir(self._log_dir / "models") + + self._group = cfg_to_group(cfg) + self._seed = cfg.seed + self._cfg = cfg + self._eval = [] + print_run(cfg) + if run_offline: + print(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + self._wandb = None + else: + try: + os.environ["WANDB_SILENT"] = "true" + import wandb + + wandb.init( + project=project, + entity=entity, + name=str(cfg.seed), + group=self._group, + tags=cfg_to_group(cfg, return_list=True) + [f"seed:{cfg.seed}"], + dir=self._log_dir, + ) + print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) + self._wandb = wandb + except: + print(colored("Warning: failed to init wandb. Logs will be saved locally.", "yellow", attrs=["bold"])) + self._wandb = None + self._video = VideoRecorder(log_dir, self._wandb) if (self._wandb and cfg.save_video) else None + + @property + def video(self): + return self._video + + def save(self, agent, model_name="model.pt", step=None, env_step=None): + if self._save_model: + fp = self._model_dir / f"{model_name}" + torch.save(agent.state_dict(step=step, env_step=env_step), fp) + + def finish(self, agent, model_name="model.pt"): + if self._save_model: + fp = self._model_dir / f"{model_name}" + torch.save(agent.state_dict(), fp) + # if self._wandb: + # artifact = self._wandb.Artifact(self._group+'-'+str(self._seed), type='model') + # artifact.add_file(fp) + # self._wandb.log_artifact(artifact) + if self._wandb: + self._wandb.finish() + print_run(self._cfg, self._eval[-1][-1]) + + def _format(self, key, value, ty): + if ty == "int": + return f'{colored(key+":", "grey")} {int(value):,}' + elif ty == "float": + return f'{colored(key+":", "grey")} {value:.03f}' + elif ty == "time": + value = str(datetime.timedelta(seconds=int(value))) + return f'{colored(key+":", "grey")} {value}' + else: + raise f"invalid log format type: {ty}" + + def _print(self, d, category): + category = colored(category, "blue" if category == "train" else "green") + pieces = [f" {category:<14}"] + for k, disp_k, ty in CONSOLE_FORMAT: + pieces.append(f"{self._format(disp_k, d.get(k, 0), ty):<26}") + print(" ".join(pieces)) + + def log(self, d, category="train"): + assert category in {"train", "eval"} + if self._wandb is not None: + for k, v in d.items(): + self._wandb.log({category + "/" + k: v}, step=d["env_step"]) + if category == "eval": + keys = ["env_step", "episode_reward"] + self._eval.append(np.array([d[keys[0]], d[keys[1]]])) + pd.DataFrame(np.array(self._eval)).to_csv(self._log_dir / "eval.log", header=keys, index=None) + self._print(d, category) diff --git a/sim/play_tdmpc.py b/sim/play_tdmpc.py index 289d069e..29670401 100644 --- a/sim/play_tdmpc.py +++ b/sim/play_tdmpc.py @@ -18,77 +18,91 @@ from pathlib import Path import random from datetime import datetime +from tqdm import tqdm + torch.backends.cudnn.benchmark = True __LOGS__ = "logs" + @dataclass class TDMPC_DoraConfigs: seed: int = 42 - task : str = "walk" - exp_name : str = "dora" - device : str = "cuda:0" - num_envs : int = 10 - - lr : float = 1e-3 - modality : str = "state" - enc_dim: int = 512 # 256 - mlp_dim = [512, 256] # [256, 256] + task: str = "walk" + exp_name: str = "dora" + device: str = "cuda:0" + num_envs: int = 10 + episode_length: int = 100 + max_episode_length: int = 1000 + max_clip_actions: float = 1.0 + clip_actions: str = f"linear(1, {max_clip_actions}, 100000)" + + lr: float = 1e-3 + modality: str = "state" + enc_dim: int = 512 # 256 + mlp_dim = [512, 256] # [256, 256] latent_dim: int = 100 - iterations : int = 12 - num_samples : int = 512 - num_elites : int = 50 - mixture_coef : float = 0.05 - min_std : float = 0.05 - temperature : float = 0.5 - momentum : float = 0.1 - horizon : int = 5 + iterations: int = 12 + num_samples: int = 512 + num_elites: int = 50 + mixture_coef: float = 0.05 + min_std: float = 0.05 + temperature: float = 0.5 + momentum: float = 0.1 + horizon: int = 5 std_schedule: str = f"linear(0.5, {min_std}, 3000)" horizon_schedule: str = f"linear(1, {horizon}, 2500)" batch_size: int = 1024 - max_buffer_size : int = 1000000 - reward_coef : float = 1 - value_coef : float = 0.5 - consistency_coef : float = 2 - rho : float = 0.5 - kappa : float = 0.1 + max_buffer_size: int = 1000000 + reward_coef: float = 1 + value_coef: float = 0.5 + consistency_coef: float = 2 + rho: float = 0.5 + kappa: float = 0.1 per_alpha: float = 0.6 - per_beta : float = 0.4 - grad_clip_norm : float = 10 + per_beta: float = 0.4 + grad_clip_norm: float = 10 seed_steps: int = 750 update_freq: int = 2 tau: int = 0.01 - discount : float = 0.99 - buffer_device : str = "cpu" - train_steps : int = int(1e6) - num_q : int = 3 + discount: float = 0.99 + buffer_device: str = "cpu" + train_steps: int = int(1e6) + num_q: int = 3 - action_repeat : int = 2 + action_repeat: int = 2 eval_freq: int = 15000 - eval_freq_episode : int = 10 - eval_episodes : int = 1 + eval_freq_episode: int = 10 + eval_episodes: int = 1 - save_model : bool = True - save_video : bool = False + save_model: bool = True + save_video: bool = False + eval_model: bool = False + save_buffer: bool = True + eval_freq_episode: int = 10 + eval_episodes: int = 1 + save_buffer_freq_episode: int = 50 + save_model_freq_episode: int = 10 - use_wandb : bool = False - wandb_entity : str = "crajagopalan" - wandb_project : str = "xbot" - + use_wandb: bool = False + wandb_entity: str = "crajagopalan" + wandb_project: str = "xbot" def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + class VideoRecorder: """Utility class for logging evaluation videos.""" + def __init__(self, root_dir, render_size=384, fps=25): - self.save_dir = (root_dir / 'eval_video') if root_dir else None + self.save_dir = (root_dir / "eval_video") if root_dir else None self.render_size = render_size self.fps = fps self.frames = [] @@ -114,38 +128,45 @@ def record(self, env, h1): img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGR) self.video.write(img) - def save(self,): - self.video.release() - + def save( + self, + ): + self.video.release() + + def evaluate(test_env, agent, h1, step, video, action_repeat=1): """Evaluate a trained agent and optionally save a video.""" episode_rewards = [] obs, privileged_obs = test_env.reset() - critic_obs = privileged_obs if privileged_obs is not None else obs + critic_obs = privileged_obs if privileged_obs is not None else obs state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs - dones, ep_reward, t = torch.tensor([False] * test_env.num_envs), torch.tensor([0.] * test_env.num_envs), 0 - if video: video.init(test_env, h1, enabled=True) - for i in range(int(1000 // action_repeat)): - actions = agent.plan(state, eval_mode=True, step=step, t0=t==0) + dones, ep_reward, t = torch.tensor([False] * test_env.num_envs), torch.tensor([0.0] * test_env.num_envs), 0 + if video: + video.init(test_env, h1, enabled=True) + for i in tqdm(range(int(1000 // action_repeat))): + actions = agent.plan(state, eval_mode=True, step=step, t0=t == 0) for _ in range(action_repeat): obs, privileged_obs, rewards, dones, infos = test_env.step(actions) critic_obs = privileged_obs if privileged_obs is not None else obs ep_reward += rewards.cpu() t += 1 - if video: video.record(test_env, h1) + if video: + video.record(test_env, h1) state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs episode_rewards.append(ep_reward) - if video: video.save() + if video: + video.save() print(f"Timestep : {t} Episode Rewards - {torch.cat(episode_rewards).mean().item()}") return torch.nanmean(torch.cat(episode_rewards)).item() + def play(args: argparse.Namespace) -> None: """Training script for TD-MPC. Requires a CUDA-enabled device.""" - assert torch.cuda.is_available() + assert torch.cuda.is_available() env_cfg, _ = task_registry.get_cfgs(name=args.task) - env, _ = task_registry.make_env(name=args.task, args=args) + env, _ = task_registry.make_env(name=args.task, args=args) - fp = "/home/guest/sim/logs/2024-09-12_17-42-05_walk_state_dora/models/tdmpc_policy_10.pt" + fp = "/home/guest/sim/logs/2024-09-17_00-45-33_walk_state_dora/models/tdmpc_policy_350.pt" config = torch.load(fp)["config"] tdmpc_cfg = TDMPC_DoraConfigs(**config) env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) @@ -163,24 +184,26 @@ def play(args: argparse.Namespace) -> None: ) set_seed(tdmpc_cfg.seed) - work_dir = Path().cwd() / __LOGS__ / f"{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}_{str(tdmpc_cfg.seed)}" + work_dir = ( + Path().cwd() / __LOGS__ / f"{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}_{str(tdmpc_cfg.seed)}" + ) obs, privileged_obs = env.reset() - critic_obs = privileged_obs if privileged_obs is not None else obs + critic_obs = privileged_obs if privileged_obs is not None else obs state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] tdmpc_cfg.obs_shape = [state.shape[0]] - tdmpc_cfg.action_shape = (env.num_actions) + tdmpc_cfg.action_shape = env.num_actions tdmpc_cfg.action_dim = env.num_actions - tdmpc_cfg.episode_length = 100 # int(env.max_episode_length // tdmpc_cfg.action_repeat) + tdmpc_cfg.episode_length = 100 # int(env.max_episode_length // tdmpc_cfg.action_repeat) tdmpc_cfg.num_envs = env.num_envs L = logger.Logger(work_dir, tdmpc_cfg) - log_dir = logger.make_dir(work_dir) + log_dir = logger.make_dir(work_dir) video = VideoRecorder(log_dir) agent = TDMPC(tdmpc_cfg) - + agent.load(fp) step = 0 episode_idx, start_time = 0, time.time() @@ -190,7 +213,8 @@ def play(args: argparse.Namespace) -> None: # Log training episode evaluate(env, agent, h1, step, video, tdmpc_cfg.action_repeat) - print('Testing completed successfully') + print("Testing completed successfully") + if __name__ == "__main__": - play(get_args()) \ No newline at end of file + play(get_args()) diff --git a/sim/resources/stompymini/__init__.py b/sim/resources/stompymini/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/sim/resources/stompymini/robot.xml b/sim/resources/stompymini/robot.xml new file mode 100755 index 00000000..db5f3f38 --- /dev/null +++ b/sim/resources/stompymini/robot.xml @@ -0,0 +1,598 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/sim/resources/stompymini/robot_fixed.xml b/sim/resources/stompymini/robot_fixed.xml new file mode 100755 index 00000000..aa4981d9 --- /dev/null +++ b/sim/resources/stompymini/robot_fixed.xml @@ -0,0 +1,961 @@ + + + + + + + + + + + + \ No newline at end of file diff --git a/sim/train_tdmpc.py b/sim/train_tdmpc.py index b9bce1f6..d8391882 100644 --- a/sim/train_tdmpc.py +++ b/sim/train_tdmpc.py @@ -16,198 +16,238 @@ import numpy as np from pathlib import Path import random + torch.backends.cudnn.benchmark = True __LOGS__ = "logs" + @dataclass class TDMPC_DoraConfigs: - seed: int = 42 - task : str = "walk" - exp_name : str = "dora" - device : str = "cuda:0" - num_envs : int = 10 - - lr : float = 1e-3 - modality : str = "state" - enc_dim: int = 512 - mlp_dim = [512, 256] - latent_dim: int = 100 - - iterations : int = 12 - num_samples : int = 512 - num_elites : int = 50 - mixture_coef : float = 0.05 - min_std : float = 0.1 - temperature : float = 0.5 - momentum : float = 0.1 - horizon : int = 5 - std_schedule: str = f"linear(0.5, {min_std}, 60000)" - horizon_schedule: str = f"linear(1, {horizon}, 15000)" - - batch_size: int = 8192 - max_buffer_size : int = int(5e6) - reward_coef : float = 1 - value_coef : float = 0.75 - consistency_coef : float = 2 - rho : float = 0.75 - kappa : float = 0.1 - per_alpha: float = 0.6 - per_beta : float = 0.4 - grad_clip_norm : float = 100 - seed_steps: int = 500 - update_freq: int = 3 - tau: int = 0.05 - - discount : float = 0.99 - buffer_device : str = "cpu" - train_steps : int = int(1e6) - num_q : int = 2 - - action_repeat : int = 2 - eval_freq: int = 15000 - eval_freq_episode : int = 10 - eval_episodes : int = 1 - - save_model : bool = True - save_video : bool = False - - use_wandb : bool = False - wandb_entity : str = "crajagopalan" - wandb_project : str = "xbot" - + seed: int = 42 + task: str = "walk" + exp_name: str = "dora" + device: str = "cuda:0" + num_envs: int = 10 + max_clip_actions: float = 1.0 + clip_actions: str = f"linear(1, {max_clip_actions}, 100000)" + episode_length: int = 100 + max_episode_length: int = 1000 + + lr: float = 1e-3 + modality: str = "state" + enc_dim: int = 512 + mlp_dim = [512, 256] + latent_dim: int = 100 + + iterations: int = 12 + num_samples: int = 512 + num_elites: int = 50 + mixture_coef: float = 0.05 + min_std: float = 0.1 + temperature: float = 0.5 + momentum: float = 0.1 + horizon: int = 5 + std_schedule: str = f"linear(0.5, {min_std}, 200000)" + horizon_schedule: str = f"linear(1, {horizon}, 15000)" + + batch_size: int = 8192 + max_buffer_size: int = int(5e6) + reward_coef: float = 1 + value_coef: float = 0.75 + consistency_coef: float = 2 + rho: float = 0.75 + kappa: float = 0.1 + per_alpha: float = 0.6 + per_beta: float = 0.4 + grad_clip_norm: float = 50 + seed_steps: int = 500 + update_freq: int = 3 + tau: int = 0.05 + + discount: float = 0.99 + buffer_device: str = "cpu" + train_steps: int = int(1e6) + num_q: int = 2 + + action_repeat: int = 2 + + save_model: bool = True + save_video: bool = False + save_buffer: bool = True + eval_model: bool = False + eval_freq_episode: int = 10 + eval_episodes: int = 1 + save_buffer_freq_episode: int = 50 + save_model_freq_episode: int = 50 + + use_wandb: bool = False + wandb_entity: str = "crajagopalan" + wandb_project: str = "xbot" def set_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) def evaluate(test_env, agent, h1, num_episodes, step, env_step, video, action_repeat=1): - """Evaluate a trained agent and optionally save a video.""" - episode_rewards = [] - for i in range(num_episodes): - obs, privileged_obs = test_env.reset() - critic_obs = privileged_obs if privileged_obs is not None else obs - state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] - dones, ep_reward, t = torch.tensor([False]), 0, 0 - if video: video.init(test_env, h1, enabled=(i==0)) - while not dones[0].item(): - actions = agent.plan(state, eval_mode=True, step=step, t0=t==0) - for _ in range(action_repeat): - obs, privileged_obs, rewards, dones, infos = test_env.step(actions) - critic_obs = privileged_obs if privileged_obs is not None else obs - state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] - ep_reward += rewards[0] - if video: video.record(test_env, h1) - t += 1 - episode_rewards.append(ep_reward) - if video: video.save(env_step) - return torch.nanmean(torch.tensor(episode_rewards)).item() + """Evaluate a trained agent and optionally save a video.""" + episode_rewards = [] + for i in range(num_episodes): + obs, privileged_obs = test_env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] + dones, ep_reward, t = torch.tensor([False]), 0, 0 + if video: + video.init(test_env, h1, enabled=(i == 0)) + while not dones[0].item(): + actions = agent.plan(state, eval_mode=True, step=step, t0=t == 0) + for _ in range(action_repeat): + obs, privileged_obs, rewards, dones, infos = test_env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1)[0] if privileged_obs is not None else obs[0] + ep_reward += rewards[0] + if video: + video.record(test_env, h1) + t += 1 + episode_rewards.append(ep_reward) + if video: + video.save(env_step) + return torch.nanmean(torch.tensor(episode_rewards)).item() + def train(args: argparse.Namespace) -> None: - """Training script for TD-MPC. Requires a CUDA-enabled device.""" - assert torch.cuda.is_available() - env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) - env, _ = task_registry.make_env(name=args.task, args=args) - - tdmpc_cfg = TDMPC_DoraConfigs() - - if tdmpc_cfg.save_video: - env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) - - camera_properties = gymapi.CameraProperties() - camera_properties.width = 160 - camera_properties.height = 120 - h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties) - camera_offset = gymapi.Vec3(3, -3, 1) - camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135)) - actor_handle = env.gym.get_actor_handle(env.envs[0], 0) - body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0) - env.gym.attach_camera_to_body( - h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION - ) - - set_seed(tdmpc_cfg.seed) - now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") - work_dir = Path().cwd() / __LOGS__ / f"{now}_{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}" - - obs, privileged_obs = env.reset() - critic_obs = privileged_obs if privileged_obs is not None else obs - state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs - - tdmpc_cfg.obs_shape = state.shape[1:] - tdmpc_cfg.action_shape = (env.num_actions) - tdmpc_cfg.action_dim = env.num_actions - episode_length = 100 - tdmpc_cfg.episode_length = episode_length - tdmpc_cfg.num_envs = env.num_envs - tdmpc_cfg.max_episode_length = int(env.max_episode_length) - - agent = TDMPC(tdmpc_cfg) - buffer = ReplayBuffer(tdmpc_cfg) - fp = None - - init_step = 0 - episode_idx, start_time = 0, time.time() - - if fp is not None: - agent.load(fp) - episode_idx = int(fp.split(".")[0].split("_")[-1]) - init_step = episode_idx * tdmpc_cfg.episode_length - episode = Episode(tdmpc_cfg, state) - - L = logger.Logger(work_dir, tdmpc_cfg) - - for step in range(init_step, tdmpc_cfg.train_steps + tdmpc_cfg.episode_length, tdmpc_cfg.episode_length): - if episode.full: - episode = Episode(tdmpc_cfg, state) - for i in range(tdmpc_cfg.episode_length): - actions = agent.plan(state, t0 = episode.first, eval_mode=False, step=step) - original_state = state.clone() - total_rewards, total_dones, total_timeouts = [], [], [] - for _ in range(tdmpc_cfg.action_repeat): - obs, privileged_obs, rewards, dones, infos = env.step(actions) - critic_obs = privileged_obs if privileged_obs is not None else obs - state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs - total_rewards.append(rewards) - total_dones.append(dones) - total_timeouts.append(infos["time_outs"]) - episode += (original_state, actions, torch.stack(total_rewards).sum(dim=0), torch.stack(total_dones).any(dim=0), torch.stack(total_timeouts).any(dim=0), torch.stack(total_dones).any(dim=0)) - buffer += episode - - # Update model - train_metrics = {} - if step >= tdmpc_cfg.seed_steps: - num_updates = tdmpc_cfg.seed_steps if step == tdmpc_cfg.seed_steps else tdmpc_cfg.episode_length - for i in range(num_updates): - train_metrics.update(agent.update(buffer, step+i)) - - # Log training episode - episode_idx += 1 - env_step = int((step + tdmpc_cfg.episode_length) * tdmpc_cfg.num_envs) - common_metrics = { - 'episode': episode_idx, - 'step': step + tdmpc_cfg.episode_length, - 'env_step': env_step, - 'total_time': time.time() - start_time, - 'episode_reward': episode.cumulative_reward.sum().item() / env.num_envs, - 'mean_episode_length' : episode.episode_length, - } - train_metrics.update(common_metrics) - L.log(train_metrics, category='train') - - - # Evaluate agent periodically - if tdmpc_cfg.save_model and episode_idx % tdmpc_cfg.eval_freq_episode == 0: - L.save(agent, f"tdmpc_policy_{int(step // tdmpc_cfg.episode_length) + 1}.pt") - buffer.save(str(work_dir / "buffer.pt")) - # # common_metrics['episode_reward'] = evaluate(env, agent, h1 if L.video is not None else None, tdmpc_cfg.eval_episodes, step, env_step, L.video, tdmpc_cfg.action_repeat) - # # L.log(common_metrics, category='eval') - - print('Training completed successfully') + """Training script for TD-MPC. Requires a CUDA-enabled device.""" + assert torch.cuda.is_available() + env_cfg, train_cfg = task_registry.get_cfgs(name=args.task) + env, _ = task_registry.make_env(name=args.task, args=args) + + tdmpc_cfg = TDMPC_DoraConfigs() + + if tdmpc_cfg.save_video: + env.set_camera(env_cfg.viewer.pos, env_cfg.viewer.lookat) + + camera_properties = gymapi.CameraProperties() + camera_properties.width = 160 + camera_properties.height = 120 + h1 = env.gym.create_camera_sensor(env.envs[0], camera_properties) + camera_offset = gymapi.Vec3(3, -3, 1) + camera_rotation = gymapi.Quat.from_axis_angle(gymapi.Vec3(-0.3, 0.2, 1), np.deg2rad(135)) + actor_handle = env.gym.get_actor_handle(env.envs[0], 0) + body_handle = env.gym.get_actor_rigid_body_handle(env.envs[0], actor_handle, 0) + env.gym.attach_camera_to_body( + h1, env.envs[0], body_handle, gymapi.Transform(camera_offset, camera_rotation), gymapi.FOLLOW_POSITION + ) + + set_seed(tdmpc_cfg.seed) + now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + work_dir = Path().cwd() / __LOGS__ / f"{now}_{tdmpc_cfg.task}_{tdmpc_cfg.modality}_{tdmpc_cfg.exp_name}" + + obs, privileged_obs = env.reset() + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + + tdmpc_cfg.obs_shape = state.shape[1:] + tdmpc_cfg.action_shape = env.num_actions + tdmpc_cfg.action_dim = env.num_actions + episode_length = 100 + tdmpc_cfg.episode_length = episode_length + tdmpc_cfg.num_envs = env.num_envs + tdmpc_cfg.max_episode_length = int(env.max_episode_length) + tdmpc_cfg.max_clip_actions = env.cfg.normalization.clip_actions + tdmpc_cfg.clip_actions = f"linear(1, {env.cfg.normalization.clip_actions}, 50000)" + print(tdmpc_cfg.clip_actions) + + agent = TDMPC(tdmpc_cfg) + buffer = ReplayBuffer(tdmpc_cfg) + fp = None + + init_step, env_step = 0, 0 + episode_idx, start_time = 0, time.time() + + if fp is not None: + init_step, env_step = agent.load(fp) + if init_step is None: + episode_idx = int(fp.split(".")[0].split("_")[-1]) + init_step = episode_idx * tdmpc_cfg.episode_length + env_step = init_step * tdmpc_cfg.episode_length + episode = Episode(tdmpc_cfg, state) + + L = logger.Logger(work_dir, tdmpc_cfg) + + for step in range(init_step, tdmpc_cfg.train_steps + tdmpc_cfg.episode_length, tdmpc_cfg.episode_length): + if episode.full: + episode = Episode(tdmpc_cfg, state) + for i in range(tdmpc_cfg.episode_length): + actions = agent.plan(state, t0=episode.first, eval_mode=False, step=step) + original_state = state.clone() + total_rewards, total_dones, total_timeouts = [], [], [] + for _ in range(tdmpc_cfg.action_repeat): + obs, privileged_obs, rewards, dones, infos = env.step(actions) + critic_obs = privileged_obs if privileged_obs is not None else obs + state = torch.cat([obs, critic_obs], dim=-1) if privileged_obs is not None else obs + total_rewards.append(rewards) + total_dones.append(dones) + total_timeouts.append(infos["time_outs"]) + episode += ( + original_state, + actions, + torch.stack(total_rewards).sum(dim=0), + torch.stack(total_dones).any(dim=0), + torch.stack(total_timeouts).any(dim=0), + torch.stack(total_dones).any(dim=0), + ) + buffer += episode + + # Update model + train_metrics = {} + if step >= tdmpc_cfg.seed_steps: + num_updates = tdmpc_cfg.seed_steps if step == tdmpc_cfg.seed_steps else tdmpc_cfg.episode_length + for i in range(num_updates): + train_metrics.update(agent.update(buffer, step + i)) + + # Log training episode + episode_idx += 1 + env_step += int(tdmpc_cfg.episode_length * tdmpc_cfg.num_envs) + common_metrics = { + "episode": episode_idx, + "step": step + tdmpc_cfg.episode_length, + "env_step": env_step, + "total_time": time.time() - start_time, + "episode_reward": episode.cumulative_reward.sum().item() / env.num_envs, + "mean_episode_length": episode.episode_length, + } + train_metrics.update(common_metrics) + L.log(train_metrics, category="train") + + # Evaluate agent periodically + if tdmpc_cfg.save_model and episode_idx % tdmpc_cfg.save_model_freq_episode == 0: + L.save( + agent, + f"tdmpc_policy_{int(step // tdmpc_cfg.episode_length) + 1}.pt", + step + tdmpc_cfg.episode_length, + env_step, + ) + if tdmpc_cfg.save_buffer and episode_idx % tdmpc_cfg.save_buffer_freq_episode == 0: + buffer.save(str(work_dir / "buffer.pt")) + if tdmpc_cfg.eval_model and episode_idx % tdmpc_cfg.eval_freq_episode == 0: + common_metrics["episode_reward"] = evaluate( + env, + agent, + h1 if L.video is not None else None, + tdmpc_cfg.eval_episodes, + step, + env_step, + L.video, + tdmpc_cfg.action_repeat, + ) + L.log(common_metrics, category="eval") + + print("Training completed successfully") + if __name__ == "__main__": # python -m sim.humanoid_gym.train - train(get_args()) \ No newline at end of file + train(get_args())